示例#1
0
class RequestSchema:
    # BIN creation v2
    BinV2POST = {
        "creator": fields.Str(
            required=True, validate=validate.Regexp("([0-9a-zA-Z .@_-]){2,60}")
        ),
        "title": fields.Str(
            required=True, validate=validate.Regexp("([0-9a-zA-Z .@_-]){1,60}")
        ),
        "data": fields.Str(required=True),
        "private": fields.Bool(required=True),
        "language": fields.Str(required=True)
    }
    BinV2GET = {
        "slug": fields.Str(
            required=True, validate=validate.Regexp("([0-9a-zA-Z .@_-]){2,60}")
        ),
    }
    BinLikes = {
        "slug": fields.Str(
            required=True, validate=validate.Regexp("([0-9a-zA-Z .@_-]){2,60}")
        )
    }
    BinViews = {
        "slug": fields.Str(
            required=True, validate=validate.Regexp("([0-9a-zA-Z .@_-]){2,60}")
        )
    }
示例#2
0
class SuperUserPhoneView(SuperBaseView):
    """总后台-用户-修改用户手机号"""
    @use_args(
        {
            "sign":
            fields.String(required=True, comment="加密认证"),
            "timestamp":
            fields.Integer(required=True, comment="时间戳"),
            "user_id":
            fields.Integer(required=True, comment="用户ID"),
            "phone":
            fields.String(required=True,
                          validate=[validate.Regexp(PHONE_RE)],
                          comment="手机号"),
            "sms_code":
            fields.String(required=True, comment="短信验证码"),
        },
        location="json")
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def put(self, request, args):
        user = self._get_current_user(request)
        phone = args["phone"]
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        # 短信验证码校验
        success, info = validate_sms_code(phone, args["sms_code"])
        if not success:
            return self.send_fail(error_text=info)
        success, info = update_user_phone(user, phone)
        if not success:
            return self.send_fail(error_text=info)
        return self.send_success()
示例#3
0
class MallUserRegisterView(MallBaseView):
    """商城-用户-注册"""
    # 父类进行了登录验证,这里覆盖掉
    authentication_classes = ()

    @use_args(
        {
            "phone":
            fields.String(required=True,
                          validate=[validate.Regexp(PHONE_RE)],
                          comment="手机号"),
            "sms_code":
            fields.String(required=True, comment="短信验证码"),
            "password1":
            fields.String(required=True, comment="密码1"),
            "password2":
            fields.String(required=True, comment="密码2"),
        },
        location="json",
    )
    def post(self, request, args, shop_code):
        self._set_current_shop(request, shop_code)
        shop = self.current_shop
        phone = args.get("phone")
        sms_code = args.get("sms_code")
        # 验证密码是否一致
        if args.get("password1") != args.get("password2"):
            return self.send_fail(error_text="两次输入的密码不一致")
        # 校验验证码
        redis_conn = get_redis_connection("verify_codes")
        real_sms_code = redis_conn.get("sms_%s" % phone)
        if not real_sms_code:
            return self.send_fail(error_text="验证码已过期")
        if str(real_sms_code.decode()) != sms_code:
            return self.send_error(status_code=status.HTTP_400_BAD_REQUEST,
                                   error_message={"detail": "短信验证码错误"})
        data = {
            "phone": phone,
            "username": phone,
            "nickname": "用户{phone}".format(phone=phone),
            "head_image_url":
            "http://img.senguo.cc/FlMKOOnlycuoZp1rR39LyCFUHUgl",
            "password": args.get("password1")
        }
        serializer = UserCreateSerializer(data=data)
        serializer.is_valid()
        user = serializer.save()
        customer = get_customer_by_user_id_and_shop_id_interface(
            user.id, shop.id)
        # 新客户则创建客户信息
        if not customer:
            create_customer(user.id, shop.id)
        token, refresh_token = self._set_current_user(user)
        response_data = jwt_response_payload_handler(token, refresh_token,
                                                     user, request)
        return self.send_success(data=response_data)
示例#4
0
文件: mods.py 项目: Sayo-nika/Backend
def b64img_field(name: str):
    return fields.Str(
        validate=validate.Regexp(
            DATA_URI_RE,
            error=
            (f"`{name}` should be a data uri like 'data:image/png;base64,<data>' or "
             "'data:image/jpeg;base64,<data>'"),
        ),
        required=True,
    )
示例#5
0
class SuperUserPasswordView(SuperBaseView):
    """总后台-用户-修改密码"""
    @use_args({
        "sign":
        fields.String(required=True, comment="加密认证"),
        "timestamp":
        fields.Integer(required=True, comment="时间戳"),
        "user_id":
        fields.Integer(required=True, comment="用户ID"),
        "phone":
        fields.String(required=True,
                      validate=[validate.Regexp(PHONE_RE)],
                      comment="手机号"),
        "sms_code":
        fields.String(required=True, comment="短信验证码"),
        "password1":
        DecryptPassword(required=True,
                        validate=[validate.Regexp(PASSWORD_RE)],
                        comment="密码"),
        "password2":
        DecryptPassword(required=True,
                        validate=[validate.Regexp(PASSWORD_RE)],
                        comment="重复密码"),
    })
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def put(self, request, args):
        user = self._get_current_user(request)
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        phone = args["phone"]
        # 短信验证码校验
        success, info = validate_sms_code(phone, args["sms_code"])
        if not success:
            return self.send_fail(error_text=info)
        success, info = update_user_password(user, args["password1"],
                                             args["password2"])
        if not success:
            return self.send_fail(error_text=info)
        token, refresh_token = self._set_current_user(user)
        response_data = jwt_response_payload_handler(token, refresh_token,
                                                     user, request)
        return self.send_success(data=response_data)
示例#6
0
    def __init__(self, urlmap={}):
        self.release = None
        self.endpoint = None
        self.dapver = None
        self.urlmap = urlmap
        self.base_args = {'release': fields.String(required=True,
                          validate=validate.Regexp('MPL-[1-9]'))}
        self.use_params = None
        self._required = None
        self._setmissing = None
        self._main_kwargs = {}
        self.final_args = {}
        self.final_args.update(self.base_args)

        self._parser = parser
        self.use_args = use_args
        self.use_kwargs = use_kwargs
class EnvironmentResource(Resource):
    @staticmethod
    def get(environment_id):
        with make_session() as session:
            data = session.query(Environment).filter(
                Environment.id == environment_id).first()  # type: Environment
            if data is None:
                raise NotFound("Requested environment does not exist")

            return marshal(data[0], environment_fields)

    @staticmethod
    def delete(environment_id):
        with make_session() as session:
            data = session.query(Environment).filter(
                Environment.id == environment_id).first()  # type: Environment
            if data is None:
                raise NotFound("Requested environment does not exist")
            if data.in_use():
                raise Conflict("Requested environment is in use")

            session.delete(data)
            return make_empty_response()

    patch_args = {
        'name':
        fields.String(required=True,
                      validate=(validate.Length(min=1, max=255),
                                validate.Regexp('[^/]+')),
                      trim=True)
    }

    @staticmethod
    @use_kwargs(patch_args)
    def patch(environment_id, name):
        with make_session() as session:
            data = session.query(Environment).filter(
                Environment.id == environment_id).first()  # type: Environment
            if data is None:
                raise NotFound("Requested environment does not exist")

            data.name = name

            return make_empty_response()
示例#8
0
class SMSCodeView(GlobalBaseView):
    """用户-发送短信验证码接口"""
    @use_args(
        {
            "phone":
            fields.String(required=True,
                          validate=[validate.Regexp(PHONE_RE)],
                          comment="手机号")
        },
        location="json")
    def post(self, request, args):
        phone = args["phone"]
        if self.request.META.get('HTTP_X_FORWARDED_FOR'):
            remote_ip = request.META.get("HTTP_X_FORWARDED_FOR")
        else:
            remote_ip = self.request.META.get("REMOTE_ADDR")
        phone_ip = "bind_phone_ip:%s:%s" % (phone, remote_ip)
        redis_conn = get_redis_connection("verify_codes")
        if redis_conn.get(phone_ip):
            return self.send_fail(error_text="一分钟只能发生一次")
        sms_code = gen_sms_code()
        print("sms_code: ", sms_code)  # 测试用

        # 在发送短信验证码前保存数据,以免多次访问和注册时验证
        pl = redis_conn.pipeline()
        pl.setex("sms_%s" % (phone), 300, sms_code)  # 验证码过期时间300秒
        pl.setex(phone_ip, 60, 1)  # 验证码60秒发送一次
        pl.execute()
        # try:
        #     # # 调用第三方接口发送短信
        #     use = "绑定手机号"
        #     # 先用腾讯发,失败就云片用发
        #     ret, info = TencentSms.send_tencent_verify_code(phone, sms_code, use)
        #     if not ret:
        #         ret, info = YunPianSms.send_yunpian_verify_code(phone, sms_code, use)
        #     if not ret:
        #         return self.send_fail(error_text=info)
        # except Exception as e:
        #     print(e)
        #     return self.send_fail(error_text="短信发送失败")
        return self.send_success()
示例#9
0
class PackList(Resource):
    @staticmethod
    def get():
        with make_session() as session:
            return marshal(session.query(Pack).all(), pack_fields)

    put_args = {
        'name':
        fields.String(required=True,
                      validate=(validate.Length(min=1, max=255),
                                validate.Regexp('[^/]+')),
                      trim=True)
    }

    @staticmethod
    @use_kwargs(put_args)
    def put(name):
        pack = Pack(name=name.strip())
        with make_session() as session:
            session.add(pack)
            session.commit()
            return make_id_response(pack.id)
示例#10
0
class EnvironmentList(Resource):
    @staticmethod
    def get():
        with make_session() as session:
            return marshal(
                session.query(Environment).all(), environment_fields)

    put_args = {
        'name':
        fields.String(required=True,
                      validate=(validate.Length(min=1, max=255),
                                validate.Regexp('[^/]+')),
                      trim=True)
    }

    @staticmethod
    @use_kwargs(put_args)
    def put(name: str):
        environment = Environment(name=name.strip())
        with make_session() as session:
            session.add(environment)
            session.commit()
            return make_id_response(environment.id)
示例#11
0
from webargs.flaskparser import use_args, use_kwargs, parser


def plate_in_range(val):
    if int(val) < 6500:
        raise ValidationError('Plateid must be > 6500')


# List of global View arguments across all API routes
viewargs = {
    'name':
    fields.String(
        required=True,
        location='view_args',
        validate=[validate.Length(min=4),
                  validate.Regexp('^[0-9-]*$')]),
    'galid':
    fields.String(
        required=True,
        location='view_args',
        validate=[validate.Length(min=4),
                  validate.Regexp('^[0-9-]*$')]),
    'bintype':
    fields.String(required=True, location='view_args'),
    'template_kin':
    fields.String(required=True, location='view_args'),
    'property_name':
    fields.String(required=True, location='view_args'),
    'channel':
    fields.String(required=True, location='view_args'),
    'binid':
示例#12
0
import datetime
import decimal
from flask import json
from webargs import fields, validate


class CustomJSONEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, (datetime.date, datetime.datetime)):
            return o.isoformat()
        elif isinstance(o, decimal.Decimal):
            return str(o)

        return super().default(o)


date_field = fields.Str(
    required=True, validate=validate.Regexp("^[0-9]{4}-[0-9]{2}-[0-9]{2}$"))


def strptime_or_none(date_string):
    if date_string:
        return datetime.datetime.strptime(date_string, "%Y-%m-%d")
    return
示例#13
0
    def jsonify(self) -> UserJson:
        return {"id": self.id, "username": self.username}

    @classmethod
    def get_by_username(cls, username: str) -> "UserModel":
        return cls.query.filter_by(username=username).first()

    def save_to_db(self) -> None:
        db.session.add(self)
        db.session.commit()

    def remove_from_db(self) -> None:
        db.session.delete(self)
        db.session.commit()

    def authenticate(self, password: str) -> bool:
        return checkpw(password.encode('utf8'), self.password.encode('utf8'))


pw_regex = '^(?=\S{8,64}$)(?=.*?\d)(?=.*?[a-z])(?=.*?[A-Z])(?=.*?[^A-Za-z\s0-9])'
user_args = {
    "username": fields.Str(required=True,
                           validate=validate.Length(min=4, max=64)),
    "password": fields.Str(required=True, validate=validate.Regexp(pw_regex))
}

login_args = {
    "username": fields.String(required=True),
    "password": fields.String(required=True)
}
示例#14
0
class GrampsObjectsResource(GrampsObjectResourceHelper, Resource):
    """Resource for multiple objects."""

    @use_args(
        {
            "backlinks": fields.Boolean(missing=False),
            "dates": fields.Str(
                missing=None,
                validate=validate.Regexp(
                    r"^([0-9]+|\*)/([1-9]|1[0-2]|\*)/([1-9]|1[0-9]|2[0-9]|3[0-1]|\*)$|"
                    r"^-[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-"
                    r"[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$"
                ),
            ),
            "extend": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(
                    choices=[
                        "all",
                        "citation_list",
                        "event_ref_list",
                        "family_list",
                        "note_list",
                        "parent_family_list",
                        "person_ref_list",
                        "primary_parent_family",
                        "place",
                        "source_handle",
                        "father_handle",
                        "mother_handle",
                        "media_list",
                        "reporef_list",
                        "tag_list",
                        "backlinks",
                        "child_ref_list",
                    ]
                ),
            ),
            "filter": fields.Str(validate=validate.Length(min=1)),
            "formats": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1))
            ),
            "format_options": fields.Str(validate=validate.Length(min=1)),
            "gramps_id": fields.Str(validate=validate.Length(min=1)),
            "keys": fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "locale": fields.Str(missing=None, validate=validate.Length(min=1, max=5)),
            "page": fields.Integer(missing=0, validate=validate.Range(min=1)),
            "pagesize": fields.Integer(missing=20, validate=validate.Range(min=1)),
            "profile": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(
                    choices=[
                        "all",
                        "self",
                        "families",
                        "events",
                        "age",
                        "span",
                        "ratings",
                        "references",
                    ]
                ),
            ),
            "rules": fields.Str(validate=validate.Length(min=1)),
            "skipkeys": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1))
            ),
            "sort": fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "soundex": fields.Boolean(missing=False),
            "strip": fields.Boolean(missing=False),
        },
        location="query",
    )
    def get(self, args: Dict) -> Response:
        """Get all objects."""
        locale = get_locale_for_language(args["locale"], default=True)
        if "gramps_id" in args:
            obj = self.get_object_from_gramps_id(args["gramps_id"])
            if obj is None:
                abort(404)
            return self.response(
                200, [self.full_object(obj, args, locale=locale)], args, total_items=1
            )

        query_method = self.db_handle.method("get_%s_handles", self.gramps_class_name)
        if self.gramps_class_name in ["Event", "Repository", "Note"]:
            handles = query_method()
        else:
            handles = query_method(sort_handles=True, locale=locale)

        if "filter" in args or "rules" in args:
            handles = apply_filter(
                self.db_handle, args, self.gramps_class_name, handles
            )

        if args["dates"]:
            handles = self.match_dates(handles, args["dates"])

        if "sort" in args:
            handles = self.sort_objects(handles, args["sort"], locale=locale)

        total_items = len(handles)

        if args["page"] > 0:
            offset = (args["page"] - 1) * args["pagesize"]
            handles = handles[offset : offset + args["pagesize"]]

        query_method = self.db_handle.method(
            "get_%s_from_handle", self.gramps_class_name
        )
        return self.response(
            200,
            [
                self.full_object(query_method(handle), args, locale=locale)
                for handle in handles
            ],
            args,
            total_items=total_items,
        )
示例#15
0
from marvin.utils.general import validate_jwt


def plate_in_range(val):
    if int(val) < 6500:
        raise ValidationError('Plateid must be > 6500')


# List of global View arguments across all API routes
viewargs = {
    'name':
    fields.String(
        required=True,
        location='view_args',
        validate=[validate.Length(min=4),
                  validate.Regexp('^[0-9-]*$')]),
    'galid':
    fields.String(
        required=True,
        location='view_args',
        validate=[validate.Length(min=4),
                  validate.Regexp('^[0-9-]*$')]),
    'bintype':
    fields.String(required=True, location='view_args'),
    'template':
    fields.String(required=True, location='view_args'),
    'property_name':
    fields.String(required=True, location='view_args'),
    'channel':
    fields.String(required=True, location='view_args'),
    'binid':
示例#16
0
from typing import Dict, Union

from flask import Blueprint, jsonify, abort
from flask_jwt_extended import create_access_token, jwt_refresh_token_required, current_user, create_refresh_token, \
    jwt_required
from webargs import fields, validate
from webargs.flaskparser import use_args

from api import jwt
from api.models import User
from api.sessions import FedauthSession

bp = Blueprint("auth", __name__, url_prefix="/auth")

login_schema = {
    "username": fields.Str(required=True, validate=validate.Regexp(regex=r"^[a-z]{4}[0-9]{4}$"), location="form"),
    "password": fields.Str(required=True, location="form")
}


@jwt.user_identity_loader
def user_identity_lookup(user: User) -> str:
    return user.identikey


@jwt.user_loader_callback_loader
def user_loader_callback(identity: str) -> Union[User, None]:
    return User.query.get(identity)


@bp.route("/login", methods=["POST"])
示例#17
0
class MallUserView(MallBaseView):
    """商城-用户-登录"""
    # 父类进行了登录验证,这里覆盖掉
    authentication_classes = ()

    @use_args(
        {
            "code":
            fields.String(required=False, comment="微信code"),
            "phone":
            fields.String(required=False,
                          validate=[validate.Regexp(PHONE_RE)],
                          comment="手机号"),
            "sms_code":
            fields.String(required=False, comment="短信验证码"),
            "password":
            DecryptPassword(required=False,
                            validate=[validate.Regexp(PASSWORD_RE)],
                            comment="密码"),
            "login_type":
            fields.Integer(required=True,
                           validate=[
                               validate.OneOf([
                                   UserLoginType.WX, UserLoginType.PWD,
                                   UserLoginType.PHONE
                               ])
                           ],
                           comment="登录方式,0:微信,1:密码,2:手机")
        },
        location="json",
    )
    def post(self, request, args, shop_code):
        login_type = args["login_type"]
        code = args.get("code", None)
        phone = args.get("phone", None)
        pwd = args.get("password", None)
        sms_code = args.get("sms_code", None)
        self._set_current_shop(request, shop_code)
        shop = self.current_shop
        # todo 微信登录还需要修改
        # 若登录方式为微信
        if login_type == UserLoginType.WX:
            if not code:
                return self.send_fail(error_text="微信登录缺少code")
            shop_appid = MP_APPID
            shop_appsecret = MP_APPSECRET
            wechat_oauth = WeChatOAuth(
                app_id=shop_appid,
                secret=shop_appsecret,
                redirect_uri="",
                scope="snsapi_userinfo",
            )
            try:
                wechat_oauth.fetch_access_token(code)
                user_info = wechat_oauth.get_user_info()
            except:
                return self.send_fail(error_text='获取微信授权失败')
            """
               user_info = {
                       "openid": "oMZbfv3iy12L1q1XGWpkko_P_YPI",
                       "nickname": "hpf",
                       "sex": 1,
                       "language": "zh_CN",
                       "city": "武汉",
                       "province": "湖北",
                       "country": "中国",
                       "headimgurl": "http://thirdwx.qlogo.cn/mmopen/vi_32/yctGCWkz1jI2ybfVe12KmrXIb9R89dfgnoribX9sG75hBPJQlsK30fnib9r4nKELHcpcXAibztiaHH3jz65f03ibOlg/132",
                       "privilege": [],
                       "unionid": "oIWUauOLaT50pWKUeNKhKP6W0WIU"
                   }
            """
            user_info["headimgurl"] = user_info["headimgurl"].replace(
                "http://", "https://")
            user = get_user_by_wx_unionid(user_info.get("unionid"))
            if not user:
                new_user_info = {
                    "username": user_info.get('phone'),
                    "phone": user_info.get('phone'),
                    "sex": user_info.get('sex'),
                    "nickname": user_info.get("nickname"),
                    "realname": user_info.get("realname"),
                    "head_image_url": user_info.get("headimgurl"),
                    "wx_unionid": user_info.get("unionid"),
                    "wx_openid": user_info.get("openid"),
                    "wx_country": user_info.get("country"),
                    "wx_province": user_info.get("province"),
                    "wx_city": user_info.get("city"),
                }
                user_serializer = UserCreateSerializer(data=new_user_info)
                user = user_serializer.save()
            ret, user_openid = get_openid_by_user_id_and_appid(
                user.id, shop_appid)
            # 不存在则添加用户的openid
            if not ret:
                info = {
                    'user_id': user.id,
                    'mp_appid': shop_appid,
                    'wx_openid': user_info.get("openid"),
                }
                create_user_openid(**info)
        # 若登录方式为密码
        elif login_type == UserLoginType.PWD:
            if not phone and not pwd:
                return self.send_fail(error_text="密码登录缺手机号或密码")
            success, user = get_user_by_phone_and_password(
                phone, pwd, login_type)
            if not success:
                return self.send_fail(error_text=user)
        # 若登陆方式为手机号
        else:
            if not phone and not sms_code:
                return self.send_fail(error_text="密码登录缺手机号或验证码")
            redis_conn = get_redis_connection("verify_codes")
            real_sms_code = redis_conn.get("sms_%s" % phone)
            if not real_sms_code:
                return self.send_fail(error_text="验证码已过期")
            if str(real_sms_code.decode()) != sms_code:
                return self.send_error(status_code=status.HTTP_400_BAD_REQUEST,
                                       error_message={"detail": "短信验证码错误"})
            success, user = get_user_by_phone(phone, login_type)
            if not success:
                return self.send_fail(error_text=user)
            # user不存在
            if not user:
                return self.send_fail(error_text="该用户不存在")
        customer = get_customer_by_user_id_and_shop_id_interface(
            user.id, shop.id)
        # 新客户则创建客户信息
        if not customer:
            create_customer(user.id, shop.id)
        token, refresh_token = self._set_current_user(user)
        response_data = jwt_response_payload_handler(token, refresh_token,
                                                     user, request)
        return self.send_success(data=response_data)
示例#18
0
from . import validators, utils

api = Blueprint('accounts.username', __name__)
db = app.db


@ratelimit(limit=10, interval=300, key_prefix='register')
@api.route('/register.api', methods=['post'])
# @anonymous_user_required
@use_args(
    {
        'username':
        fields.Str(required=True,
                   validate=[
                       validate.Length(min=6, max=30),
                       validate.Regexp('^[a-zA-Z][a-zA-Z0-9_\-]+'),
                       validators.username_exists
                   ]),
        'password':
        fields.Str(required=True, validate=validate.Length(min=8, max=32)),
    },
    locations=('form', 'json'))
def register(args):
    if app.config.get('ACCOUNT_USERNAME_DISABLE_REGISTER', True):
        abort(422, errors={'global': '网站禁止使用用户名方式注册'})

    return jsonify(utils.register(args['username'], args['password'])), 201


@ratelimit(limit=5, interval=300, key_prefix='login:username')
@api.route('/login.api', methods=['post'])
示例#19
0
from webargs import fields, validate

saludo_request = {
    "nombre": fields.Str(required=False, missing="", validate=validate.Regexp(r"^[a-zA-Z ]{4,20}$"))
}
示例#20
0
from brain.api.base import processRequest
from marvin import config
from marvin.core.exceptions import MarvinError
from marvin.utils.datamodel.dap import datamodel as dm
from webargs import fields, validate, ValidationError
from webargs.flaskparser import use_args, use_kwargs, parser


def plate_in_range(val):
    if int(val) < 6500:
        raise ValidationError('Plateid must be > 6500')


# List of global View arguments across all API routes
viewargs = {'name': fields.String(required=True, location='view_args', validate=[validate.Length(min=4),
                                  validate.Regexp('^[0-9-]*$')]),
            'galid': fields.String(required=True, location='view_args', validate=[validate.Length(min=4),
                                   validate.Regexp('^[0-9-]*$')]),
            'bintype': fields.String(required=True, location='view_args'),
            'template': fields.String(required=True, location='view_args'),
            'property_name': fields.String(required=True, location='view_args'),
            'channel': fields.String(required=True, location='view_args'),
            'binid': fields.Integer(required=True, location='view_args', validate=validate.Range(min=-1, max=5800)),
            'plateid': fields.String(required=True, location='view_args', validate=[validate.Length(min=4, max=5),
                                     plate_in_range]),
            'x': fields.Integer(required=True, location='view_args', validate=validate.Range(min=0, max=100)),
            'y': fields.Integer(required=True, location='view_args', validate=validate.Range(min=0, max=100)),
            'mangaid': fields.String(required=True, location='view_args', validate=validate.Length(min=4, max=20)),
            'paramdisplay': fields.String(required=True, location='view_args', validate=validate.OneOf(['all', 'best'])),
            'cube_extension': fields.String(required=True, location='view_args',
                                            validate=validate.OneOf(['flux', 'ivar', 'mask',
示例#21
0
def add_generation_routes(app, is_local, server_path):
    """ Create routes related to file generation

        Attributes:
            app: A Flask application
            is_local: A boolean flag indicating whether the application is being run locally or not
            server_path: A string containing the path to the server files (only applicable when run locally)
    """

    @app.route("/v1/generate_file/", methods=["POST"])
    @convert_to_submission_id
    @requires_submission_perms('writer')
    @use_kwargs({
        'file_type': webargs_fields.String(
            required=True,
            validate=webargs_validate.OneOf(('D1', 'D2', 'E', 'F'), error="Must be either D1, D2, E or F")),
        'start': webargs_fields.String(
            validate=webargs_validate.Regexp(DATE_REGEX, error="Must be in the format MM/DD/YYYY")),
        'end': webargs_fields.String(
            validate=webargs_validate.Regexp(DATE_REGEX, error="Must be in the format MM/DD/YYYY")),
        'agency_type': webargs_fields.String(
            missing='awarding',
            validate=webargs_validate.OneOf(('awarding', 'funding'),
                                            error="Must be either awarding or funding if provided")
        )
    })
    def generate_file(submission_id, file_type, **kwargs):
        """ Kick of a file generation, or retrieve the cached version of the file.

            Attributes:
                submission: submission ID for which we're generating the file
                file_type: type of file to generate the job for
                start: the start date for the file to generate
                end: the end date for the file to generate
                agency_type: The type of agency (awarding or funding) to generate the file for
        """
        start = kwargs.get('start')
        end = kwargs.get('end')
        agency_type = kwargs.get('agency_type')
        return generation_handler.generate_file(submission_id, file_type, start, end, agency_type)

    @app.route("/v1/check_generation_status/", methods=["GET"])
    @convert_to_submission_id
    @requires_submission_perms('reader')
    @use_kwargs({'file_type': webargs_fields.String(
        required=True,
        validate=webargs_validate.OneOf(('D1', 'D2', 'E', 'F'), error="Must be either D1, D2, E or F"))
    })
    def check_generation_status(submission, file_type):
        """ Return status of file generation job

            Attributes:
                submission: submission for which we're generating the file
                file_type: type of file to check the status of
        """
        return generation_handler.check_generation(submission, file_type)

    @app.route("/v1/generate_detached_file/", methods=["POST"])
    @requires_login
    @use_kwargs({
        'file_type': webargs_fields.String(required=True, validate=webargs_validate.OneOf(('A', 'D1', 'D2'))),
        'cgac_code': webargs_fields.String(),
        'frec_code': webargs_fields.String(),
        'start': webargs_fields.String(),
        'end': webargs_fields.String(),
        'year': webargs_fields.Int(),
        'period': webargs_fields.Int(validate=webargs_validate.OneOf(list(range(2, 13)),
                                                                     error="Period must be an integer 2-12.")),
        'agency_type': webargs_fields.String(
            missing='awarding',
            validate=webargs_validate.OneOf(('awarding', 'funding'),
                                            error="Must be either awarding or funding if provided")
        )
    })
    def generate_detached_file(file_type, **kwargs):
        """ Generate a file from external API, independent from a submission

            Attributes:
                file_type: type of file to be generated
                cgac_code: the code of a CGAC agency if generating for a CGAC agency
                frec_code: the code of a FREC agency if generating for a FREC agency
                start: start date in a string, formatted MM/DD/YYYY
                end: end date in a string, formatted MM/DD/YYYY
                year: integer indicating the year to generate for (YYYY)
                period: integer indicating the period to generate for (2-12)
                agency_type: The type of agency (awarding or funding) to generate the file for
        """
        cgac_code = kwargs.get('cgac_code')
        frec_code = kwargs.get('frec_code')
        start = kwargs.get('start')
        end = kwargs.get('end')
        year = kwargs.get('year')
        period = kwargs.get('period')
        agency_type = kwargs.get('agency_type')
        return generation_handler.generate_detached_file(file_type, cgac_code, frec_code, start, end, year, period,
                                                         agency_type)

    @app.route("/v1/check_detached_generation_status/", methods=["GET"])
    @requires_login
    @use_kwargs({'job_id': webargs_fields.Int(required=True)})
    def check_detached_generation_status(job_id):
        """ Return status of file generation job """
        return generation_handler.check_detached_generation(job_id)
示例#22
0
class Mods(RouteCog):
    @staticmethod
    def dict_all(models):
        return [m.to_dict() for m in models]

    @multiroute("/api/v1/mods", methods=["GET"], other_methods=["POST"])
    @json
    @use_kwargs(
        {
            "q": fields.Str(),
            "page": fields.Int(missing=0),
            "limit": fields.Int(missing=50),
            "category": EnumField(ModCategory),
            "rating": fields.Int(validate=validate.OneOf([1, 2, 3, 4, 5])),
            "status": EnumField(ModStatus),
            "sort": EnumField(ModSorting),
            "ascending": fields.Bool(missing=False)
        },
        locations=("query", ))
    async def get_mods(self,
                       q: str = None,
                       page: int = None,
                       limit: int = None,
                       category: ModCategory = None,
                       rating: int = None,
                       status: ModStatus = None,
                       sort: ModSorting = None,
                       ascending: bool = None):
        if not 1 <= limit <= 100:
            limit = max(1, min(
                limit,
                100))  # Clamp `limit` to 1 or 100, whichever is appropriate

        page = page - 1 if page > 0 else 0
        query = Mod.query.where(Mod.verified)

        if q is not None:
            like = f"%{q}%"

            query = query.where(
                and_(Mod.title.match(q), Mod.tagline.match(q),
                     Mod.description.match(q), Mod.title.ilike(like),
                     Mod.tagline.ilike(like), Mod.description.ilike(like)))

        if category is not None:
            query = query.where(Mod.status == category)

        if rating is not None:
            query = query.where(rating + 1 > db.select([
                func.avg(
                    Review.select("rating").where(Review.mod_id == Mod.id))
            ]) >= rating)

        if status is not None:
            query = query.where(Mod.status == status)

        if sort is not None:
            sort_by = mod_sorters[sort]
            query = query.order_by(
                sort_by.asc() if ascending else sort_by.desc())

        results = await paginate(query, page, limit).gino.all()
        total = await query.alias().count().gino.scalar()

        return jsonify(total=total,
                       page=page,
                       limit=limit,
                       results=self.dict_all(results))

    @multiroute("/api/v1/mods", methods=["POST"], other_methods=["GET"])
    @requires_login
    @json
    @use_kwargs(
        {
            "title":
            fields.Str(required=True, validate=validate.Length(max=64)),
            "tagline":
            fields.Str(required=True, validate=validate.Length(max=100)),
            "description":
            fields.Str(required=True,
                       validate=validate.Length(min=100, max=10000)),
            "website":
            fields.Url(required=True),
            "status":
            EnumField(ModStatus, required=True),
            "category":
            EnumField(ModCategory, required=True),
            "authors":
            fields.List(fields.Nested(AuthorSchema), required=True),
            "icon":
            fields.Str(validate=validate.Regexp(
                DATA_URI_RE,
                error=
                ("`icon` should be a data uri like 'data:image/png;base64,<data>' or "
                 "'data:image/jpeg;base64,<data>'")),
                       required=True),
            "banner":
            fields.Str(validate=validate.Regexp(
                DATA_URI_RE,
                error=
                ("`banner` should be a data uri like 'data:image/png;base64,<data>' or "
                 "'data:image/jpeg;base64,<data>'"),
            ),
                       required=True),
            "is_private_beta":
            fields.Bool(missing=False),
            "mod_playtester":
            fields.List(fields.Str()),
            "color":
            EnumField(ModColor, missing=ModColor.default),
            "recaptcha":
            fields.Str(required=True)
        },
        locations=("json", ))
    async def post_mods(self,
                        title: str,
                        tagline: str,
                        description: str,
                        website: str,
                        authors: List[dict],
                        status: ModStatus,
                        category: ModCategory,
                        icon: str,
                        banner: str,
                        recaptcha: str,
                        color: ModColor,
                        is_private_beta: bool = None,
                        mod_playtester: List[str] = None):
        score = await verify_recaptcha(recaptcha, self.core.aioh_sess, 3,
                                       "create_mod")

        if score < 0.5:
            # TODO: discuss what to do here
            abort(400, "Possibly a bot")

        token = request.headers.get("Authorization",
                                    request.cookies.get("token"))
        parsed_token = await jwt_service.verify_login_token(token, True)
        user_id = parsed_token["id"]

        # Check if any mod with a similar enough name exists already.
        generalized_title = generalize_text(title)
        mods = await Mod.get_any(True,
                                 generalized_title=generalized_title).first()

        if mods is not None:
            abort(400, "A mod with that title already exists")

        if status is ModStatus.archived:
            abort(400, "Can't create a new archived mod")

        mod = Mod(title=title,
                  tagline=tagline,
                  description=description,
                  website=website,
                  status=status,
                  category=category,
                  theme_color=color)

        icon_mimetype, icon_data = validate_img(icon, "icon")
        banner_mimetype, banner_data = validate_img(banner, "banner")

        for i, author in enumerate(authors):
            if author["id"] == user_id:
                authors.pop(i)
                continue
            elif not await User.exists(author["id"]):
                abort(400, f"Unknown user '{author['id']}'")

        authors.append({"id": user_id, "role": AuthorRole.owner})

        if is_private_beta is not None:
            mod.is_private_beta = is_private_beta

        if mod_playtester is not None:
            if not is_private_beta:
                abort(400, "No need for `ModPlaytester` if open beta")

            for playtester in mod_playtester:
                if not await User.exists(playtester):
                    abort(400, f"Unknown user '{playtester}'")

        # Decode images and add name for mimetypes
        icon_data = base64.b64decode(icon_data)
        banner_data = base64.b64decode(banner_data)
        icon_ext = icon_mimetype.split("/")[1]
        banner_ext = banner_mimetype.split("/")[1]

        icon_data = NamedBytes(icon_data, name=f"icon.{icon_ext}")
        banner_data = NamedBytes(banner_data, name=f"banner.{banner_ext}")

        img_urls = await owo.async_upload_files(icon_data, banner_data)

        mod.icon = img_urls[icon_data.name]
        mod.banner = img_urls[banner_data.name]

        await mod.create()
        await ModAuthor.insert().gino.all(*[
            dict(user_id=author["id"], mod_id=mod.id, role=author["role"])
            for author in authors
        ])

        if ModPlaytester is not None:
            await ModPlaytester.insert().gino.all(
                *
                [dict(user_id=user, mod_id=mod.id) for user in mod_playtester])

        return jsonify(mod.to_dict())

    @route("/api/v1/mods/recent_releases")
    @json
    async def get_recent_releases(self):
        mods = await Mod.query.where(
            and_(Mod.verified, Mod.status == ModStatus.released)
        ).order_by(Mod.released_at.desc()).limit(10).gino.all()

        return jsonify(self.dict_all(mods))

    @route("/api/v1/mods/most_loved")
    @json
    async def get_most_loved(self):
        love_counts = select(
            [func.count()]).where(UserFavorite.mod_id == Mod.id).as_scalar()
        mods = await Mod.query.order_by(love_counts.desc()
                                        ).limit(10).gino.all()

        return jsonify(self.dict_all(mods))

    @route("/api/v1/mods/most_downloads")
    @json
    async def get_most_downloads(self):
        mods = await Mod.query.where(
            and_(Mod.verified, Mod.released_at is not None)
        ).order_by(Mod.downloads.desc()).limit(10).gino.all()

        return jsonify(self.dict_all(mods))

    @route("/api/v1/mods/trending")
    @json
    async def get_trending(self):
        # TODO: implement
        return jsonify([])

    @multiroute("/api/v1/mods/<mod_id>",
                methods=["GET"],
                other_methods=["PATCH", "DELETE"])
    @json
    async def get_mod(self, mod_id: str):
        mod = await Mod.get(mod_id)

        if mod is None:
            abort(404, "Unknown mod")

        return jsonify(mod.to_dict())

    @multiroute("/api/v1/mods/<mod_id>",
                methods=["PATCH"],
                other_methods=["GET", "DELETE"])
    @requires_login
    @json
    @use_kwargs(
        {
            "title":
            fields.Str(validate=validate.Length(max=64)),
            "tagline":
            fields.Str(validate=validate.Length(max=100)),
            "description":
            fields.Str(validate=validate.Length(min=100, max=10000)),
            "website":
            fields.Url(),
            "status":
            EnumField(ModStatus),
            "category":
            EnumField(ModCategory),
            "authors":
            fields.List(fields.Nested(AuthorSchema)),
            "icon":
            fields.Str(validate=validate.Regexp(
                DATA_URI_RE,
                error=
                ("`icon` should be a data uri like 'data:image/png;base64,<data>' or "
                 "'data:image/jpeg;base64,<data>'"))),
            "banner":
            fields.Str(validate=validate.Regexp(
                DATA_URI_RE,
                error=
                ("`banner` should be a data uri like 'data:image/png;base64,<data>' or "
                 "'data:image/jpeg;base64,<data>'"),
            )),
            "color":
            EnumField(ModColor),
            "is_private_beta":
            fields.Bool(),
            "mod_playtester":
            fields.List(fields.Str())
        },
        locations=("json", ))
    async def patch_mod(self,
                        mod_id: str = None,
                        authors: List[dict] = None,
                        mod_playtester: List[str] = None,
                        icon: str = None,
                        banner: str = None,
                        **kwargs):
        if not await Mod.exists(mod_id):
            abort(404, "Unknown mod")

        mod = await Mod.get(mod_id)
        updates = mod.update(**kwargs)

        if authors is not None:
            authors = [
                author for author in authors if await User.exists(author["id"])
            ]
            # TODO: if user is owner or co-owner, allow them to change the role of others to ones below them.
            authors = [
                author for author in authors
                if not await ModAuthor.query.where(
                    and_(ModAuthor.user_id == author["id"], ModAuthor.mod_id ==
                         mod_id)).gino.first()
            ]

        if mod_playtester is not None:
            for playtester in mod_playtester:
                if not await User.exists(playtester):
                    abort(400, f"Unknown user '{playtester}'")
                elif await ModPlaytester.query.where(
                        and_(ModPlaytester.user_id == playtester,
                             ModPlaytester.mod_id == mod.id)).gino.all():
                    abort(400, f"{playtester} is already enrolled.")

        to_upload = []

        if icon is not None:
            icon_mimetype, icon_data = validate_img(icon, "icon")
            icon_data = base64.b64decode(icon_data)

            icon_ext = icon_mimetype.split("/")[1]
            icon_data = NamedBytes(icon_data, name=f"icon.{icon_ext}")

            to_upload.append(icon_data)

        if banner is not None:
            banner_mimetype, banner_data = validate_img(banner, "banner")
            banner_data = base64.b64decode(banner_data)

            banner_ext = banner_mimetype.split("/")[1]
            banner_data = NamedBytes(banner_data, name=f"banner.{banner_ext}")

            to_upload.append(banner_data)

        img_urls = await owo.async_upload_files(*to_upload)
        img_updates = {}

        if icon is not None:
            img_updates["icon"] = img_urls[icon_data.name]

        if banner is not None:
            img_updates["banner"] = img_urls[banner_data.name]

        # Lump together image updates because lessening operations or some shit.
        updates = updates.update(**img_updates)

        await updates.apply()
        await ModAuthor.insert().gino.all(*[
            dict(user_id=author["id"], mod_id=mod.id, role=author["role"])
            for author in authors
        ])
        await ModPlaytester.insert().gino.all(
            *[dict(user_id=user, mod_id=mod.id) for user in ModPlaytester])

        return jsonify(mod.to_dict())

    # TODO: decline route with reason, maybe doesn't 100% delete it? idk
    @multiroute("/api/v1/mods/<mod_id>",
                methods=["DELETE"],
                other_methods=["GET", "PATCH"])
    @requires_login
    @json
    async def delete_mod(self, mod_id: str):
        await Mod.delete.where(Mod.id == mod_id).gino.status()

        return jsonify(True)

    @route("/api/v1/mods/<mod_id>/download")
    @json
    async def get_download(self, mod_id: str):
        token = request.headers.get("Authorization",
                                    request.cookies.get("token"))
        parsed_token = await jwt_service.verify_login_token(token, True)
        user_id = parsed_token["id"]

        mod = await Mod.get(mod_id)

        if mod is None:
            abort(404, "Unknown mod")
        if user_id is None and mod.is_private_beta:
            abort(403, "Private beta mods requires authentication.")
        if not await ModPlaytester.query.where(and_(ModPlaytester.user_id == user_id, ModPlaytester.mod_id == mod.id))\
                .gino.all():
            abort(403, "You are not enrolled to the private beta.")
        elif not mod.zip_url:
            abort(404, "Mod has no download")

        return jsonify(url=mod.zip_url)

    @multiroute("/api/v1/mods/<mod_id>/reviews",
                methods=["GET"],
                other_methods=["POST"])
    @json
    @use_kwargs(
        {
            "page":
            fields.Int(missing=0),
            "limit":
            fields.Int(missing=10),
            # Probably won't work right now, will need union field.
            "rating":
            fields.Int(validate=validate.OneOf([1, 2, 3, 4, 5, "all"]),
                       missing="all"),
            "sort":
            EnumField(ReviewSorting, missing=ReviewSorting.best)
        },
        locations=("query", ))
    async def get_reviews(self, mod_id: str, page: int, limit: int,
                          rating: Union[int, str], sort: ReviewSorting):
        if not await Mod.exists(mod_id):
            abort(404, "Unknown mod")

        if not 1 <= limit <= 25:
            limit = max(1, min(
                limit,
                25))  # Clamp `limit` to 1 or 100, whichever is appropriate

        page = page - 1 if page > 0 else 0
        query = Review.query.where(Review.mod_id == mod_id)

        if review_sorters[sort]:
            query = query.order_by(review_sorters[sort])
        elif sort == ReviewSorting.best:
            upvoters_count = select([func.count()]).where(
                and_(ReviewReaction.review_id == Review.id,
                     ReviewReaction.reaction ==
                     ReactionType.upvote)).as_scalar()

            downvoters_count = select([func.count()]).where(
                and_(ReviewReaction.review_id == Review.id,
                     ReviewReaction.reaction ==
                     ReactionType.downvote)).as_scalar()

            query = query.order_by(upvoters_count - downvoters_count)
        elif sort == ReviewSorting.funniest:
            # Get count of all funny ratings by review.
            sub_order = select([func.count()]).where(
                and_(ReviewReaction.review_id == Review.id,
                     ReviewReaction.reaction ==
                     ReactionType.funny)).as_scalar()
            query = query.order_by(sub_order.desc())

        if isinstance(rating, int):
            values = [rating, rating + 0.5]

            if rating == 1:
                # Also get reviews with a 0.5 star rating, otherwise they'll never appear.
                values.append(0.5)

            query = query.where(Review.rating.in_(values))

        reviews = await query.gino.all()

        return jsonify(self.dict_all(reviews))

    @multiroute("/api/v1/mods/<mod_id>/reviews",
                methods=["POST"],
                other_methods=["GET"])
    @requires_login
    @json
    @use_kwargs({
        "rating":
        fields.Int(
            required=True,
            validate=[
                # Only allow increments of 0.5, up to 5.
                lambda x: 5 >= x >= 1,
                lambda x: x % 0.5 == 0
            ]),
        "content":
        fields.Str(required=True, validate=validate.Length(max=2000)),
        "title":
        fields.Str(required=True, validate=validate.Length(max=32))
    })
    async def post_review(self, mod_id: str, rating: int, content: str,
                          title: str):
        if not await Mod.exists(mod_id):
            abort(404, "Unknown mod")

        token = request.headers.get("Authorization",
                                    request.cookies.get("token"))
        parsed_token = await jwt_service.verify_login_token(token, True)
        user_id = parsed_token["id"]

        if await Review.query.where(
                and_(Review.author_id == user_id,
                     Review.mod_id == mod_id)).gino.first():
            abort(400, "Review already exists")

        review = await Review.create(title=title,
                                     content=content,
                                     rating=rating,
                                     author_id=user_id,
                                     mod_id=mod_id)

        return jsonify(review.to_json())

    @route("/api/v1/mods/<mod_id>/authors")
    @json
    async def get_authors(self, mod_id: str):
        if not await Mod.exists(mod_id):
            abort(404, "Unknown mod")

        author_pairs = await ModAuthor.query.where(ModAuthor.mod_id == mod_id
                                                   ).gino.all()
        author_pairs = [x.user_id for x in author_pairs]
        authors = await User.query.where(User.id.in_(author_pairs)).gino.all()

        return jsonify(self.dict_all(authors))

    # This handles POST requests to add zip_url.
    # Usually this would be done via a whole entry but this
    # is designed for existing content.
    @route("/api/v1/mods/<mod_id>/upload_content", methods=["POST"])
    @json
    @requires_supporter
    @requires_login
    async def upload(self, mod_id: str):
        if not await Mod.exists(mod_id):
            abort(404, "Unknown mod")

        abort(501, "Coming soon")

    @route("/api/v1/mods/<mod_id>/report", methods=["POST"])
    @json
    @use_kwargs(
        {
            "content":
            fields.Str(required=True,
                       validate=validate.Length(min=100, max=1000)),
            "type_":
            EnumField(ReportType, required=True),
            "recaptcha":
            fields.Str(required=True)
        },
        locations=("json", ))
    @requires_login
    @limiter.limit("2 per hour")
    async def report_mod(self, mod_id: str, content: str, type_: ReportType,
                         recaptcha: str):
        await verify_recaptcha(recaptcha, self.core.aioh_sess, 2)

        token = request.headers.get("Authorization",
                                    request.cookies.get("token"))
        parsed_token = await jwt_service.verify_login_token(token, True)
        user_id = parsed_token["id"]

        report = await Report.create(content=content,
                                     author_id=user_id,
                                     mod_id=mod_id,
                                     type=type_)
        return jsonify(report.to_dict())
示例#23
0
from webargs import fields, validate

login_args = {
    'email': fields.String(required=True, validate=validate.Email()),
    'password': fields.String(required=True)
}

register_args = {
    'email':
    fields.String(required=True, validate=validate.Email()),
    'username':
    fields.String(required=True,
                  validate=validate.Regexp(
                      regex='^[a-zA-Z0-9_]{6,64}$',
                      error='Username must be an alphanumeric string '
                      'between 6-64 characters in length')),
    'password':
    fields.String(required=True,
                  validate=validate.Length(
                      min=8,
                      error='Password must be atleast 8 characters long'))
}
示例#24
0
class SuperUserEmailView(SuperBaseView):
    """总后台-用户-验证邮箱&b绑定邮箱&激活邮箱"""
    @use_args(
        {
            "sign": fields.String(required=True, comment="加密认证"),
            "timestamp": fields.Integer(required=True, comment="时间戳"),
            "user_id": fields.Integer(required=True, comment="用户ID"),
            "token": fields.String(required=True, comment="验证token"),
        },
        location="query")
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def get(self, request, args):
        token = args["token"]
        # 验证token
        user = User.check_verify_email_token(token)
        if user is None:
            return self.send_error(status_code=status.HTTP_400_BAD_REQUEST,
                                   error_message={"detail": "链接信息无效"})
        else:
            user.email_active = True
            user.save()
            return self.send_success()

    @use_args(
        {
            "email":
            fields.String(required=True,
                          validate=[validate.Regexp(EMAIL_RE)],
                          comment="邮箱")
        },
        location="json")
    def put(self, request, args):
        user = self._get_current_user(request)
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        check_user = get_user_by_email(args["email"])
        if check_user:
            return self.send_fail(error_text="该邮箱已绑定其他用户")
        serializer = EmailSerializer(user, data=args)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return self.send_success()

    @use_args(
        {
            "email":
            fields.String(required=True,
                          validate=[validate.Regexp(EMAIL_RE)],
                          comment="邮箱")
        },
        location="json")
    def post(self, request, args):
        user = self._get_current_user(request)
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        success, info = send_email(user, args["email"])
        if not success:
            return self.send_fail(error_text=info)
        return self.send_success()
示例#25
0
class AdminUserView(UserBaseView):
    """后台-用户-登录注册"""
    # 父类进行了登录验证,这里覆盖掉
    authentication_classes = ()

    @use_args(
        {
            "code":
            fields.String(required=False, comment="微信code"),
            "phone":
            fields.String(required=False,
                          validate=[validate.Regexp(PHONE_RE)],
                          comment="手机号"),
            "sms_code":
            fields.String(required=False, comment="短信验证码"),
            "password":
            DecryptPassword(required=False,
                            validate=[validate.Regexp(PASSWORD_RE)],
                            comment="密码"),
            "login_type":
            fields.Integer(required=True,
                           validate=[
                               validate.OneOf([
                                   UserLoginType.WX, UserLoginType.PWD,
                                   UserLoginType.PHONE
                               ])
                           ],
                           comment="登录方式,0:微信,1:密码,2:手机")
        },
        location="json",
    )
    def post(self, request, args):
        login_type = args["login_type"]
        code = args.get("code", None)
        phone = args.get("phone", None)
        pwd = args.get("password", None)
        sms_code = args.get("sms_code", None)
        # 若登录方式为微信
        if login_type == UserLoginType.WX:
            if not code:
                return self.send_fail(error_text="微信登录缺少code")
        # 若登录方式为密码
        elif login_type == UserLoginType.PWD:
            if not phone and not pwd:
                return self.send_fail(error_text="密码登录缺手机号或密码")
            success, user = get_user_by_phone_and_password(
                phone, pwd, login_type)
            if not success:
                return self.send_fail(error_text=user)
            token, refresh_token = self._set_current_user(user)
            response_data = jwt_response_payload_handler(
                token, refresh_token, user, request)
            return self.send_success(data=response_data)
        # 若登陆方式为手机号
        else:
            if not phone and not sms_code:
                return self.send_fail(error_text="密码登录缺手机号或验证码")
            redis_conn = get_redis_connection("verify_codes")
            real_sms_code = redis_conn.get("sms_%s" % phone)
            if not real_sms_code:
                return self.send_fail(error_text="验证码已过期")
            if str(real_sms_code.decode()) != sms_code:
                return self.send_error(status_code=status.HTTP_400_BAD_REQUEST,
                                       error_message={"detail": "短信验证码错误"})
            success, user = get_user_by_phone(phone, login_type)
            if not success:
                return self.send_fail(error_text=user)
            # user不存在,进行注册
            if not user:
                data = {
                    "phone":
                    phone,
                    "username":
                    phone,
                    "nickname":
                    "用户{phone}".format(phone=phone),
                    "head_image_url":
                    "http://img.senguo.cc/FlMKOOnlycuoZp1rR39LyCFUHUgl"
                }
                serializer = UserCreateSerializer(data=data)
                serializer.is_valid()
                user = serializer.save()
            token, refresh_token = self._set_current_user(user)
            response_data = jwt_response_payload_handler(
                token, refresh_token, user, request)
            return self.send_success(data=response_data)
示例#26
0
class PersonTimelineResource(ProtectedResource, GrampsJSONEncoder):
    """Person timeline resource."""
    @use_args(
        {
            "ancestors":
            fields.Integer(missing=1, validate=validate.Range(min=1, max=5)),
            "dates":
            fields.Str(
                missing=None,
                validate=validate.Regexp(
                    r"^-[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-"
                    r"[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$"),
            ),
            "discard_empty":
            fields.Boolean(missing=True),
            "event_classes":
            fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(choices=EVENT_CATEGORIES),
            ),
            "events":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "first":
            fields.Boolean(missing=True),
            "keys":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "last":
            fields.Boolean(missing=True),
            "locale":
            fields.Str(missing=None),
            "offspring":
            fields.Integer(missing=1, validate=validate.Range(min=1, max=5)),
            "omit_anchor":
            fields.Boolean(missing=True),
            "page":
            fields.Integer(missing=0, validate=validate.Range(min=1)),
            "pagesize":
            fields.Integer(missing=20, validate=validate.Range(min=1)),
            "precision":
            fields.Integer(missing=1, validate=validate.Range(min=1, max=3)),
            "ratings":
            fields.Boolean(missing=False),
            "relative_event_classes":
            fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(choices=EVENT_CATEGORIES),
            ),
            "relative_events":
            fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)), ),
            "relatives":
            fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(choices=RELATIVES),
            ),
            "skipkeys":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "strip":
            fields.Boolean(missing=False),
        },
        location="query",
    )
    def get(self, args: Dict, handle: str):
        """Get list of events in timeline for a person."""
        locale = get_locale_for_language(args["locale"], default=True)
        events = prepare_events(args)
        relatives = []
        if "relatives" in args:
            relatives = args["relatives"]
        relative_events = []
        if "relative_events" in args:
            relative_events = args["relative_events"]
        if "relative_event_classes" in args:
            relative_events = relative_events + args["relative_event_classes"]
        try:
            timeline = Timeline(
                get_db_handle(),
                dates=args["dates"],
                events=events,
                ratings=args["ratings"],
                relatives=relatives,
                relative_events=relative_events,
                discard_empty=args["discard_empty"],
                omit_anchor=args["omit_anchor"],
                precision=args["precision"],
                locale=locale,
            )
            timeline.add_person(
                Handle(handle),
                anchor=True,
                start=args["first"],
                end=args["last"],
                ancestors=args["ancestors"],
                offspring=args["offspring"],
            )
        except ValueError:
            abort(422)
        except HandleError:
            abort(404)

        payload = timeline.profile(page=args["page"],
                                   pagesize=args["pagesize"])
        return self.response(200,
                             payload,
                             args,
                             total_items=len(timeline.timeline))
示例#27
0
class GrampsObjectsResource(GrampsObjectResourceHelper, Resource):
    """Resource for multiple objects."""

    @use_args(
        {
            "backlinks": fields.Boolean(missing=False),
            "dates": fields.Str(
                missing=None,
                validate=validate.Regexp(
                    r"^([0-9]+|\*)/([1-9]|1[0-2]|\*)/([1-9]|1[0-9]|2[0-9]|3[0-1]|\*)$|"
                    r"^-[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-"
                    r"[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$"
                ),
            ),
            "extend": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(
                    choices=[
                        "all",
                        "citation_list",
                        "event_ref_list",
                        "family_list",
                        "note_list",
                        "parent_family_list",
                        "person_ref_list",
                        "primary_parent_family",
                        "place",
                        "source_handle",
                        "father_handle",
                        "mother_handle",
                        "media_list",
                        "reporef_list",
                        "tag_list",
                        "backlinks",
                        "child_ref_list",
                    ]
                ),
            ),
            "filter": fields.Str(validate=validate.Length(min=1)),
            "formats": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1))
            ),
            "format_options": fields.Str(validate=validate.Length(min=1)),
            "gramps_id": fields.Str(validate=validate.Length(min=1)),
            "keys": fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "locale": fields.Str(missing=None, validate=validate.Length(min=1, max=5)),
            "page": fields.Integer(missing=0, validate=validate.Range(min=1)),
            "pagesize": fields.Integer(missing=20, validate=validate.Range(min=1)),
            "profile": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(
                    choices=[
                        "all",
                        "self",
                        "families",
                        "events",
                        "age",
                        "span",
                        "ratings",
                        "references",
                    ]
                ),
            ),
            "rules": fields.Str(validate=validate.Length(min=1)),
            "skipkeys": fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1))
            ),
            "sort": fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "soundex": fields.Boolean(missing=False),
            "strip": fields.Boolean(missing=False),
            "filemissing": fields.Boolean(missing=False),
        },
        location="query",
    )
    def get(self, args: Dict) -> Response:
        """Get all objects."""
        locale = get_locale_for_language(args["locale"], default=True)
        if "gramps_id" in args:
            obj = self.get_object_from_gramps_id(args["gramps_id"])
            if obj is None:
                abort(404)
            return self.response(
                200, [self.full_object(obj, args, locale=locale)], args, total_items=1
            )

        query_method = self.db_handle.method("get_%s_handles", self.gramps_class_name)
        if self.gramps_class_name in ["Event", "Repository", "Note"]:
            handles = query_method()
        else:
            handles = query_method(sort_handles=True, locale=locale)

        if "filter" in args or "rules" in args:
            handles = apply_filter(
                self.db_handle, args, self.gramps_class_name, handles
            )

        if self.gramps_class_name == "Media" and args.get("filemissing"):
            handles = get_missing_media_file_handles(self.db_handle, handles)

        if args["dates"]:
            handles = self.match_dates(handles, args["dates"])

        if "sort" in args:
            handles = self.sort_objects(handles, args["sort"], locale=locale)

        total_items = len(handles)

        if args["page"] > 0:
            offset = (args["page"] - 1) * args["pagesize"]
            handles = handles[offset : offset + args["pagesize"]]

        query_method = self.db_handle.method(
            "get_%s_from_handle", self.gramps_class_name
        )
        return self.response(
            200,
            [
                self.full_object(query_method(handle), args, locale=locale)
                for handle in handles
            ],
            args,
            total_items=total_items,
        )

    def post(self) -> Response:
        """Post a new object."""
        require_permissions([PERM_ADD_OBJ])
        obj = self._parse_object()
        if not obj:
            abort(400)
        db_handle = self.db_handle_writable
        with DbTxn("Add objects", db_handle) as trans:
            try:
                add_object(db_handle, obj, trans, fail_if_exists=True)
            except ValueError:
                abort(400)
            trans_dict = transaction_to_json(trans)
        # update search index
        indexer: SearchIndexer = current_app.config["SEARCH_INDEXER"]
        with indexer.get_writer(overwrite=False, use_async=True) as writer:
            for _trans_dict in trans_dict:
                handle = _trans_dict["handle"]
                class_name = _trans_dict["_class"]
                indexer.add_or_update_object(writer, handle, db_handle, class_name)
        return self.response(201, trans_dict, total_items=len(trans_dict))
示例#28
0
class TimelineFamiliesResource(ProtectedResource, GrampsJSONEncoder):
    """Families timeline resource."""
    @use_args(
        {
            "dates":
            fields.Str(
                missing=None,
                validate=validate.Regexp(
                    r"^-[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-$|"
                    r"^[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])-"
                    r"[0-9]+/([1-9]|1[0-2])/([1-9]|1[0-9]|2[0-9]|3[0-1])$"),
            ),
            "discard_empty":
            fields.Boolean(missing=True),
            "event_classes":
            fields.DelimitedList(
                fields.Str(validate=validate.Length(min=1)),
                validate=validate.ContainsOnly(choices=EVENT_CATEGORIES),
            ),
            "events":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "filter":
            fields.Str(validate=validate.Length(min=1)),
            "keys":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "handles":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "locale":
            fields.Str(missing=None, validate=validate.Length(min=1, max=5)),
            "page":
            fields.Integer(missing=0, validate=validate.Range(min=1)),
            "pagesize":
            fields.Integer(missing=20, validate=validate.Range(min=1)),
            "ratings":
            fields.Boolean(missing=False),
            "rules":
            fields.Str(validate=validate.Length(min=1)),
            "skipkeys":
            fields.DelimitedList(fields.Str(validate=validate.Length(min=1))),
            "strip":
            fields.Boolean(missing=False),
        },
        location="query",
    )
    def get(self, args: Dict):
        """Get consolidated list of events in timeline for a list of families."""
        db_handle = get_db_handle()
        locale = get_locale_for_language(args["locale"], default=True)
        events = prepare_events(args)
        try:
            timeline = Timeline(
                db_handle,
                dates=args["dates"],
                events=events,
                ratings=args["ratings"],
                discard_empty=args["discard_empty"],
                locale=locale,
            )
        except ValueError:
            abort(422)

        if "handles" in args:
            handles = args["handles"]
        else:
            handles = db_handle.get_family_handles(sort_handles=True,
                                                   locale=locale)

        try:
            if "filter" in args or "rules" in args:
                handles = apply_filter(db_handle, args, "Family", handles)

            for handle in handles:
                timeline.add_family(handle)
        except HandleError:
            abort(404)

        payload = timeline.profile(page=args["page"],
                                   pagesize=args["pagesize"])
        return self.response(200,
                             payload,
                             args,
                             total_items=len(timeline.timeline))
示例#29
0
class TemplateResource(Resource):
    get_args = {
        'with_text':
        fields.Boolean(required=False, missing=False, location='query'),
        'template_id':
        fields.Integer(location='view_args')
    }

    @staticmethod
    @use_kwargs(get_args)
    def get(template_id, with_text):
        with make_session() as session:
            template = session.query(Template).filter(
                Template.id == template_id).first()  # type: Template
            if template is None:
                raise NotFound("Requested template does not exist")

            return marshal(
                template,
                template_fields_with_text if with_text else template_fields)

    patch_args = {
        'name':
        fields.String(required=False),
        'text':
        fields.String(required=False),
        'variables':
        fields.Nested({
            'delete':
            fields.List(fields.Integer(required=True)),
            'update':
            fields.List(
                fields.Nested({
                    'id':
                    fields.Integer(required=True),
                    'description':
                    fields.String(required=False, missing=''),
                    'name':
                    fields.String(required=False,
                                  validate=validate.Regexp(
                                      re.compile('^[a-zA-Z_][a-zA-Z0-9_]*$')))
                })),
            'create':
            fields.List(
                fields.Nested({
                    'name':
                    fields.String(required=True,
                                  validate=validate.Regexp(
                                      re.compile('^[a-zA-Z_][a-zA-Z0-9_]*$'))),
                    'description':
                    fields.String(required=False, missing='')
                }))
        }),
        'template_id':
        fields.Integer(location='view_args')
    }

    @staticmethod
    @use_kwargs(patch_args)
    def patch(template_id, name, text, variables):
        with make_session() as session:
            template = session.query(Template).filter(
                Template.id == template_id).first()  # type: Template
            if template is None:
                raise NotFound("Requested template does not exist")

            if variables != missing:
                if 'delete' in variables:
                    for d in variables['delete']:
                        to_delete = session.query(Variable).filter(
                            Variable.id == d)
                        if to_delete.first().in_use():
                            raise Conflict(
                                'Cannot delete variable because it is in use')
                        to_delete.delete()
                if 'update' in variables:
                    for u in variables['update']:
                        if 'description' in u:
                            variable = session.query(Variable) \
                                .filter(Variable.id == u['id']) \
                                .first()
                            variable.description = u['description'].strip()
                            if 'name' in u:
                                if variable.name != u[
                                        'name'] and variable.in_use():
                                    raise Conflict(
                                        'Cannot change name of a variable that\'s in use'
                                    )
                                variable.name = u['name']
                if 'create' in variables:
                    for c in variables['create']:
                        if 'description' in c:
                            session.add(
                                Variable(template=template,
                                         name=c['name'].strip(),
                                         description=c['description'].strip()))

            if text != missing:
                template.text = text

            if name != missing:
                template.name = name

            return make_id_response(template_id)

    @staticmethod
    def delete(template_id):
        with make_session() as session:
            template = session.query(Template).filter(
                Template.id == template_id).first()  # type: Template
            if template is None:
                raise NotFound("Requested template does not exist")

            for variable in template.variables:
                if variable.in_use():
                    raise Conflict(
                        "Cannot delete the template because one or more variables are in use"
                    )

            session.delete(template)
            return make_empty_response()
示例#30
0
class StaffApplyView(MallBaseView):
    """后台-员工-提交员工申请&获取申请信息"""
    def get_tmp_class(self, status):
        """获取一个员工申请模板类"""
        class TMP:
            def __init__(self, status):
                self.status = status

        return TMP(status)

    def get(self, request, shop_code):
        user = self.current_user
        self._set_current_shop(request, shop_code)
        current_shop = self.current_shop
        staff_apply = get_staff_apply_by_user_id_and_shop_id(
            user.id, current_shop.id)
        # 没有审核记录的是超管或者第一次申请的人
        if not staff_apply:
            # 超管
            if current_shop.super_admin_id == user.id:
                staff_apply = self.get_tmp_class(StaffApplyStatus.PASS)
            else:
                staff_apply = self.get_tmp_class(StaffApplyStatus.UNAPPlY)
        serializer = StaffApplySerializer(staff_apply)
        return self.send_success(
            data=serializer.data,
            shop_info={"shop_name": current_shop.shop_name})

    @use_args(
        {
            "realname":
            fields.String(required=True,
                          validate=[validate.Length(1, 64)],
                          comment="真实姓名"),
            "phone":
            fields.String(
                required=False,
                validate=[validate.Regexp(PHONE_RE)],
                comment="手机号,已绑定的时候是不需要的",
            ),
            "code":
            fields.String(required=False,
                          validate=[validate.Regexp(r"^[0-9]{4}$")],
                          comment="验证码"),
            "birthday":
            fields.Date(required=False, allow_none=True, comment="生日"),
        },
        location="json")
    def post(self, request, args, shop_code):
        user = self.current_user
        self._set_current_shop(request, shop_code)
        current_shop = self.current_shop
        # 验证员工是否存在
        staff = get_staff_by_user_id_and_shop_id(user.id, current_shop.id)
        if staff:
            return self.send_fail(error_text="已经为该店铺的员工")
        # 验证是否已经提交申请
        staff_apply = get_staff_apply_by_user_id_and_shop_id(
            user.id, current_shop.id)
        if staff_apply:
            return self.send_fail(error_text="已提交申请,无需重复提交")
        serializer = StaffApplyCreateSerializer(data=args,
                                                context={'self': self})
        if not serializer.is_valid():
            return self.send_error(error_message=serializer.errors,
                                   status_code=status.HTTP_400_BAD_REQUEST)
        staff_apply = serializer.save()
        data = {
            "staff_apply_id": staff_apply.id,
            "status": staff_apply.status,
            "expired": staff_apply.expired,
            "user_id": staff_apply.user_id
        }
        return self.send_success(data=data)