class CollectItem(ActionMixin, db.Model): __tablename__ = 'collect_items' user_id = db.Column(db.Integer) target_id = db.Column(db.Integer) target_kind = db.Column(db.Integer) action_type = 'collect' __table_args__ = ( db.Index('idx_ti_tk_ui', target_id, target_kind, user_id), )
class LikeItem(ActionMixin, db.Model): __tablename__ = 'like_items' user_id = db.Column(db.Integer) target_id = db.Column(db.Integer) target_kind = db.Column(db.Integer) action_type = 'like' __table_args__ = ( db.Index('idx_ti', target_id, target_kind, user_id), )
class Tag(BaseMixin, db.Model): __tablename__ = 'tags' name = db.Column(db.String(128), default='', unique=True) __table_args__ = ( db.Index('idx_name', name), ) def __repr__(self): return self.name @classmethod def get_by_name(cls, name): return cls.query.filter_by(name=name).first() def delete(self): raise NotAllowedException def update(self, **kwargs): raise NotAllowedException @classmethod def create(cls, **kwargs): name = kwargs.pop('name') kwargs['name'] = name.lower() return super().create(**kwargs) @classmethod def __flush_event__(cls, target): rdb.delete(MC_KEY_ALL_TAGS)
class CommentItem(ActionMixin, LikeMixin, db.Model): __tablename__ = 'comment_items' user_id = db.Column(db.Integer) target_id = db.Column(db.Integer) target_kind = db.Column(db.Integer) ref_id = db.Column(db.Integer, default=0) content = PropsItem('content', '') kind = K_COMMENT action_type = 'comment' __table_args__ = (db.Index('idx_ti_tk', target_id, target_kind, user_id), ) @cached_hybrid_property def html_content(self): return self.content @cached_hybrid_property def user(self): return User.get(self.user_id)
class userFollowStats(BaseMixin, db.Model): follower_count = db.Column(db.Integer, default=0) following_count = db.Column(db.Integer, default=0) # __table_args__ = { # 'mysql_charset': 'utf8' # } @classmethod def get(cls, id): return cls.cache.get(id) @classmethod def get_or_create(cls, id, **kw): st = cls.get(id) if not st: session = db.create_scoped_session() st = cls(id=id) session.add(st) session.commit() return st
class Contact(BaseMixin, db.Model): __tablename__ = 'contacts' to_id = db.Column(db.Integer) from_id = db.Column(db.Integer) __table_args__ = ( db.UniqueConstraint('from_id', 'to_id', name='uk_from_to'), db.Index('idx_to_time_from', to_id, 'created_at', from_id), db.Index('idx_time_to_from', 'created_at', to_id, from_id), ) def update(self, **kwargs): raise NotAllowedException @classmethod def create(cls, **kwargs): ok, obj = super().create(**kwargs) cls.clear_mc(obj, 1) if ok: from handler.tasks import feed_to_followers feed_to_followers.delay(obj.from_id, obj.to_id) return ok, obj def delete(self): super().delete() self.clear_mc(self, -1) from handler.tasks import remove_user_posts_from_feed remove_user_posts_from_feed.delay(self.from_id, self.to_id) @classmethod @cache(MC_KEY_FOLLOWERS % ('{user_id}', '{page}')) def get_follower_ids(cls, user_id, page=1): query = cls.query.with_entities(cls.from_id).filter_by( to_id=user_id) followers = query.paginate(page, PER_PAGE) followers.items = [id for id, in followers.items] del followers.query return followers @classmethod @cache(MC_KEY_FOLLOWING % ('{user_id}', '{page}')) def get_following_ids(cls, user_id, page=1): query = cls.query.with_entities(cls.to_id).filter_by( from_id=user_id) following = query.paginate(page, PER_PAGE) following.items = [id for id, in following.items] del following.query return following @classmethod @cache(MC_KEY_FOLLOW_ITEM % ('{from_id}', '{to_id}')) def get_follow_item(cls, from_id, to_id): return cls.query.filter_by(from_id=from_id, to_id=to_id).first() @classmethod def clear_mc(cls, target, amount): to_id = target.to_id from_id = target.from_id st = userFollowStats.get_or_create(to_id) follower_count = st.follower_count or 0 st.follower_count = follower_count + amount st.save() st = userFollowStats.get_or_create(from_id) following_count = st.following_count or 0 st.following_count = following_count + amount st.save() rdb.delete(MC_KEY_FOLLOW_ITEM % (from_id, to_id)) for user_id, total, mc_key in ( (to_id, follower_count, MC_KEY_FOLLOWERS), (from_id, following_count, MC_KEY_FOLLOWING)): pages = math.ceil((max(total, 0) or 1) / PER_PAGE) for p in range(1, pages + 1): rdb.delete(mc_key % (user_id, p))
class Post(BaseMixin, CommentMixin, LikeMixin, CollectMixin, db.Model): __tablename__ = 'posts' author_id = db.Column(db.Integer) title = db.Column(db.String(128), default='') orig_url = db.Column(db.String(255), default='') can_comment = db.Column(db.Boolean, default=True) content = PropsItem('content', '') kind = K_POST __table_args__ = ( db.Index('idx_title', title), db.Index('idx_authorId', author_id) ) def url(self): return '/{}/{}/'.format(self.__class__.__name__.lower(), self.id) @classmethod def __flush_event__(cls, target): rdb.delete(MC_KEY_ALL_TAGS) @classmethod def get(cls, identifier): if is_numeric(identifier): return cls.cache.get(identifier) return cls.cache.filter(title=identifier).first() @property @cache(MC_KEY_POST_TAGS % ('{self.id}')) def tags(self): at_ids = PostTag.query.with_entities( PostTag.tag_id).filter( PostTag.post_id == self.id ).all() tags = Tag.query.filter(Tag.id.in_((id for id, in at_ids))).all() return tags @classmethod @cache(MC_KEY_POST_LIST % ('{page}', '{per_page}', '{order_by}')) def get_posts_list(cls, page=1, per_page=10, order_by='id'): query = cls.query.filter().order_by(order_by) posts = query.paginate(page, per_page) del posts.query # Fix `TypeError: can't pickle _thread.lock objects` return posts @cached_hybrid_property def abstract_content(self): return trunc_utf8(self.content, 100) @cached_hybrid_property def author(self): return User.get(self.author_id) @classmethod def create_or_update(cls, **kwargs): tags = kwargs.pop('tags', []) created, obj = super(Post, cls).create_or_update(**kwargs) if tags: PostTag.update_multi(obj.id, tags, []) # 发送Celery任务 if created: from handler.tasks import feed_post, reindex reindex.delay(obj.id, obj.kind, op_type='create') feed_post.delay(obj.id) return created, obj def delete(self): id = self.id super().delete() for pt in PostTag.query.filter_by(post_id=id): pt.delete() from handler.tasks import remove_post_from_feed remove_post_from_feed.delay(self.id, self.author_id) @cached_hybrid_property def netloc(self): return urlparse(self.orig_url).netloc @staticmethod def _flush_insert_event(mapper, connection, target): target._flush_event(mapper, connection, target) target.__flush_insert_event__(target)
class PostTag(BaseMixin, db.Model): __tablename__ = 'post_tags' post_id = db.Column(db.Integer) tag_id = db.Column(db.Integer) __table_args__ = ( db.Index('idx_post_id', post_id, 'updated_at'), db.Index('idx_tag_id', tag_id, 'updated_at'), ) @classmethod def _get_posts_by_tag(cls, identifier): if not identifier: return [] if not is_numeric(identifier): tag = Tag.get_by_name(identifier) if not tag: return identifier = tag.id at_ids = cls.query.with_entities(cls.post_id).filter( cls.tag_id == identifier ).all() query = Post.query.filter( Post.id.in_(id for id, in at_ids)).order_by(Post.id.desc()) return query @classmethod @cache(MC_KEY_POSTS_BY_TAG % ('{identifier}', '{page}')) def get_posts_by_tag(cls, identifier, page, per): query = cls._get_posts_by_tag(identifier) if not query: return [] posts = query.paginate(page, per) del posts.query # Fix `TypeError: can't pickle _thread.lock objects` return posts @classmethod @cache(MC_KEY_POST_STATS_BY_TAG % ('{identifier}')) def get_count_by_tag(cls, identifier): query = cls._get_posts_by_tag(identifier) return query.count() @classmethod def update_multi(cls, post_id, tags, origin_tags=None): if origin_tags is None: origin_tags = Post.get(post_id).tags need_add = set() need_del = set() for tag in tags: if tag not in origin_tags: need_add.add(tag) for tag in origin_tags: if tag not in tags: need_del.add(tag) need_add_tag_ids = set() need_del_tag_ids = set() for tag_name in need_add: _, tag = Tag.create(name=tag_name) need_add_tag_ids.add(tag.id) for tag_name in need_del: _, tag = Tag.create(name=tag_name) need_del_tag_ids.add(tag.id) if need_del_tag_ids: obj = cls.query.filter(cls.post_id == post_id, cls.tag_id.in_(need_del_tag_ids)) obj.delete(synchronize_session='fetch') for tag_id in need_add_tag_ids: cls.create(post_id=post_id, tag_id=tag_id) db.session.commit() @staticmethod def _flush_insert_event(mapper, connection, target): super(PostTag, target)._flush_insert_event(mapper, connection, target) target.clear_mc(target, 1) @staticmethod def _flush_delete_event(mapper, connection, target): super(PostTag, target)._flush_delete_event(mapper, connection, target) target.clear_mc(target, -1) @staticmethod def _flush_after_update_event(mapper, connection, target): super(PostTag, target)._flush_after_update_event( mapper, connection, target) target.clear_mc(target, 1) @staticmethod def _flush_before_update_event(mapper, connection, target): super(PostTag, target)._flush_before_update_event( mapper, connection, target) target.clear_mc(target, -1) @staticmethod def clear_mc(target, amount): post_id = target.post_id tag_name = Tag.get(target.tag_id).name for ident in (post_id, tag_name): total = incr_key(MC_KEY_POST_STATS_BY_TAG % ident, amount) pages = math.ceil((max(total, 0) or 1) / PER_PAGE) for p in range(1, pages + 1): rdb.delete(MC_KEY_POSTS_BY_TAG % (ident, p))
class User(db.Model, UserMixin, BaseMixin): __tablename__ = 'users' bio = db.Column(db.String(128), default='') name = db.Column(db.String(128), default='') nickname = db.Column(db.String(128), default='') email = db.Column(db.String(191), default='') password = db.Column(db.String(191)) website = db.Column(db.String(191), default='') github_url = db.Column(db.String(191), default='') last_login_at = db.Column(db.DateTime()) current_login_at = db.Column(db.DateTime()) last_login_ip = db.Column(db.String(100)) current_login_ip = db.Column(db.String(100)) login_count = db.Column(db.Integer) active = db.Column(db.Boolean()) icon_color = db.Column(db.String(7)) confirmed_at = db.Column(db.DateTime()) company = db.Column(db.String(191), default='') avatar_id = db.Column(db.String(20), default='') roles = db.relationship('Role', secondary=roles_users, backref=db.backref('users', lazy='dynamic')) _stats = None __table_args__ = ( # db.Index('idx_name', name), db.Index('idx_email', email), ) def url(self): return '/user/{}'.format(self.id) @property def github_id(self): return self.github_url.split('/')[-1] @property def avatar_path(self): avatar_id = self.avatar_id return '' if not avatar_id else '/static/avatars/{}.png'.format( avatar_id) def update_avatar(self, avatar_id): self.avatar_id = avatar_id self.save() def upload_avatar(self, img): avatar_id = generate_id() filename = os.path.join(UPLOAD_FOLDER, 'avatars', '{}.png'.format(avatar_id)) if isinstance(img, str) and img.startswith('http'): r = requests.get(img, stream=True) if r.status_code == 200: with open(filename, 'wb') as f: for chunk in r.iter_content(1024): f.write(chunk) else: img.save(filename) self.update_avatar(avatar_id) def follow(self, from_id): ok, _ = Contact.create(to_id=self.id, from_id=from_id) if ok: self._stats = None return ok def unfollow(self, from_id): contact = Contact.get_follow_item(from_id, self.id) if contact: contact.delete() self._stats = None return True return False def is_followed_by(self, user_id): contact = Contact.get_follow_item(user_id, self.id) return bool(contact) @property def n_following(self): return self._follow_stats[1] @property def n_followers(self): return self._follow_stats[0] @property def _follow_stats(self): if self._stats is None: stats = userFollowStats.get(self.id) if not stats: self._stats = 0, 0 else: self._stats = stats.follower_count, stats.following_count return self._stats
class Role(db.Model, RoleMixin): name = db.Column(db.String(80), unique=True) description = db.Column(db.String(191))
from flask_security import SQLAlchemyUserDatastore, UserMixin, RoleMixin import requests from base.base_extend import db from base.utils import generate_id from base.mixin import BaseMixin from .contact import Contact, userFollowStats from base.utils import cached_hybrid_property from flask_app.utils import get_config UPLOAD_FOLDER = get_config().UPLOAD_FOLDER roles_users = db.Table( 'roles_users', db.Column('user_id', db.Integer(), db.ForeignKey('users.id')), db.Column('role_id', db.Integer(), db.ForeignKey('role.id'))) class BranSQLAlchemyUserDatastore(SQLAlchemyUserDatastore): def get_user_name(self, identifier): return self._get_user(identifier, 'name') def get_user_email(self, identifier): return self._get_user(identifier, 'email') def _get_user(self, identifier, attr): user_model_query = self.user_model.query if hasattr(self.user_model, 'roles'): from sqlalchemy.orm import joinedload user_model_query = user_model_query.options(joinedload('roles'))