database.py 11 KB


  1. # -*- coding: utf-8 -*-
  2. """Database module, including the SQLAlchemy database object and DB-related utilities."""
  3. from pprint import pformat
  4. from sqlalchemy import desc, or_
  5. from sqlalchemy.sql.sqltypes import Date, DateTime
  6. from werkzeug import cached_property
  7. from flask import current_app
  8. from walle.service.extensions import db
  9. from walle.service.utils import basestring
  10. from walle.service.utils import datetime_str_to_obj, date_str_to_obj
  11. # Alias common SQLAlchemy names
  12. Column = db.Column
  13. relationship = db.relationship
  14. OPERATOR_FUNC_DICT = {
  15. '=': (lambda cls, k, v: getattr(cls, k) == v),
  16. '==': (lambda cls, k, v: getattr(cls, k) == v),
  17. 'eq': (lambda cls, k, v: getattr(cls, k) == v),
  18. '!=': (lambda cls, k, v: getattr(cls, k) != v),
  19. 'ne': (lambda cls, k, v: getattr(cls, k) != v),
  20. 'neq': (lambda cls, k, v: getattr(cls, k) != v),
  21. '>': (lambda cls, k, v: getattr(cls, k) > v),
  22. 'gt': (lambda cls, k, v: getattr(cls, k) > v),
  23. '>=': (lambda cls, k, v: getattr(cls, k) >= v),
  24. 'gte': (lambda cls, k, v: getattr(cls, k) >= v),
  25. '<': (lambda cls, k, v: getattr(cls, k) < v),
  26. 'lt': (lambda cls, k, v: getattr(cls, k) < v),
  27. '<=': (lambda cls, k, v: getattr(cls, k) <= v),
  28. 'lte': (lambda cls, k, v: getattr(cls, k) <= v),
  29. 'or': (lambda cls, k, v: or_(getattr(cls, k) == value for value in v)),
  30. 'in': (lambda cls, k, v: getattr(cls, k).in_(v)),
  31. 'nin': (lambda cls, k, v: ~getattr(cls, k).in_(v)),
  32. 'like': (lambda cls, k, v: getattr(cls, k).like('%%%s%%' % (v))),
  33. 'nlike': (lambda cls, k, v: ~getattr(cls, k).like(v)),
  34. '+': (lambda cls, k, v: getattr(cls, k) + v),
  35. 'incr': (lambda cls, k, v: getattr(cls, k) + v),
  36. '-': (lambda cls, k, v: getattr(cls, k) - v),
  37. 'decr': (lambda cls, k, v: getattr(cls, k) - v),
  38. }
  39. def parse_operator(cls, filter_name_dict):
  40. """ 用来返回sqlalchemy query对象filter使用的表达式
  41. Args:
  42. filter_name_dict (dict): 过滤条件dict
  43. {
  44. 'last_name': {'eq': 'wang'}, # 如果是dic使用key作为操作符
  45. 'age': {'>': 12}
  46. }
  47. Returns:
  48. binary_expression_list (lambda list)
  49. """
  50. def _change_type(cls, field, value):
  51. """ 有些表字段比如DateTime类型比较的时候需要转换类型,
  52. 前端传过来的都是字符串,Date等类型没法直接相比较,需要转成Date类型
  53. Args:
  54. cls (class): Model class
  55. field (str): Model class field
  56. value (str): value need to compare
  57. """
  58. field_type = getattr(cls, field).type
  59. if isinstance(field_type, Date):
  60. return date_str_to_obj(value)
  61. elif isinstance(field_type, DateTime):
  62. return datetime_str_to_obj(value)
  63. else:
  64. return value
  65. binary_expression_list = []
  66. for field, op_dict in filter_name_dict.items():
  67. for op, op_val in op_dict.items():
  68. op_val = _change_type(cls, field, op_val)
  69. if op in OPERATOR_FUNC_DICT:
  70. binary_expression_list.append(
  71. OPERATOR_FUNC_DICT[op](cls, field, op_val)
  72. )
  73. return binary_expression_list
  74. class CRUDMixin(object):
  75. """Mixin that adds convenience methods for
  76. CRUD (create, read, update, delete) operations."""
  77. @classmethod
  78. def create(cls, **kwargs):
  79. """Create a new record and save it the database."""
  80. instance = cls(**kwargs)
  81. return instance.save()
  82. @classmethod
  83. def create_from_dict(cls, d):
  84. """Create a new record and save it the database."""
  85. assert isinstance(d, dict)
  86. instance = cls(**d)
  87. return instance.save()
  88. def update(self, commit=True, **kwargs):
  89. """Update specific fields of a record."""
  90. for attr, value in kwargs.items():
  91. setattr(self, attr, value)
  92. return commit and self.save() or self
  93. def save(self, commit=True):
  94. """Save the record."""
  95. db.session.add(self)
  96. if commit:
  97. try:
  98. db.session.commit()
  99. except Exception as e:
  100. current_app.logger.info(e)
  101. db.session.rollback()
  102. return self
  103. def delete(self, commit=True):
  104. """Remove the record from the database."""
  105. db.session.delete(self)
  106. if commit:
  107. try:
  108. db.session.commit()
  109. except:
  110. db.session.rollback()
  111. return self
  112. def to_dict(self, fields_list=None):
  113. """
  114. Args:
  115. fields (str list): 指定返回的字段
  116. """
  117. column_list = fields_list or [
  118. column.name for column in self.__table__.columns
  119. ]
  120. return {
  121. column_name: getattr(self, column_name)
  122. for column_name in column_list
  123. }
  124. @classmethod
  125. def create_or_update(cls, query_dict, update_dict=None):
  126. instance = db.session.query(cls).filter_by(**query_dict).first()
  127. if instance: # update
  128. if update_dict is not None:
  129. return instance.update(**update_dict)
  130. else:
  131. return instance
  132. else: # create new instance
  133. query_dict.update(update_dict or {})
  134. return cls.create(**query_dict)
  135. @classmethod
  136. def query_paginate(cls, page=1, limit=20, fields=None, order_by_list=[('id', 'desc')],
  137. filter_name_dict=None):
  138. """ 通用的分页查询函数
  139. Args:
  140. page (int): 页数
  141. limit (int): 每页个数
  142. order_by_list (tuple list): 用来指定排序的字段,可以传多个
  143. [ ('id', 1), ('name', -1) ],1表示正序,-1表示逆序
  144. or
  145. [ ('id', 'asc'), ('name', 'desc') ],1表示正序,-1表示逆序
  146. filter_name_dict (dict): 过滤条件,使用字典表示,使用字段名作为key,value
  147. 是{'operator': to_compare_value}, e.g.:
  148. {
  149. 'last_name': {'eq': 'wang'}, # 如果是dic使用key作为操作符
  150. 'age': {'>': 12}
  151. }
  152. Returns:
  153. if fields is not None:
  154. (keytuple_list, total_cnt) (tuple)
  155. else:
  156. (instance_list, total_cnt) (tuple)
  157. 前段查询参数规范:
  158. request.args 示例:
  159. ImmutableMultiDict([('limit', '10'), ('page', '1'), ('filter', '[{"field":"name","op":"eq","q":"g"},{
  160. "field":"id","op":"gt","q":"5"
  161. }]')])
  162. page: 页码
  163. limit: 每页限制
  164. order: 顺序,取值"asc" "desc". """'name', 'asc', 'model', 'desc'"""
  165. fields: 需要返回的字段
  166. filter: 过滤条件:
  167. {
  168. field: 需要过滤的字段
  169. op: 过滤操作符,支持"eq","neq","gt","gte","lt","lte","in","nin","like"
  170. q: 过滤值
  171. }
  172. """
  173. fields = (
  174. [getattr(cls, column) for column in fields] if fields is not None
  175. else None
  176. )
  177. if fields:
  178. query = db.session.query(*fields)
  179. else:
  180. query = db.session.query(cls)
  181. if order_by_list:
  182. for (field, order) in order_by_list:
  183. query = (
  184. query.order_by(getattr(cls, field)) if order == 1 else
  185. query.order_by(desc(getattr(cls, field)))
  186. )
  187. if filter_name_dict:
  188. p = parse_operator(cls, filter_name_dict)
  189. query = query.filter(*p)
  190. total_cnt = query.count()
  191. start = (page - 1) * limit
  192. query = query.offset(start).limit(limit)
  193. instance_or_keytuple_list = query.all()
  194. return instance_or_keytuple_list, total_cnt
  195. @classmethod
  196. def dump_schema(cls, items, fields, schema_class):
  197. """ 用来序列化从数据库查询出来的对象
  198. Args:
  199. items (instance list): obj list query from mysql
  200. fields (str list): fields need to dump
  201. schema_class (marshmallow.Schema): marshmallow.Schema class
  202. Returns:
  203. items, err
  204. """
  205. fields = (
  206. fields if fields else list(schema_class._declared_fields.keys())
  207. )
  208. schema = schema_class(many=True, only=fields)
  209. items, err = schema.dump(items)
  210. return items, err
  211. @classmethod
  212. def query_paginate_and_dump_schema(cls, page=1, limit=20, fields=None,
  213. order_by_list=None,
  214. filter_name_dict=None,
  215. schema_class=None):
  216. """ 分页查询并且返回dump后的对象,可以解决大部分查询问题 """
  217. assert schema_class
  218. items, total_cnt = cls.query_paginate(
  219. page, limit, fields, order_by_list, filter_name_dict
  220. )
  221. items, err = cls.dump_schema(items, fields, schema_class)
  222. return items, total_cnt
  223. def __repr__(self):
  224. return pformat(self.to_dict())
  225. @cached_property
  226. def column_name_set(self):
  227. return set([column.name for column in self.__table__.columns])
  228. @classmethod
  229. def get_common_fields(cls, fields=None):
  230. """ 防止传过来的fields含有该Model没有的字段 """
  231. if not fields:
  232. return []
  233. table_fields_set = set(
  234. [column.name for column in cls.__table__.columns]
  235. )
  236. return list(table_fields_set & set(fields))
  237. class Model(CRUDMixin, db.Model):
  238. """Base model class that includes CRUD convenience methods."""
  239. __abstract__ = True
  240. status_remove = -1
  241. status_default = 0
  242. status_available = 1
  243. # From Mike Bayer's "Building the app" talk
  244. # https://speakerdeck.com/zzzeek/building-the-app
  245. class SurrogatePK(object):
  246. """A mixin that adds a surrogate integer 'primary key' column named ``id`` to any declarative-mapped class."""
  247. __table_args__ = {'extend_existing': True}
  248. id = db.Column(db.Integer, primary_key=True)
  249. @classmethod
  250. def get_by_id(cls, record_id):
  251. """Get record by ID."""
  252. if any(
  253. (isinstance(record_id, basestring) and record_id.isdigit(),
  254. isinstance(record_id, (int))),
  255. ):
  256. return cls.query.get(int(record_id))
  257. return None
  258. def reference_col(tablename, nullable=False, pk_name='id', **kwargs):
  259. """Column that adds primary key foreign key reference.
  260. Usage: ::
  261. category_id = reference_col('category')
  262. category = relationship('Category', backref='categories')
  263. """
  264. return db.Column(
  265. db.ForeignKey('{0}.{1}'.format(tablename, pk_name)),
  266. nullable=nullable, **kwargs)