예제 #1
0
class Account(BaseModel):
    """
    Contains all identifiable user information.

    Attributes:
        username (str): Public identifier.
        email (str): Private identifier and can be used for verification.
        phone (str): Mobile phone number with country code included and can be used for verification. May be null or empty.
        password (str): Password of account for protection. Must be hashed via Argon.
        disabled (bool): Renders the account unusable but available.
        verified (bool): Renders the account unusable until verified via two-step verification or other method.
        roles (ManyToManyRelation[Role]): Roles associated with this account.
    """

    username = fields.CharField(max_length=32)
    email = fields.CharField(unique=True, max_length=255)
    phone = fields.CharField(unique=True, max_length=14, null=True)
    password = fields.CharField(max_length=255)
    disabled = fields.BooleanField(default=False)
    verified = fields.BooleanField(default=False)
    roles: fields.ManyToManyRelation["Role"] = fields.ManyToManyField(
        "models.Role", through="account_role")

    def json(self) -> dict:
        return {
            "id": self.id,
            "date_created": str(self.date_created),
            "date_updated": str(self.date_updated),
            "email": self.email,
            "username": self.username,
            "disabled": self.disabled,
            "verified": self.verified,
        }

    def validate(self) -> None:
        """
        Raises an error with respect to account state.

        Raises:
            DeletedError
            UnverifiedError
            DisabledError
        """
        if self.deleted:
            raise DeletedError("Account has been deleted.")
        elif not self.verified:
            raise UnverifiedError()
        elif self.disabled:
            raise DisabledError()

    @staticmethod
    async def get_via_email(email: str):
        """
        Retrieve an account with an email.

        Args:
            email (str): Email associated to account being retrieved.

        Returns:
            account

        Raises:
            NotFoundError
        """
        try:
            account = await Account.filter(email=email, deleted=False).get()
            return account
        except DoesNotExist:
            raise NotFoundError("Account with this email does not exist.")

    @staticmethod
    async def get_via_username(username: str):
        """
        Retrieve an account with a username.

        Args:
            username (str): Username associated to account being retrieved.

        Returns:
            account

        Raises:
            NotFoundError
        """
        try:
            account = await Account.filter(username=username,
                                           deleted=False).get()
            return account
        except DoesNotExist:
            raise NotFoundError("Account with this username does not exist.")

    @staticmethod
    async def get_via_phone(phone: str):
        """
        Retrieve an account with a phone number.

        Args:
            phone (str): Phone number associated to account being retrieved.

        Returns:
            account

        Raises:
            NotFoundError
        """
        try:
            account = await Account.filter(phone=phone, deleted=False).get()
            return account
        except DoesNotExist:
            raise NotFoundError(
                "Account with this phone number does not exist.")
예제 #2
0
class Recipe(Model):
    id = fields.IntField(pk=True)
    meal_name = fields.CharField(100, unique=True)
    ingredients = fields.CharField(255)
    notes = fields.CharField(255, null=True)
예제 #3
0
class Section(models.Model):
    id = fields.IntField(pk=True)
    name = fields.CharField(max_length=50, null=True)

    def __str__(self):
        return name
예제 #4
0
class UserProfile(Model):

    id = fields.IntField(pk=True)
    ap_id = fields.CharField(256,unique=True, null=True)
    name = fields.CharField(64, unique=False, null=True) # Display name
    disabled = fields.BooleanField(default=False) # True if the user is disabled in the server
    is_remote = fields.BooleanField(default=False) # The user is a remote user
    private_key = fields.TextField(null=True) # Private key used to sign AP actions
    public_key = fields.TextField(null=True) # Public key
    description = fields.CharField(512,unique=True, default="") # Description of the profile
    avatar_file = fields.CharField(512,unique=True, default="")
    following_count = fields.IntField(default=0)
    followers_count = fields.IntField(default=0)
    statuses_count = fields.IntField(default=0)
    user = fields.ForeignKeyField('models.User', on_delete='CASCADE')
    public_inbox = fields.CharField(256, null=True)

    followers = fields.ManyToManyField('models.UserProfile', related_name="following")

    @property
    async def username(self):
        await self.fetch_related('user')
        return self.user.username

    @property
    async def is_private(self):
        await self.fetch_related('user')
        return self.user.is_private

    @property
    async def uris(self):
        if self.is_remote:
            return URIs(
                id=self.ap_id,
                inbox=f'{self.ap_id}/inbox',
                outbox=f'{self.ap_id}/inbox',
                following=f'{self.ap_id}/following',
                followers=f'{self.ap_id}/followers',
            )

        return URIs(
            id=uri("user", {"username": await self.username}),
            following=uri("following", {"username": await self.username}),
            followers=uri("followers", {"username": await self.username}),
            outbox=uri("outbox", {"username": await self.username}),
            inbox=uri("inbox", {"username": await self.username}),
            atom=uri("atom", {"id": self.id}),
            featured=uri("featured", {"username": await self.username}),
            avatar=uri('profile_image', {"name": self.avatar_file}),
            client=uri('user_client',{'username': await self.username})
        )



    async def to_json(self):
        await self.fetch_related('user')
        json = {
            'id': self.id,
            'username': await self.username,
            'name': self.name,
            'display_name': self.name,
            'locked': await self.is_private,
            'created_at': str(self.user.created_at),
            'followers_count': self.followers_count,
            'following_count': self.following_count,
            'statuses_count': self.statuses_count,
            'note':self.description,
            'url': None,
            'avatar': self.avatar,
            'moved': None,
            'fields':[],
            'bot': self.user.is_bot,
        }

        if self.is_remote:
            json.update({
                'acct':self.ap_id
            })

        else:
            json.update({
                'acct': await self.username
            })

        return json


    async def to_activitystream(self):
        uris = await self.uris

        json_dict = {
            "@context": [
                "https://www.w3.org/ns/activitystreams",
                "https://w3id.org/security/v1",
            ],
            "type": "Person",
            "id": uris.id,
            "name": self.name,
            "preferredUsername": await self.username,
        }

        if not self.is_remote:
            json_dict.update({
                "following": uris.following,
                "followers": uris.followers,
                "outbox": uris.outbox,
                "inbox": uris.inbox,
                "publicKey": {
                    'publicKeyPem': import_keys()["actorKeys"]["publicKey"],
                    'id': f'{BASE_URL}/users/{await self.username}#main-key',
                    'owner': f'{BASE_URL}/users/{await self.username}'
                },
                "summary": self.description,
                "manuallyApprovesFollowers": await self.is_private,
                "featured": uris.featured,
                "endpoints": {
                    "sharedInbox": uri('sharedInbox')
                }
            })

        return json_dict

    def _create_avatar_id(self):
        hashid = Hashids(salt=salt_code, min_length=6)

        possible_id = self.id + int((datetime.datetime.now() - datetime.datetime(1970,1,1)).total_seconds())
        return hashid.encode(possible_id)

    def _crate_avatar_file(self, image):
        """
        image - A byte array with the image
        """

        filename = self._create_avatar_id()
        image = io.BytesIO(image)
        im = Image.open(image)
        im = im.convert('RGB')
        im.thumbnail((400, 400), Image.ANTIALIAS)
        file_path = os.path.join(MEDIA_FOLDER, 'avatars', filename + '.jpeg')
        im.save(file_path, 'jpeg')

        return f'{filename}.jpeg'

    def update_avatar(self, image):
        return self._crate_avatar_file(image)

    @property
    def avatar(self):
        return uri("profile_image", {"name": self.avatar_file})


    async def is_following(self, user):
        return len(await self.followers.filter(pk=user.pk)) == 1

    async def follow(self, target):


        """
        The current user follows the target account.

        target: An instance of UserProfile
        """

        await self.followers.add(target)

        

    async def unfollow(self, target, valid=False):


        """
        The current user follows the target account.

        target: An instance of UserProfile
        """

        await self.followers.add(target)
예제 #5
0
class Message(Model):
    key = fields.IntField(pk=True)
    text = fields.CharField(max_length=255)
    keyboard = fields.JSONField(default=[])
예제 #6
0
class Giveaway(Model):
    started_by = fields.ForeignKeyField('db.User',
                                        related_name='started_giveaways',
                                        index=True,
                                        null=True)
    started_by_bot = fields.BooleanField(default=False)
    base_amount = fields.CharField(max_length=50, default='0')
    final_amount = fields.CharField(max_length=50, default='0', null=True)
    entry_fee = fields.CharField(max_length=50, default='0')
    end_at = fields.DatetimeField(null=True)
    ended_at = fields.DatetimeField(null=True)
    server_id = fields.BigIntField()
    started_in_channel = fields.BigIntField()
    winning_user = fields.ForeignKeyField('db.User',
                                          related_name='won_giveaways',
                                          null=True)

    class Meta:
        table = 'giveaways'

    @staticmethod
    async def get_active_giveaway(server_id: int) -> 'Giveaway':
        """Returns the current active giveaway for the server, if there is one."""
        giveaway = await Giveaway.filter(
            server_id=server_id, end_at__not_isnull=True, winning_user=None
        ).prefetch_related('started_by').order_by('-end_at').first()
        return giveaway

    @staticmethod
    async def get_active_giveaway_by_id(id: int) -> 'Giveaway':
        """Returns the active giveaway by id, if there is one."""
        giveaway = await Giveaway.filter(
            id=id, end_at__not_isnull=True, winning_user=None
        ).prefetch_related('started_by').order_by('-end_at').first()
        return giveaway

    @staticmethod
    async def get_active_giveaways(server_ids: List[int]) -> List['Giveaway']:
        """Returns the current active giveaway, if there is one."""
        giveaway = await Giveaway.filter(server_id__in=server_ids,
                                         end_at__not_isnull=True,
                                         winning_user=None).prefetch_related(
                                             'started_by').order_by('-end_at')
        return giveaway

    @staticmethod
    async def get_pending_bot_giveaway(server_id: int) -> 'Giveaway':
        """Return the current pending bot giveaway, if there is one"""
        return await Giveaway.filter(
            server_id=server_id, end_at__isnull=True,
            started_by_bot=True).order_by('-end_at').first()

    @staticmethod
    async def start_giveaway_user(server_id: int,
                                  started_by: usr.User,
                                  amount: float,
                                  entry_fee: float,
                                  duration: int,
                                  started_in_channel: int,
                                  conn=None) -> 'Giveaway':
        # Double check no active giveaways
        active = await Giveaway.get_active_giveaway(server_id)
        if active is not None:
            raise Exception("There's already an active giveaway")
        giveaway = Giveaway(started_by=started_by,
                            base_amount=str(Env.amount_to_raw(amount)),
                            entry_fee=str(Env.amount_to_raw(entry_fee)),
                            end_at=datetime.datetime.utcnow() +
                            datetime.timedelta(minutes=duration),
                            server_id=server_id,
                            started_in_channel=started_in_channel)
        await giveaway.save(using_db=conn)
        return giveaway

    @staticmethod
    async def start_giveaway_bot(server_id: int,
                                 entry_fee: float,
                                 started_in_channel: int,
                                 conn=None) -> 'Giveaway':
        # Double check no active giveaways
        active = await Giveaway.get_active_giveaway(server_id)
        if active is not None:
            raise Exception("There's already an active giveaway")
        giveaway = Giveaway(started_by_bot=True,
                            base_amount=str(Env.amount_to_raw(0)),
                            entry_fee=str(Env.amount_to_raw(entry_fee)),
                            server_id=server_id,
                            started_in_channel=started_in_channel)
        await giveaway.save(using_db=conn)
        return giveaway

    async def get_transactions(self):
        """Get transactions belonging to this giveaway"""
        return await self.giveaway_transactions.all()
예제 #7
0
class UniqueName(Model):
    name = fields.CharField(max_length=20, null=True, unique=True)
예제 #8
0
class Store(Model, ModelTimeMixin):
    """店铺"""

    id = fields.IntField(pk=True)
    name = fields.CharField(unique=True, max_length=64, description='店铺名称')
    desc = fields.CharField(null=True, max_length=255, description='店铺简介')
예제 #9
0
class Domain(Model):
    id = fields.CharField(max_length=255, pk=True)
    tld = fields.ForeignKeyField('models.TLD', 'domains')
    owner = fields.CharField(max_length=36)
    token_id = fields.BigIntField(null=True)
예제 #10
0
 def test_max_length_missing(self):
     with self.assertRaisesRegex(
             TypeError,
             "missing 1 required positional argument: 'max_length'"):
         fields.CharField()
예제 #11
0
 def test_max_length_bad(self):
     with self.assertRaisesRegex(ConfigurationError,
                                 "'max_length' must be >= 1"):
         fields.CharField(max_length=0)
예제 #12
0
class Post(CommentMixin, ReactMixin, BaseModel):
    STATUSES = (STATUS_UNPUBLISHED, STATUS_ONLINE) = range(2)

    TYPES = (TYPE_ARTICLE, TYPE_PAGE) = range(2)

    title = fields.CharField(max_length=100, unique=True)
    author_id = fields.IntField()
    slug = fields.CharField(max_length=100)
    summary = fields.CharField(max_length=255)
    can_comment = fields.BooleanField(default=True)
    status = fields.IntField(default=STATUS_UNPUBLISHED)
    type = fields.IntField(default=TYPE_ARTICLE)
    kind = K_POST

    class Meta:
        table = 'posts'

    @classmethod
    async def create(cls, **kwargs):
        tags = kwargs.pop('tags', [])
        content = kwargs.pop('content')
        obj = await super().create(**kwargs)
        if tags:
            await PostTag.update_multi(obj.id, tags)
        await obj.set_content(content)
        return obj

    async def update_tags(self, tagnames):
        if tagnames:
            await PostTag.update_multi(self.id, tagnames)
        return True

    @property
    @cache(MC_KEY_TAGS_BY_POST_ID % ('{self.id}'))
    async def tags(self):
        pts = await PostTag.filter(post_id=self.id)
        if not pts:
            return []
        ids = [pt.tag_id for pt in pts]
        return await Tag.filter(id__in=ids).all()

    @property
    async def author(self):
        rv = await User.cache(self.author_id)
        return rv

    @property
    def preview_url(self):
        return f'/{self.__class__.__name__.lower()}/{self.id}/preview'

    async def set_content(self, content):
        return await self.set_props_by_key('content',
                                           pangu.spacing_text(content))

    async def save(self, *args, **kwargs):
        content = kwargs.pop('content', None)
        if content is not None:
            await self.set_content(content)
        return await super().save(*args, **kwargs)

    @property
    async def content(self):
        rv = await self.get_props_by_key('content')
        if rv:
            return rv.decode('utf-8')

    @property
    async def html_content(self):
        content = await self.content
        if not content:
            return ''
        return markdown(content)

    @property
    async def excerpt(self):
        if self.summary:
            return self.summary
        s = MLStripper()
        s.feed(await self.html_content)
        return trunc_utf8(
            BQ_REGEX.sub('', s.get_data()).replace('\n', ''), 100)  # noqa

    @cache(MC_KEY_RELATED % ('{self.id}'), ONE_HOUR)
    async def get_related(self, limit=4):
        tag_ids = [tag.id for tag in await self.tags]
        if not tag_ids:
            return []
        post_ids = set(await PostTag.filter(Q(post_id__not=self.id),
                                            Q(tag_id__in=tag_ids)).values_list(
                                                'post_id', flat=True))

        excluded_ids = await self.filter(
            Q(created_at__lt=(datetime.now() - timedelta(days=180)))
            | Q(status__not=self.STATUS_ONLINE)).values_list('id', flat=True)

        post_ids -= set(excluded_ids)
        try:
            post_ids = random.sample(post_ids, limit)
        except ValueError:
            ...
        return await self.get_multi(post_ids)

    async def clear_mc(self):
        coros = [
            clear_mc(MC_KEY_RELATED % self.id),
            clear_mc(MC_KEY_POST_BY_SLUG % self.slug),
            clear_mc(MC_KEY_ARCHIVE % self.created_at.year)
        ]
        for key in [
                MC_KEY_FEED, MC_KEY_SITEMAP, MC_KEY_SEARCH, MC_KEY_ARCHIVES,
                MC_KEY_TAGS
        ]:
            coros.append(clear_mc(key))
        for i in [True, False]:
            coros.append(clear_mc(MC_KEY_ALL_POSTS % i))

        for tag in await self.tags:
            coros.append(clear_mc(MC_KEY_TAG % tag.id))
        await asyncio.gather(*coros, return_exceptions=True)

    @classmethod
    @cache(MC_KEY_POST_BY_SLUG % '{slug}')
    async def get_by_slug(cls, slug):
        return await cls.filter(slug=slug).first()

    @classmethod
    @cache(MC_KEY_ALL_POSTS % '{with_page}')
    async def get_all(cls, with_page=True):
        if with_page:
            return await Post.sync_filter(status=Post.STATUS_ONLINE,
                                          orderings=['-id'],
                                          limit=None)
        return await Post.sync_filter(status=Post.STATUS_ONLINE,
                                      type__not=cls.TYPE_PAGE,
                                      orderings=['-id'],
                                      limit=None)

    @classmethod
    async def cache(cls, ident):
        if str(ident).isdigit():
            return await super().cache(ident)
        return await cls.get_by_slug(ident)

    @property
    async def toc(self):
        content = await self.content
        if not content:
            return ''
        toc.reset_toc()
        toc_md.parse(content)
        return toc.render_toc(level=4)

    @property
    def is_page(self):
        return self.type == self.TYPE_PAGE

    @property
    def url(self):
        return f'/page/{self.slug}' if self.is_page else super().url
예제 #13
0
class User(TortoiseBaseUserModel):
    first_name = fields.CharField(null=True, max_length=255)
예제 #14
0
파일: chat.py 프로젝트: alekssamos/KarmaBot
class Chat(Model):
    chat_id = fields.BigIntField(pk=True, generated=False)
    type_: ChatType = typing.cast(ChatType, fields.CharEnumField(ChatType))
    title = fields.CharField(max_length=255, null=True)
    username = fields.CharField(max_length=32, null=True)
    description = fields.CharField(max_length=255, null=True)
    # noinspection PyUnresolvedReferences
    user_karma: fields.ReverseRelation['UserKarma']

    class Meta:
        table = "chats"

    @classmethod
    async def create_from_tg_chat(cls, chat):
        chat = await cls.create(chat_id=chat.id,
                                type_=chat.type,
                                title=chat.title,
                                username=chat.username)
        return chat

    @classmethod
    async def get_or_create_from_tg_chat(cls, chat):
        try:
            chat = await cls.get(chat_id=chat.id)
        except DoesNotExist:
            chat = await cls.create_from_tg_chat(chat=chat)
        return chat

    @property
    def mention(self):
        return hlink(self.title,
                     f"t.me/{self.username}") if self.username else quote_html(
                         self.title)

    def __str__(self):
        rez = f"Chat with type: {self.type_} with ID {self.chat_id}, title: {self.title}"
        if self.username:
            rez += f" Username @{self.username}"
        if self.description:
            rez += f". description: {self.description}"
        return rez

    def __repr__(self):
        return str(self)

    # noinspection PyUnresolvedReferences
    async def get_top_karma_list(self, limit: int = 15):
        await self.fetch_related('user_karma')
        users_karmas = await self.user_karma.order_by(
            *karma_filters).limit(limit).prefetch_related("user").all()
        rez = []
        for user_karma in users_karmas:
            user = user_karma.user
            karma = user_karma.karma_round
            rez.append((user, karma))

        return rez

    # noinspection PyUnresolvedReferences
    async def get_neighbours(
            self, user) -> typing.Tuple["UserKarma", "UserKarma", "UserKarma"]:
        prev_id, next_id = await get_neighbours_id(self.chat_id, user.id)
        uk = await self.user_karma.filter(user_id__in=(prev_id, next_id)
                                          ).prefetch_related("user").order_by(
                                              *karma_filters).all()

        user_uk = await self.user_karma.filter(
            user=user).prefetch_related("user").first()
        return uk[0], user_uk, uk[1]
예제 #15
0
class AbstractModel(Model, OneMixin):
    new_field = fields.CharField(max_length=100)

    class Meta:
        abstract = True
예제 #16
0
class Record(Model):
    id = fields.CharField(max_length=255, pk=True)
    domain = fields.ForeignKeyField('models.Domain', 'records')
    address = fields.CharField(max_length=36, null=True)
예제 #17
0
class Category(Model):
    slug = fields.CharField(max_length=200)
    name = fields.CharField(max_length=200)
    created_at = fields.DatetimeField(auto_now_add=True)
예제 #18
0
class TLD(Model):
    id = fields.CharField(max_length=255, pk=True)
    owner = fields.CharField(max_length=36)
예제 #19
0
class User(Model):
    id = fields.IntField(pk=True)
    email = fields.CharField(max_length=64, unique=True, index=True)
    password = fields.CharField(max_length=255)
    nickname = fields.CharField(max_length=16, unique=True, index=True)
    created_at = fields.DatetimeField(auto_now_add=True)
예제 #20
0
class Expiry(Model):
    id = fields.CharField(max_length=255, pk=True)
    expiry = fields.DatetimeField(null=True)
예제 #21
0
class Todo(BaseModel):
    user = fields.ForeignKeyField("models.User", related_name="users")
    name = fields.CharField(max_length=100)
    memo = fields.TextField(null=True)
    completed = fields.BooleanField(null=True, default=False)
    completed_at = fields.DatetimeField(null=True)
예제 #22
0
class Users(Model):
    id = fields.IntegerField(primary_key=True)
    name = fields.CharField(50)

    def __str__(self):
        return f"User {self.id}: {self.name}"
예제 #23
0
class Config(Model):
    id = fields.IntField(pk=True)
    key = fields.CharField(max_length=255)
    value = fields.CharField(max_length=255)
예제 #24
0
class Users(Model):
    id = fields.IntField(pk=True)
    name = fields.CharField(50)

    def __str__(self):
        return "User {}: {}".format(self.id, self.name)
예제 #25
0
class ShoppingListItem(Model):
    id = fields.IntField(pk=True)
    item_name = fields.CharField(100, unique=True)
예제 #26
0
class OneMixin(ZeroMixin):
    one = fields.CharField(40, null=True)
예제 #27
0
class User(Model, TimestampMixin):
    """User specific information"""
    id = fields.UUIDField(pk=True)
    first_name: str = fields.CharField(max_length=50)
    last_name: str = fields.CharField(max_length=50)
    phone_number: str = fields.CharField(max_length=50, null=True)

    class Meta:
        table = 'user'

    def __str__(self):
        return f"{self.first_name} {self.last_name}"

    async def create_login(self,
                           email: str,
                           password: str,
                           email_confirmed: bool = False):
        """Creates a new login for the associated user"""
        check_email = await Login.filter(email=email).first()
        if check_email is None:
            hash_pass = hash_password(password)
            login = Login(user_id=self.id,
                          email=email,
                          hash_pass=hash_pass,
                          email_confirmed=email_confirmed)
            await login.save()
            return login
        else:
            return False

    async def get_login(self):
        """Returns the login information for the associated user, returns None if no login info is stored"""
        return await Login.filter(login_user_id=self.id).first()

    async def create_subscription(self, subscriber_id: str,
                                  subscriber_email: str, cost: str,
                                  currency: str, period: int):
        """Create a new subscription for the associated merchant"""
        payment_address = await create_account()
        expiration_date = datetime.now() + timedelta(days=1)
        subscription = Subscriptions(merchant_id=self.id,
                                     subscriber_id=subscriber_id,
                                     subscriber_email=subscriber_email,
                                     payment_address=payment_address,
                                     cost=cost,
                                     currency=currency,
                                     period=period,
                                     expiration_date=expiration_date)
        await subscription.save()
        return {
            'subscription_id': subscription.subscription_id,
            'payment_address': payment_address,
            'expiration_date': expiration_date
        }

    async def get_subscriptions(self):
        """Returns the subscription information for the associated user, returns None if no subscriptions exist"""
        return await Subscriptions.filter(subscription_user_id=self.id).all()

    async def get_forwarding_address(self):
        """Returns the address the user has set to forward payments to, returns None if no address is set"""
        return await ForwardingAddress.filter(forwarding_user_id=self.id
                                              ).first()

    async def set_forwarding_address(self, new_address: str = None):
        """Set or update the forwarding address for a user"""
        if new_address is not None and validate_checksum_xrb(new_address):
            old_address = await ForwardingAddress.filter(user_id=self.id
                                                         ).first()
            if old_address is None:
                forwarding_address = ForwardingAddress(user_id=self.id,
                                                       address=new_address)
                await forwarding_address.save()
                return forwarding_address
            elif old_address.address != new_address:
                old_address.address = new_address
                await old_address.save(update_fields=['address'])
            else:
                return None
            return old_address
        else:
            return None

    async def set_name(self, first_name: str, last_name: str):
        """Update the user's first name and last name"""
        update_fields = []
        if self.first_name != first_name:
            self.first_name = first_name.strip().title()
            update_fields.append('first_name')
        if self.last_name != last_name:
            self.last_name = last_name.strip().title()
            update_fields.append('last_name')
        if len(update_fields) > 0:
            await self.save(update_fields=update_fields)
        else:
            return None
        return self

    async def set_phone_number(self, phone_number: str):
        """Update a user's phone number"""
        update_fields = []
        regex = '(?:\+?(\d{1})?-?\(?(\d{3})\)?[\s-\.]?)?(\d{3})[\s-\.]?(\d{4})[\s-\.]?'
        if re.search(regex, phone_number):
            self.phone_number = phone_number
            update_fields.append('phone_number')
        else:
            return None
        if len(update_fields) > 0:
            await self.save(update_fields=update_fields)
        return self
예제 #28
0
class TwoMixin:
    two = fields.CharField(40)
예제 #29
0
class Cours(models.Model):
    id = fields.IntField(pk=True)
    name = fields.CharField(max_length=20, unique=True)
    hours = fields.IntField(null=True)
    classe = fields.ForeignKeyField('models.Class', null=False)
    teacher = fields.ForeignKeyField('models.User', null=False)
예제 #30
0
class Session(BaseModel):
    """
    Used for client identification and verification. Base session model that all session models derive from.

    Attributes:
        expiration_date (datetime): Time the session expires and can no longer be used.
        active (bool): Determines if the session can be used.
        ip (str): IP address of client creating session.
        token (uuid): Token stored on the client's browser in a cookie for identification.
        bearer (ForeignKeyRelation[Account]): Account associated with this session.
        ctx (SimpleNamespace): Store whatever additional information you need about the session. Fields stored will be encoded.
    """

    expiration_date = fields.DatetimeField(null=True)
    active = fields.BooleanField(default=True)
    ip = fields.CharField(max_length=16)
    token = fields.UUIDField(unique=True, default=uuid.uuid4, max_length=36)
    bearer: fields.ForeignKeyRelation["Account"] = fields.ForeignKeyField(
        "models.Account", null=True)
    ctx = SimpleNamespace()

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def json(self) -> dict:
        return {
            "id": self.id,
            "date_created": str(self.date_created),
            "date_updated": str(self.date_updated),
            "expiration_date": str(self.expiration_date),
            "bearer":
            self.bearer.email if isinstance(self.bearer, Account) else None,
            "active": self.active,
        }

    def validate(self) -> None:
        """
        Raises an error with respect to session state.

        Raises:
            DeletedError
            ExpiredError
            DeactivatedError
        """
        if self.deleted:
            raise DeletedError("Session has been deleted.")
        elif (self.expiration_date and datetime.datetime.now(
                datetime.timezone.utc) >= self.expiration_date):
            raise ExpiredError()
        elif not self.active:
            raise DeactivatedError()

    async def check_client_location(self, request) -> None:
        """
        Checks if client ip address has been used previously within other sessions.

        Raises:
            UnrecognisedLocationError
        """
        ip = get_ip(request)
        if not await self.filter(ip=ip, bearer=self.bearer,
                                 deleted=False).exists():
            logger.warning(
                f"Client ({self.bearer.email}/{ip}) ip address is unrecognised"
            )
            raise UnrecognisedLocationError()

    def encode(self, response: HTTPResponse):
        """
        Transforms session into jwt and then is stored in a cookie.

        Args:
            response (HTTPResponse): Sanic response used to store JWT into a cookie on the client.
        """
        payload = {
            "id": self.id,
            "date_created": str(self.date_created),
            "date_updated": str(self.date_updated),
            "expiration_date": str(self.expiration_date),
            "token": str(self.token),
            "ip": self.ip,
            **self.ctx.__dict__,
        }
        cookie = f"{security_config.SESSION_PREFIX}_{self.__class__.__name__.lower()[:4]}_session"
        encoded_session = jwt.encode(
            payload, security_config.SECRET,
            security_config.SESSION_ENCODING_ALGORITHM)
        if isinstance(encoded_session, bytes):
            response.cookies[cookie] = encoded_session.decode()
        elif isinstance(encoded_session, str):
            response.cookies[cookie] = encoded_session
        response.cookies[cookie]["httponly"] = security_config.SESSION_HTTPONLY
        response.cookies[cookie]["samesite"] = security_config.SESSION_SAMESITE
        response.cookies[cookie]["secure"] = security_config.SESSION_SECURE
        if security_config.SESSION_EXPIRES_ON_CLIENT and self.expiration_date:
            response.cookies[cookie]["expires"] = self.expiration_date
        if security_config.SESSION_DOMAIN:
            response.cookies[cookie]["domain"] = security_config.SESSION_DOMAIN

    @classmethod
    def decode_raw(cls, request: Request) -> dict:
        """
        Decodes JWT token from client cookie into a python dict.

        Args:
            request (Request): Sanic request parameter.

        Returns:
            session_dict

        Raises:
            JWTDecodeError
        """
        cookie = request.cookies.get(
            f"{security_config.SESSION_PREFIX}_{cls.__name__.lower()[:4]}_session"
        )
        try:
            if not cookie:
                raise JWTDecodeError("Session token not provided.")
            else:
                return jwt.decode(
                    cookie,
                    security_config.SECRET if not security_config.PUBLIC_SECRET
                    else security_config.PUBLIC_SECRET,
                    security_config.SESSION_ENCODING_ALGORITHM,
                )
        except DecodeError as e:
            raise JWTDecodeError(str(e))

    @classmethod
    async def decode(cls, request: Request):
        """
        Decodes session JWT from client cookie to a Sanic Security session.

        Args:
            request (Request): Sanic request parameter.

        Returns:
            session

        Raises:
            JWTDecodeError
            NotFoundError
        """
        try:
            decoded_raw = cls.decode_raw(request)
            decoded_session = (await
                               cls.filter(token=decoded_raw["token"]
                                          ).prefetch_related("bearer").get())
        except DoesNotExist:
            raise NotFoundError("Session could not be found.")
        return decoded_session

    class Meta:
        abstract = True