diff options
Diffstat (limited to 'pyaggr3g470r/controllers/abstract.py')
-rw-r--r-- | pyaggr3g470r/controllers/abstract.py | 43 |
1 files changed, 28 insertions, 15 deletions
diff --git a/pyaggr3g470r/controllers/abstract.py b/pyaggr3g470r/controllers/abstract.py index fe437b09..8960c3be 100644 --- a/pyaggr3g470r/controllers/abstract.py +++ b/pyaggr3g470r/controllers/abstract.py @@ -1,47 +1,60 @@ -from flask import g +from bootstrap import db from pyaggr3g470r.lib.exceptions import Forbidden, NotFound class AbstractController(object): _db_cls = None + _user_id_key = 'user_id' def __init__(self, user_id): self.user_id = user_id def _get(self, **filters): if self.user_id: - filters['user_id'] = self.user_id - db_filters = [getattr(self._db_cls, key) == value - for key, value in filters.iteritems()] - return self._db_cls.query.filter(*db_filters).first() + filters[self._user_id_key] = self.user_id + db_filters = set() + for key, value in filters.iteritems(): + if key.endswith('__gt'): + db_filters.add(getattr(self._db_cls, key[:-4]) > value) + elif key.endswith('__lt'): + db_filters.add(getattr(self._db_cls, key[:-4]) < value) + elif key.endswith('__ge'): + db_filters.add(getattr(self._db_cls, key[:-4]) >= value) + elif key.endswith('__le'): + db_filters.add(getattr(self._db_cls, key[:-4]) <= value) + elif key.endswith('__ne'): + db_filters.add(getattr(self._db_cls, key[:-4]) != value) + elif key.endswith('__in'): + db_filters.add(getattr(self._db_cls, key[:-4]).in_(value)) + else: + db_filters.add(getattr(self._db_cls, key) == value) + return self._db_cls.query.filter(*db_filters) def get(self, **filters): obj = self._get(**filters).first() if not obj: raise NotFound({'message': 'No %r (%r)' % (self._db_cls.__class__.__name__, filters)}) - if obj.user_id != self.user_id: + if getattr(obj, self._user_id_key) != self.user_id: raise Forbidden({'message': 'No authorized to access %r (%r)' % (self._db_cls.__class__.__name__, filters)}) return obj def create(self, **attrs): obj = self._db_cls(**attrs) - g.db.session.commit() + db.session.commit() return obj def read(self, **filters): return self._get(**filters) - def update(self, obj_id, **attrs): - obj = self.get(id=obj_id) - for key, values in attrs.iteritems(): - setattr(obj, key, values) - g.db.session.commit() - return obj + def update(self, filters, attrs): + result = self._get(**filters).update(attrs, synchronize_session=False) + db.session.commit() + return result def delete(self, obj_id): obj = self.get(id=obj_id) - g.db.session.delete(obj) - g.db.session.commit() + db.session.delete(obj) + db.session.commit() return obj |