aboutsummaryrefslogtreecommitdiff
path: root/pyaggr3g470r/controllers/abstract.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyaggr3g470r/controllers/abstract.py')
-rw-r--r--pyaggr3g470r/controllers/abstract.py43
1 files changed, 28 insertions, 15 deletions
diff --git a/pyaggr3g470r/controllers/abstract.py b/pyaggr3g470r/controllers/abstract.py
index fe437b09..8960c3be 100644
--- a/pyaggr3g470r/controllers/abstract.py
+++ b/pyaggr3g470r/controllers/abstract.py
@@ -1,47 +1,60 @@
-from flask import g
+from bootstrap import db
from pyaggr3g470r.lib.exceptions import Forbidden, NotFound
class AbstractController(object):
_db_cls = None
+ _user_id_key = 'user_id'
def __init__(self, user_id):
self.user_id = user_id
def _get(self, **filters):
if self.user_id:
- filters['user_id'] = self.user_id
- db_filters = [getattr(self._db_cls, key) == value
- for key, value in filters.iteritems()]
- return self._db_cls.query.filter(*db_filters).first()
+ filters[self._user_id_key] = self.user_id
+ db_filters = set()
+ for key, value in filters.iteritems():
+ if key.endswith('__gt'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) > value)
+ elif key.endswith('__lt'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) < value)
+ elif key.endswith('__ge'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) >= value)
+ elif key.endswith('__le'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) <= value)
+ elif key.endswith('__ne'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) != value)
+ elif key.endswith('__in'):
+ db_filters.add(getattr(self._db_cls, key[:-4]).in_(value))
+ else:
+ db_filters.add(getattr(self._db_cls, key) == value)
+ return self._db_cls.query.filter(*db_filters)
def get(self, **filters):
obj = self._get(**filters).first()
if not obj:
raise NotFound({'message': 'No %r (%r)'
% (self._db_cls.__class__.__name__, filters)})
- if obj.user_id != self.user_id:
+ if getattr(obj, self._user_id_key) != self.user_id:
raise Forbidden({'message': 'No authorized to access %r (%r)'
% (self._db_cls.__class__.__name__, filters)})
return obj
def create(self, **attrs):
obj = self._db_cls(**attrs)
- g.db.session.commit()
+ db.session.commit()
return obj
def read(self, **filters):
return self._get(**filters)
- def update(self, obj_id, **attrs):
- obj = self.get(id=obj_id)
- for key, values in attrs.iteritems():
- setattr(obj, key, values)
- g.db.session.commit()
- return obj
+ def update(self, filters, attrs):
+ result = self._get(**filters).update(attrs, synchronize_session=False)
+ db.session.commit()
+ return result
def delete(self, obj_id):
obj = self.get(id=obj_id)
- g.db.session.delete(obj)
- g.db.session.commit()
+ db.session.delete(obj)
+ db.session.commit()
return obj
bgstack15