flask-sqlalchemy的痛点 #
我们再使用flask开发接口的时候,经常需要讲数据库的模型直接转化为json给前端调用,之前我们一直使用非常笨的方法,比如将字段一个一个取出来,然后赋值给dict,或者写一个蹩脚的转换器。再新项目中,我改进了这种工作方式,给model统一添加一个to_dict与from_dict方法,更方便的开发符合restful规范的接口
Sqlalchemy Model To Dict #
首先我们撞见一个Basemodel,让其他模型都继承自这个
from flask import json
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm.attributes import QueryableAttribute
from wakatime_website import app
db = SQLAlchemy(app)
class BaseModel(db.Model):
__abstract__ = True
def to_dict(self, show=None, _hide=[], _path=None):
"""Return a dictionary representation of this model."""
show = show or []
hidden = self._hidden_fields if hasattr(self, "_hidden_fields") else []
default = self._default_fields if hasattr(self, "_default_fields") else []
default.extend(['id', 'modified_at', 'created_at'])
if not _path:
_path = self.__tablename__.lower()
def prepend_path(item):
item = item.lower()
if item.split(".", 1)[0] == _path:
return item
if len(item) == 0:
return item
if item[0] != ".":
item = ".%s" % item
item = "%s%s" % (_path, item)
return item
_hide[:] = [prepend_path(x) for x in _hide]
show[:] = [prepend_path(x) for x in show]
columns = self.__table__.columns.keys()
relationships = self.__mapper__.relationships.keys()
properties = dir(self)
ret_data = {}
for key in columns:
if key.startswith("_"):
continue
check = "%s.%s" % (_path, key)
if check in _hide or key in hidden:
continue
if check in show or key in default:
ret_data[key] = getattr(self, key)
for key in relationships:
if key.startswith("_"):
continue
check = "%s.%s" % (_path, key)
if check in _hide or key in hidden:
continue
if check in show or key in default:
_hide.append(check)
is_list = self.__mapper__.relationships[key].uselist
if is_list:
items = getattr(self, key)
if self.__mapper__.relationships[key].query_class is not None:
if hasattr(items, "all"):
items = items.all()
ret_data[key] = []
for item in items:
ret_data[key].append(
item.to_dict(
show=list(show),
_hide=list(_hide),
_path=("%s.%s" % (_path, key.lower())),
)
)
else:
if (
self.__mapper__.relationships[key].query_class is not None
or self.__mapper__.relationships[key].instrument_class
is not None
):
item = getattr(self, key)
if item is not None:
ret_data[key] = item.to_dict(
show=list(show),
_hide=list(_hide),
_path=("%s.%s" % (_path, key.lower())),
)
else:
ret_data[key] = None
else:
ret_data[key] = getattr(self, key)
for key in list(set(properties) - set(columns) - set(relationships)):
if key.startswith("_"):
continue
if not hasattr(self.__class__, key):
continue
attr = getattr(self.__class__, key)
if not (isinstance(attr, property) or isinstance(attr, QueryableAttribute)):
continue
check = "%s.%s" % (_path, key)
if check in _hide or key in hidden:
continue
if check in show or key in default:
val = getattr(self, key)
if hasattr(val, "to_dict"):
ret_data[key] = val.to_dict(
show=list(show),
_hide=list(_hide), _path=("%s.%s" % (_path, key.lower()))
_path=('%s.%s' % (path, key.lower())),
)
else:
try:
ret_data[key] = json.loads(json.dumps(val))
except:
pass
return ret_data
在这个方法中,我们给BaseModel创建了to_dict方法和读取三个默认字段
- _default_fields
- _hidden_fields
- _readonly_fields
Example:
class User(BaseModel):
id = db.Column(UUID(), primary_key=True, default=uuid.uuid4)
username = db.Column(db.String(), nullabe=False, unique=True)
password = db.Column(db.String())
email_confirmed = db.Column(db.Boolean())
modified_at = db.Column(db.DateTime())
created_at = db.Column(db.DateTime(), nullable=False, default=datetime.utcnow)
_default_fields = [
"username",
"joined_recently",
]
_hidden_fields = [
"password",
]
_readonly_fields = [
"email_confirmed",
]
@property
def joined_recently(self):
return self.created_at > datetime.utcnow() - timedelta(days=3)
user = User(username="zzzeek")
db.session.add(user)
db.session.commit()
print(user.to_dict())
看到我们支持默认输出字段,默认隐藏字段,以及对@property也照样支持
to_dict方法同样支持关联模型
Example:
class User(BaseModel):
...
goals = db.relationship('Goal', backref='user', lazy='dynamic')
class Goal(BaseModel):
id = db.Column(UUID(), primary_key=True, default=uuid.uuid4)
title = db.Column(db.String(), nullabe=False)
accomplished = db.Column(db.Boolean())
created_at = db.Column(db.DateTime(), nullable=False, default=datetime.utcnow)
_default_fields = [
"title",
]
goal = Goal(title="Mountain", accomplished=True)
user.goals.append(goal)
db.session.commit()
print(user.to_dict(show=['goals', 'goals.accomplished']))
他将打印如下
{
'id': UUID('488345de-88a1-4c87-9304-46a1a31c9414'),
'username': 'zoe',
'goals': [
{
'id': UUID('c72cfef0-0988-45e4-9f4b-8a4a7d4f8d8f'),
'title': 'Mountain',
'accomplished': True,
'created_at': datetime.datetime(2018, 7, 11, 6, 45, 18, 299924),
},
],
'joined_recently': True,
'modified_at': datetime.datetime(2018, 7, 11, 6, 36, 47, 939084),
'created_at': datetime.datetime(2018, 7, 11, 6, 28, 56, 905379),
}
是不是方便多了?