aboutsummaryrefslogtreecommitdiff
path: root/pyaggr3g470r/controllers/abstract.py
diff options
context:
space:
mode:
Diffstat (limited to 'pyaggr3g470r/controllers/abstract.py')
-rw-r--r--pyaggr3g470r/controllers/abstract.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/pyaggr3g470r/controllers/abstract.py b/pyaggr3g470r/controllers/abstract.py
new file mode 100644
index 00000000..a99e67f3
--- /dev/null
+++ b/pyaggr3g470r/controllers/abstract.py
@@ -0,0 +1,69 @@
+import logging
+from bootstrap import db
+from sqlalchemy import update
+from werkzeug.exceptions import Forbidden, NotFound
+
+logger = logging.getLogger(__name__)
+
+
+class AbstractController(object):
+ _db_cls = None # reference to the database class
+ _user_id_key = 'user_id'
+
+ def __init__(self, user_id):
+ self.user_id = user_id
+
+ def _to_filters(self, **filters):
+ if self.user_id:
+ filters[self._user_id_key] = self.user_id
+ db_filters = set()
+ for key, value in filters.items():
+ if key.endswith('__gt'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) > value)
+ elif key.endswith('__lt'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) < value)
+ elif key.endswith('__ge'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) >= value)
+ elif key.endswith('__le'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) <= value)
+ elif key.endswith('__ne'):
+ db_filters.add(getattr(self._db_cls, key[:-4]) != value)
+ elif key.endswith('__in'):
+ db_filters.add(getattr(self._db_cls, key[:-4]).in_(value))
+ else:
+ db_filters.add(getattr(self._db_cls, key) == value)
+ return db_filters
+
+ def _get(self, **filters):
+ return self._db_cls.query.filter(*self._to_filters(**filters))
+
+ def get(self, **filters):
+ obj = self._get(**filters).first()
+ if not obj:
+ raise NotFound({'message': 'No %r (%r)'
+ % (self._db_cls.__class__.__name__, filters)})
+ if getattr(obj, self._user_id_key) != self.user_id:
+ raise Forbidden({'message': 'No authorized to access %r (%r)'
+ % (self._db_cls.__class__.__name__, filters)})
+ return obj
+
+ def create(self, **attrs):
+ attrs[self._user_id_key] = self.user_id
+ obj = self._db_cls(**attrs)
+ db.session.add(obj)
+ db.session.commit()
+ return obj
+
+ def read(self, **filters):
+ return self._get(**filters)
+
+ def update(self, filters, attrs):
+ result = self._get(**filters).update(attrs, synchronize_session=False)
+ db.session.commit()
+ return result
+
+ def delete(self, obj_id):
+ obj = self.get(id=obj_id)
+ db.session.delete(obj)
+ db.session.commit()
+ return obj
bgstack15