跳过正文

Sqlalchemy中model的增强方法

·510 字·3 分钟
Remy
作者
Remy
Bug的设计师,故障的制造机,P0的背锅侠。代码里的隐秘问题总能被我创造性地解锁。写代码如同解谜,有时谜底是惊喜,有时是惊吓。

flask-sqlalchemy的痛点
#

我们再使用flask开发接口的时候,经常需要讲数据库的模型直接转化为json给前端调用,之前我们一直使用非常笨的方法,比如将字段一个一个取出来,然后赋值给dict,或者写一个蹩脚的转换器。再新项目中,我改进了这种工作方式,给model统一添加一个to_dictfrom_dict方法,更方便的开发符合restful规范的接口

Sqlalchemy Model To Dict
#

首先我们撞见一个Basemodel,让其他模型都继承自这个

from flask import json
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm.attributes import QueryableAttribute
from wakatime_website import app

db = SQLAlchemy(app)

class BaseModel(db.Model):
    __abstract__ = True

    def to_dict(self, show=None, _hide=[], _path=None):
        """Return a dictionary representation of this model."""

        show = show or []

        hidden = self._hidden_fields if hasattr(self, "_hidden_fields") else []
        default = self._default_fields if hasattr(self, "_default_fields") else []
        default.extend(['id', 'modified_at', 'created_at'])

        if not _path:
            _path = self.__tablename__.lower()

            def prepend_path(item):
                item = item.lower()
                if item.split(".", 1)[0] == _path:
                    return item
                if len(item) == 0:
                    return item
                if item[0] != ".":
                    item = ".%s" % item
                item = "%s%s" % (_path, item)
                return item

            _hide[:] = [prepend_path(x) for x in _hide]
            show[:] = [prepend_path(x) for x in show]

        columns = self.__table__.columns.keys()
        relationships = self.__mapper__.relationships.keys()
        properties = dir(self)

        ret_data = {}

        for key in columns:
            if key.startswith("_"):
                continue
            check = "%s.%s" % (_path, key)
            if check in _hide or key in hidden:
                continue
            if check in show or key in default:
                ret_data[key] = getattr(self, key)

        for key in relationships:
            if key.startswith("_"):
                continue
            check = "%s.%s" % (_path, key)
            if check in _hide or key in hidden:
                continue
            if check in show or key in default:
                _hide.append(check)
                is_list = self.__mapper__.relationships[key].uselist
                if is_list:
                    items = getattr(self, key)
                    if self.__mapper__.relationships[key].query_class is not None:
                        if hasattr(items, "all"):
                            items = items.all()
                    ret_data[key] = []
                    for item in items:
                        ret_data[key].append(
                            item.to_dict(
                                show=list(show),
                                _hide=list(_hide),
                                _path=("%s.%s" % (_path, key.lower())),
                            )
                        )
                else:
                    if (
                        self.__mapper__.relationships[key].query_class is not None
                        or self.__mapper__.relationships[key].instrument_class
                        is not None
                    ):
                        item = getattr(self, key)
                        if item is not None:
                            ret_data[key] = item.to_dict(
                                show=list(show),
                                _hide=list(_hide),
                                _path=("%s.%s" % (_path, key.lower())),
                            )
                        else:
                            ret_data[key] = None
                    else:
                        ret_data[key] = getattr(self, key)

        for key in list(set(properties) - set(columns) - set(relationships)):
            if key.startswith("_"):
                continue
            if not hasattr(self.__class__, key):
                continue
            attr = getattr(self.__class__, key)
            if not (isinstance(attr, property) or isinstance(attr, QueryableAttribute)):
                continue
            check = "%s.%s" % (_path, key)
            if check in _hide or key in hidden:
                continue
            if check in show or key in default:
                val = getattr(self, key)
                if hasattr(val, "to_dict"):
                    ret_data[key] = val.to_dict(
                        show=list(show),
                        _hide=list(_hide), _path=("%s.%s" % (_path, key.lower()))
                        _path=('%s.%s' % (path, key.lower())),
                    )
                else:
                    try:
                        ret_data[key] = json.loads(json.dumps(val))
                    except:
                        pass

        return ret_data

在这个方法中,我们给BaseModel创建了to_dict方法和读取三个默认字段

  • _default_fields
  • _hidden_fields
  • _readonly_fields

Example:

class User(BaseModel):
    id = db.Column(UUID(), primary_key=True, default=uuid.uuid4)
    username = db.Column(db.String(), nullabe=False, unique=True)
    password = db.Column(db.String())
    email_confirmed = db.Column(db.Boolean())
    modified_at = db.Column(db.DateTime())
    created_at = db.Column(db.DateTime(), nullable=False, default=datetime.utcnow)

    _default_fields = [
        "username",
        "joined_recently",
    ]
    _hidden_fields = [
        "password",
    ]
    _readonly_fields = [
        "email_confirmed",
    ]

    @property
    def joined_recently(self):
        return self.created_at > datetime.utcnow() - timedelta(days=3)

user = User(username="zzzeek")
db.session.add(user)
db.session.commit()

print(user.to_dict())

看到我们支持默认输出字段,默认隐藏字段,以及对@property也照样支持

to_dict方法同样支持关联模型 Example:

class User(BaseModel):
    ...
    goals = db.relationship('Goal', backref='user', lazy='dynamic')

class Goal(BaseModel):
    id = db.Column(UUID(), primary_key=True, default=uuid.uuid4)
    title = db.Column(db.String(), nullabe=False)
    accomplished = db.Column(db.Boolean())
    created_at = db.Column(db.DateTime(), nullable=False, default=datetime.utcnow)

    _default_fields = [
        "title",
    ]

goal = Goal(title="Mountain", accomplished=True)
user.goals.append(goal)
db.session.commit()

print(user.to_dict(show=['goals', 'goals.accomplished']))

他将打印如下

{
    'id': UUID('488345de-88a1-4c87-9304-46a1a31c9414'),
    'username': 'zoe',
    'goals': [
        {
            'id': UUID('c72cfef0-0988-45e4-9f4b-8a4a7d4f8d8f'),
            'title': 'Mountain',
            'accomplished': True,
            'created_at': datetime.datetime(2018, 7, 11, 6, 45, 18, 299924),
        },
    ],
    'joined_recently': True,
    'modified_at': datetime.datetime(2018, 7, 11, 6, 36, 47, 939084),
    'created_at': datetime.datetime(2018, 7, 11, 6, 28, 56, 905379),
}

是不是方便多了?