示例#1
0
class AdminStaffListView(AdminBaseView):
    """后台-员工-员工列表"""
    pagination_class = StandardResultsSetPagination

    @AdminBaseView.permission_required(
        [AdminBaseView.staff_permissions.ADMIN_STAFF])
    @use_args(
        {
            "keyword":
            fields.String(required=False, comment="搜索关键字(姓名或手机号)"),
            "page":
            fields.Integer(required=False, missing=1, comment="页码"),
            "page_size":
            fields.Integer(required=False,
                           missing=20,
                           validate=[validate.Range(1)],
                           comment="每页条数"),
        },
        location="query")
    def get(self, request, args):
        page = args.get("page")
        shop_id = self.current_shop.id
        staff_list = list_staff_by_shop_id(shop_id, args.get("keyword"))
        # page为-1时,不分页
        if page > 0:
            staff_list = self._get_paginated_data(staff_list, StaffSerializer)
        else:
            # 适配前端参数要求
            staff_list = {
                'results': StaffSerializer(staff_list, many=True).data
            }
        return self.send_success(data_list=staff_list)
示例#2
0
class EditableSchema(Schema):
    id = fields.Integer(required=True)
    type = fields.String()
    state = fields.String(required=True)
    editor = fields.Nested(EditingUserSchema, allow_none=True)
    timeline_url = fields.String()
    revision_count = fields.Integer()
示例#3
0
class Publish(Resource):

    publish_args = {
        'uuid': fields.String(required=True),
        'batch_size': fields.Integer(required=True),
        'offset': fields.Integer(required=True)
    }

    @use_kwargs(publish_args)
    def post(self, uuid, batch_size, offset):

        logging.debug("uuid = %s" % uuid)
        logging.debug("batch_size = %d" % batch_size)
        logging.debug("offset= %d" % offset)

        raw_json = request.get_json()
        raw_data = raw_json["data"]
        logging.debug(raw_data)
        try:
            logging.debug("parsing data")
            parsedData = parser.parseData(uuid, offset, raw_data)
            logging.debug("sending data")
            sender.sendData(parsedData, batch_size)
            return uuid, 201
        except Exception as ex:
            logging.debug(str(ex))
            return str(ex), 500
示例#4
0
class TrainArgsSchema(Schema):
    class Meta:
        unknown = INCLUDE  # support 'full_paths' parameter

    # available fields are e.g. fields.Integer(), fields.Str(), fields.Boolean()
    # full list of fields: https://marshmallow.readthedocs.io/en/stable/api_reference.html
    dataset = fields.Str(
        required=False,
        missing="run01",
        description="processed dataset which should be used for training")

    save_name = fields.Str(
        required=False,
        missing="gan",
        description="save name which should be used in predict")

    cache_path = fields.Str(required=False,
                            missing="/home/tmp/koepke/cache",
                            description="cache basepath for dataset")

    epochs = fields.Integer(required=False,
                            missing=1,
                            description="number of epochs for training")

    batchsize = fields.Integer(required=False,
                               missing=1024,
                               description="number of epochs for training")
示例#5
0
class GetBlockDataByDateEndpoint(Resource):
    """
    Class implementing get block by date API
    """
    args = {"day": fields.Integer(),
            "month": fields.Integer(),
            "year": fields.Integer(),
            "date_offset": fields.Integer()
            }

    @use_kwargs(args)
    def get(self, year, month, day, date_offset):
        """
        Method for GET request
        :param year:
        :param month:
        :param day:
        :param date_offset:
        """
        # Validate User Input
        try:
            request = {"day": day, "month": month, "year": year, "date_offset": date_offset}
            validations_result = validate_block_input(year, month, day, date_offset)
            response = {}
            if validations_result is not None and len(validations_result) > 0:
                response = {"ResponseCode": ResponseCodes.InvalidRequestParameter.value,
                            "ResponseDesc": ResponseCodes.InvalidRequestParameter.name,
                            "ValidationErrors": validations_result}
            else:  # all valid

                from_time = datetime(int(year), int(month), int(day))
                to_time = from_time + timedelta(days=int(date_offset))

                from_unixtime = time.mktime(from_time.timetuple())  # get the unix time to form the query
                to_unixtime = time.mktime(to_time.timetuple())

                # perform the query
                block_data = db_session.query(Block).filter(
                    and_(Block.ntime >= from_unixtime, Block.ntime <= to_unixtime)).order_by(Block.ntime.asc())
                if block_data is not None and len(list(block_data)) != 0:
                    block_list = []
                    for block in block_data:
                        block_list.append(serialize_block(block))
                    response = {
                        "ResponseCode": ResponseCodes.Success.value,
                        "ResponseDesc": ResponseCodes.Success.name,
                        "FromDate": from_time.strftime('%Y-%m-%d %H:%M:%S'),
                        "ToDate": to_time.strftime('%Y-%m-%d %H:%M:%S'),
                        "NumberOfBlocks": len(block_list),
                        "Blocks": block_list}
                else:
                    response = {"ResponseCode": ResponseCodes.NoDataFound.value,
                                "ResponseDesc": ResponseCodes.NoDataFound.name,
                                "ErrorMessage": ResponseDescriptions.NoDataFound.value}
        except Exception as ex:
            response = {"ResponseCode": ResponseCodes.InternalError.value,
                        "ResponseDesc": ResponseCodes.InternalError.name,
                        "ErrorMessage": str(ex)}
        finally:
            return response
示例#6
0
class SuperUserPhoneView(SuperBaseView):
    """总后台-用户-修改用户手机号"""
    @use_args(
        {
            "sign":
            fields.String(required=True, comment="加密认证"),
            "timestamp":
            fields.Integer(required=True, comment="时间戳"),
            "user_id":
            fields.Integer(required=True, comment="用户ID"),
            "phone":
            fields.String(required=True,
                          validate=[validate.Regexp(PHONE_RE)],
                          comment="手机号"),
            "sms_code":
            fields.String(required=True, comment="短信验证码"),
        },
        location="json")
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def put(self, request, args):
        user = self._get_current_user(request)
        phone = args["phone"]
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        # 短信验证码校验
        success, info = validate_sms_code(phone, args["sms_code"])
        if not success:
            return self.send_fail(error_text=info)
        success, info = update_user_phone(user, phone)
        if not success:
            return self.send_fail(error_text=info)
        return self.send_success()
示例#7
0
class OddsList(Resource):

    odds_args = {
        'odds_name': fields.String(),
        'odds_type': fields.String(),
        'sport': fields.String(),
        'per_page': fields.Integer(),
        'page': fields.Integer(),
        'filter_name': fields.String(),
        'active': fields.Boolean(),
    }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._default_response = {
            'status': 'fail',
            'message': 'Invalid payload.',
        }

        self._create_odds_schema = CreateOddsSchema()

    @use_args(odds_args)
    @json_response
    def get(self, query_params):
        return get_odds_by(**query_params), 200

    @json_response
    def post(self):
        req = request.get_json()
        errors = self._create_odds_schema.validate(req)
        if errors:
            return self._default_response, 422
        return create_entity(Odds(**req))
示例#8
0
class LivingResource(ProtectedResource, GrampsJSONEncoder):
    """Living calculator resource."""
    @use_args(
        {
            "average_generation_gap":
            fields.Integer(missing=None, validate=validate.Range(min=1)),
            "max_age_probably_alive":
            fields.Integer(missing=None, validate=validate.Range(min=1)),
            "max_sibling_age_difference":
            fields.Integer(missing=None, validate=validate.Range(min=1)),
        },
        location="query",
    )
    def get(self, args: Dict, handle: Handle) -> Response:
        """Determine if person alive."""
        db_handle = get_db_handle()
        person = get_person_by_handle(db_handle, handle)
        if person == {}:
            abort(404)

        data = probably_alive(
            person,
            db_handle,
            max_sib_age_diff=args["max_sibling_age_difference"],
            max_age_prob_alive=args["max_age_probably_alive"],
            avg_generation_gap=args["average_generation_gap"],
        )
        return self.response(200, {"living": data})
示例#9
0
class TrainArgsSchema(Schema):
    class Meta:
        unknown = INCLUDE  # supports extra parameters

    train_epochs = fields.Integer(
        required=False,
        missing=10,
        description="Number of training epochs")

    batch_size = fields.Integer(
        missing=16,
        description='Global Batch size',
        required=False)

    num_gpus =  fields.Integer(
        missing=1,
        validate=gpus_must_exist,
        description='Number of GPUs to use, if available (0 = CPU)',
        required=False)
    
    upload_back = fields.Boolean(
        missing=False,
        enum=[False, True],
        description='Either upload a trained graph back to the remote storage (True) or not (False, default)',
        required=False)
示例#10
0
class SimilarAnime:
    request_args = {
        'to': fields.Str(required=True),
        'by_genre': fields.Str(),
        'limit': fields.Integer(missing='10'),
        'offset': fields.Integer(missing='0'),
    }

    def __init__(self):
        self.resource = AnimeModel()

    @use_args(request_args)
    def on_get(self, req: Request, resp: Response, q_args: Dict) -> None:
        anime_title = q_args['to']
        limit = q_args['limit']
        offset = q_args['offset']
        genre = q_args['by_genre'] if 'by_genre' in q_args else None
        json_result: List[Dict] = self.resource.get_similar_anime(anime_title,
                                                                  limit,
                                                                  offset,
                                                                  genre=genre)

        if not json_result:
            raise falcon.HTTPNotFound(
                title='Animes not Found',
                description=
                f'No animes are similar to {anime_title} by {genre}',
            )

        resp.json = json_result
        resp.status = falcon.HTTP_200
示例#11
0
class GeomagField(Resource):

    schema_args = {
        "lat":
        fields.Float(required=True, validate=lambda dLat: -90 <= dLat <= 180
                     ),  # accepts normalised [-90, 90] or unormalised [0, 180]
        "lng":
        fields.Float(
            required=True, validate=lambda dLng: -180 <= dLng <= 360
        ),  # accepts normalised [-180, 180] or unormalised [0, 360]
        "altitude_km":
        fields.Float(required=True, validate=lambda km: -1 < km <= 850),
        "yr":
        fields.Integer(required=True, validate=lambda y: 2020 <= y < 2025),
        "mth":
        fields.Integer(required=True, validate=lambda m: 1 <= m <= 12),
        "day":
        fields.Integer(required=True, validate=lambda d: 1 <= d < 31),
    }

    @use_kwargs(schema_args, location="query")
    def get(self, lat, lng, altitude_km, yr, mth, day):
        result = calculate_field(lat, lng, altitude_km,
                                 datetime.date(yr, mth, day))
        return result
示例#12
0
class TrainArgsSchemaPro(Schema):
    class Meta:
        unknown = INCLUDE  # supports extra parameters

    batch_size_per_device = fields.Integer(
        missing=64, description='Batch size for each GPU.', required=False)
    dataset = fields.Str(
        missing='synthetic_data',
        enum=['synthetic_data', 'imagenet', 'imagenet_mini', 'cifar10'],
        description='Dataset to perform training on. \
                         synthetic_data: randomly generated ImageNet-like \
                         images; imagenet_mini: 3% of the real ImageNet \
                         dataset',
        required=False)
    model = fields.Str(missing='resnet50 (ImageNet)',
                       enum=[
                           'googlenet (ImageNet)', 'inception3 (ImageNet)',
                           'mobilenet (ImageNet)', 'overfeat (ImageNet)',
                           'resnet50 (ImageNet)', 'resnet152 (ImageNet)',
                           'vgg16 (ImageNet)', 'vgg19 (ImageNet)',
                           'resnet56 (Cifar10)', 'resnet110 (Cifar10)',
                           'alexnet (ImageNet, Cifar10)'
                       ],
                       description='CNN model for training. N.B. Models only \
                       support specific data sets, given in brackets. \
                       synthetic_data can only be processed by ImageNet models.',
                       required=False)
    num_gpus = fields.Integer(missing=1,
                              description='Number of GPUs to train on \
                              (one node only). If set to zero, CPU is used.',
                              required=False)
    num_epochs = fields.Float(missing=NUM_EPOCHS,
                              description='Number of epochs to \
                              train on (float value, < 1.0 allowed).',
                              required=False)
    optimizer = fields.Str(missing='sgd',
                           enum=['sgd', 'momentum', 'rmsprop', 'adam'],
                           description='Optimizer to use.',
                           required=False)
    use_fp16 = fields.Boolean(missing=False,
                              enum=[False, True],
                              description='Use 16-bit floats for certain \
                              tensors instead of 32-bit floats. ',
                              required=False)
    weight_decay = fields.Float(missing=4.0e-5,
                                description='Weight decay factor for training',
                                required=False)
    evaluation = fields.Boolean(missing=True,
                                enum=[False, True],
                                description='Perform evaluation after the \
                                benchmark in order to get accuracy results \
                                (only meaningful on real data sets!).',
                                required=False)
    if_cleanup = fields.Boolean(missing=False,
                                enum=[False, True],
                                description='If to delete training and \
                              evaluation directories.',
                                required=False)
示例#13
0
class RHAPIRegistrationTagsAssign(RHRegistrationsActionBase):
    """Internal API to assign and remove registration tags."""
    @use_kwargs({
        'add': fields.List(fields.Integer(), load_default=[]),
        'remove': fields.List(fields.Integer(), load_default=[])
    })
    def _process_POST(self, add, remove):
        _assign_registration_tags(self.event, self.registrations, add, remove)
        return '', 204
示例#14
0
文件: schema.py 项目: dallanb/course
class DumpHoleSchema(Schema):
    uuid = fields.UUID()
    course_uuid = fields.UUID()
    ctime = fields.Integer()
    mtime = fields.Integer()
    name = fields.String()
    number = fields.Integer()
    par = fields.Integer()
    distance = fields.Integer()
示例#15
0
class SuperUserView(SuperBaseView):
    """总后台-用户-获取用户详情&修改用户基本信息"""
    @use_args(
        {
            "sign": fields.String(required=True, comment="加密认证"),
            "timestamp": fields.Integer(required=True, comment="时间戳"),
            "user_id": fields.Integer(required=True, comment="用户ID"),
        },
        location="query")
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def get(self, request, args):
        user = self._get_current_user(request)
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        serializer = SuperUserSerializer(user)
        return self.send_success(data=serializer.data)

    @use_args({
        "sign":
        fields.String(required=True, comment="加密认证"),
        "timestamp":
        fields.Integer(required=True, comment="时间戳"),
        "user_id":
        fields.Integer(required=True, comment="用户ID"),
        "nickname":
        fields.String(required=False,
                      validate=[validate.Length(1, 15)],
                      comment="用户昵称"),
        "realname":
        fields.String(required=False,
                      validate=[validate.Length(1, 15)],
                      comment="用户真实姓名"),
        "sex":
        fields.Integer(
            required=False,
            validate=[validate.OneOf([Sex.UNKNOWN, Sex.FEMALE, Sex.MALE])],
        ),
        "birthday":
        fields.Date(required=False, comment="出生日期"),
        "head_image_url":
        fields.String(required=False,
                      validate=[validate.Length(0, 1024)],
                      comment="用户头像")
    })
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def put(self, request, args):
        user = self._get_current_user(request)
        if not user:
            return self.send_error(status_code=status.HTTP_401_UNAUTHORIZED,
                                   error_message={"error_text": "用户未登录"})
        if not args:
            return self.send_fail(error_text="参数有误")
        user = update_user_basic_data(user, args)
        serializer = UserSerializer(user)
        return self.send_success(data=serializer.data)
示例#16
0
class PlayersList(HTTPEndpoint):
    @use_args({"limit": fields.Integer(missing=25, min=1, max=50),
               "offset": fields.Integer(missing=25, min=1, max=50),
               "search": fields.String(),
               "desc": fields.Bool(missing=True), })
    async def get(self, request, args):
        """ List players. """

        return responder.render(
            await request.state.league.list(**args).players()
        )
示例#17
0
def test_lists(location):
    req = {
        'query': MockRequest(args={'single': '123', 'multi': ['456', '789']}),
        'form': MockRequest(form={'single': '123', 'multi': ['456', '789']}),
        'json': MockRequest(json={'single': 123, 'multi': [456, 789]}),
    }[location]
    fn = make_decorated_func(use_kwargs, {
        'single': fields.Integer(),
        'multi': fields.List(fields.Integer()),
    }, req, location=location)
    assert fn() == {'single': 123, 'multi': [456, 789]}
示例#18
0
class BookSchema(ma.Schema):
    title = fields.String(dump_only=True)
    id = fields.Integer(dump_only=True)
    gutenberg_id = fields.Integer()
    download_count = fields.Integer()

    @post_dump()
    def dump_related_data(self, data, many):
        author_ids = [
            _.author_id
            for _ in db.session.query(m.book_authors.c.author_id).filter(
                m.book_authors.c.book_id == data['id']).all()
        ]
        author_rows = db.session.query(m.author).filter(
            m.author.c.id.in_(author_ids)).all()
        author_schema = AuthorSchema()
        data['authors'] = author_schema.dump(author_rows, many=True)

        language_ids = [
            _.language_id
            for _ in db.session.query(m.book_languages.c.language_id).filter(
                m.book_languages.c.book_id == data['id']).all()
        ]
        language_rows = db.session.query(m.language).filter(
            m.language.c.id.in_(language_ids)).all()
        language_schema = LanguageSchema()
        data['languages'] = language_schema.dump(language_rows, many=True)

        shelf_ids = [
            _.bookshelf_id
            for _ in db.session.query(m.book_bookshelves.c.bookshelf_id).
            filter(m.book_bookshelves.c.book_id == data['id']).all()
        ]
        shelf_rows = db.session.query(m.bookshelf).filter(
            m.bookshelf.c.id.in_(shelf_ids)).all()
        shelf_schema = BookShelfSchema()
        data['bookshelfs'] = shelf_schema.dump(shelf_rows, many=True)

        sub_ids = [
            _.subject_id
            for _ in db.session.query(m.book_subjects.c.subject_id).filter(
                m.book_subjects.c.book_id == data['id']).all()
        ]
        sub_rows = db.session.query(m.subject).filter(
            m.subject.c.id.in_(sub_ids)).all()
        sub_schema = SubjectSchema()
        data['subjects'] = sub_schema.dump(sub_rows, many=True)

        format_rows = db.session.query(
            m.book_format).filter(m.book_format.c.book_id == data['id']).all()
        format_schema = FormatSchema()
        data['download_links'] = format_schema.dump(format_rows, many=True)

        return data
示例#19
0
class BaseReadArgsSchema(ma.Schema):
    """
    A base schema for reading filters, pagination, sort from args using
    webargs.
    """
    page = fields.Integer(missing=1)
    per_page = fields.Integer(missing=25)
    sort_by = fields.List(fields.String(), missing=['download_count'])
    sort = fields.String(validate=validate.OneOf(['asc', 'desc']),
                         missing='desc')
    operator = fields.String(validate=validate.OneOf(['and', 'or']),
                             missing='or')
示例#20
0
文件: list.py 项目: ttibau/API
class MatchesList(HTTPEndpoint):
    @use_args({
        "limit": fields.Integer(missing=25, min=1, max=50),
        "offset": fields.Integer(missing=0, min=1, max=50),
        "search": fields.String(),
        "desc": fields.Bool(missing=True),
    })
    async def get(self, request, args):
        """ Gets list of matches. """

        return Responder(await
                         request.state.league.list(**args).matches()).json()
示例#21
0
class DumpSportSchema(Schema):
    uuid = fields.UUID()
    ctime = fields.Integer()
    mtime = fields.Integer()
    name = fields.String()

    def get_attribute(self, obj, attr, default):
        return getattr(obj, attr, default)

    @post_dump
    def make_obj(self, data, **kwargs):
        return data
示例#22
0
class Data(Resource):
    """ Market data at an instance in time """
    @use_kwargs({'id': fields.Integer(missing=None)})
    @validate_db(db)
    def get(self, id):
        if id is None:
            return [query_to_dict(q) for q in DataModel.query.all()]
        else:
            return query_to_dict(DataModel.query.get_or_404(id))

    @use_kwargs(data_kwargs)
    @validate_db(db)
    def post(self, id, low, high, close, volume):
        try:
            post_request = DataModel(id, low, high, close, volume)
            db.session.add(post_request)
            db.session.commit()
        except:  # ID already exists, use PUT
            abort(HTTP_CODES.UNPROCESSABLE_ENTITY)
        else:
            return query_to_dict(post_request)

    @use_kwargs(data_kwargs)
    @validate_db(db)
    def put(self, id, low, high, close, volume):
        """ Loop through function args, only change what is specified
        NOTE: Arg values of -1 clears since each must be >= 0 to be valid
        """
        query = DataModel.query.get_or_404(id)
        for arg, value in locals().items():
            if arg is not 'id' and arg is not 'self' and value is not None:
                if value == -1:
                    setattr(query, arg, None)
                else:
                    setattr(query, arg, value)
        db.session.commit()
        return query_to_dict(query)

    @use_kwargs({'id': fields.Integer(missing=None)})
    @validate_db(db)
    def delete(self, id):
        try:
            if id is None:
                DataModel.query.delete()
                db.session.commit()
            else:
                db.session.delete(DataModel.query.get_or_404(id))
                db.session.commit()
        except:
            return {'status': 'failed'}
        else:
            return {'status': 'successful'}
示例#23
0
def search_args():
    """Defines and validates params for index"""
    return {
        "search": fields.String(missing=None),
        "team_id": fields.UUID(missing=None),
        "types": fields.String(load_from="type", missing="image"),
        "pipeline": fields.Integer(),
        "start_date": fields.DateTime(),
        "end_date": fields.DateTime(),
        "offset": fields.Integer(missing=0),
        "limit": fields.Integer(missing=12),
        "notify_clients": fields.Boolean(missing=False),
    }
示例#24
0
class ManageRequest(restful.Resource):
    @use_args(
        {
            'id': fields.Integer(required=True),
            'state': fields.Integer(required=True)
        },
        location='json')
    def post(self, args):
        c = g.db.cursor()
        c.execute("update callup_request set state = ? where id = ?",
                  (args['state'], args['id']))
        g.db.commit()
        return {'result': 'success'}
示例#25
0
class SuperShopVerifyView(UserBaseView):
    """总后台-修改店铺认证状态"""
    @use_args(
        {
            "sign":
            fields.String(required=True, comment="加密认证"),
            "timestamp":
            fields.Integer(required=True, comment="时间戳"),
            "user_id":
            fields.Integer(required=True, comment="用户ID"),
            "shop_id":
            fields.Integer(
                required=True, validate=[validate.Range(1)], comment="店铺ID"),
            "verify_status":
            fields.Integer(
                required=True,
                validate=[
                    validate.OneOf([
                        ShopVerifyActive.YES,
                        ShopVerifyActive.CHECKING,
                        ShopVerifyActive.REJECTED,
                    ])
                ],
                comment="店铺认证状态",
            ),
            "verify_type":
            fields.Integer(
                required=True,
                validate=[
                    validate.OneOf(
                        [ShopVerifyType.ENTERPRISE, ShopVerifyType.INDIVIDUAL])
                ],
                comment="店铺认证类型,个人/企业",
            ),
            "verify_content":
            fields.String(required=True,
                          validate=[validate.Length(0, 200)],
                          comment="认证内容"),
        },
        location="json")
    @SuperBaseView.validate_sign("sign", ("user_id", "timestamp"))
    def put(self, request, args):
        shop = get_shop_by_shop_id(args.pop("shop_id"))
        if not shop:
            return self.send_fail(error_text="店铺不存在")
        serializer = SuperShopVerifySerializer(shop, data=args)
        if not serializer.is_valid():
            return self.send_error(error_message=serializer.errors,
                                   status_code=status.HTTP_400_BAD_REQUEST)
        serializer.save()
        return self.send_success()
示例#26
0
def _make_request_schema(require_all: bool = False) -> dict:
    """Create an expected schema for a request body or query.

    Args:
        require_all (bool): True to require that all fields be present.
    """
    return {
        'departure_date': webargs_fields.String(required=require_all),  # pylint: disable=E1101
        'capacity': webargs_fields.Integer(required=require_all),  # pylint: disable=E1101
        'time_range_id': webargs_fields.Integer(required=require_all),  # pylint: disable=E1101
        'driver_id': webargs_fields.Integer(required=require_all),  # pylint: disable=E1101
        'start_location_id': webargs_fields.Integer(required=require_all),  # pylint: disable=E1101
        'destination_id': webargs_fields.Integer(required=require_all)  # pylint: disable=E1101
    }
示例#27
0
class AdminOrderRefundView(AdminBaseView):
    """后台-订单-退款"""
    @AdminBaseView.permission_required(
        [AdminBaseView.staff_permissions.ADMIN_ORDER])
    @use_args(
        {
            "order_id":
            fields.Integer(
                required=True, validate=[validate.Range(1)], comment="订单ID"),
            "refund_type":
            fields.Integer(
                required=True,
                validate=[
                    validate.OneOf([
                        OrderRefundType.WEIXIN_JSAPI_REFUND,
                        OrderRefundType.UNDERLINE_REFUND,
                    ])
                ],
            ),
        },
        location="json",
    )
    def post(self, request, args):
        shop_id = self.current_shop.id
        order = get_order_by_shop_id_and_id(shop_id, args.get("order_id"))
        if not order:
            return self.send_fail(error_text="订单不存在")
        elif order.order_status not in [
                OrderStatus.PAID,
                OrderStatus.CONFIRMED,
                OrderStatus.FINISHED,
                OrderStatus.REFUND_FAIL,
        ]:
            return self.send_fail(error_text="订单状态已改变")
        if (order.pay_type == OrderPayType.ON_DELIVERY and args["refund_type"]
                == OrderRefundType.WEIXIN_JSAPI_REFUND):
            return self.send_fail(error_text="货到付款的订单只能进行线下退款")
        success, msg = refund_order(
            self.current_shop.id,
            order,
            args["refund_type"],
            self.current_user.id,
        )
        if not success:
            return self.send_fail(error_obj=msg)
        # 获取店铺的消息提醒设置, 发送微信模板消息
        msg_notify = get_msg_notify_by_shop_id_interface(shop_id)
        # if msg_notify.order_refund_wx and order.pay_type == OrderPayType.WEIXIN_JSAPI:
        #     order_refund_tplmsg_interface(order.id)
        return self.send_success()
示例#28
0
class Prediction(Resource):
    predictionRequest = {
        'Work_accident': fields.Integer(required=True),
        'time_spend_company': fields.Integer(required=True),
        'salary': fields.Integer(required=True),
        'number_project': fields.Integer(required=True),
        'satisfaction_level': fields.Integer(required=True),
        'last_evaluation': fields.Integer(required=True),
        'average_montly_hours': fields.Integer(required=True),
        'promotion_last_5years': fields.Integer(required=True),
    }

    def get(self):
        return {}

    @use_args(predictionRequest)
    def post(self, request):
        to_predict_list = []
        for key in request:
            to_predict_list.append(request[key])

        to_predict_list = list(map(float, to_predict_list))
        result = PredictorHelper.predictionValue(to_predict_list)[0]
        print(result)

        if int(result) == 1:
            return {
                'messageResult': 'El empleado va a dejar la empresa',
                'predictionResult': result
            }
        return {
            'messageResult': 'El empleado no va a dejar la empresa',
            'predictionResult': result
        }
示例#29
0
class AdminProductGroupsView(AdminBaseView):
    """后台-货品-批量更新货品分组"""

    @AdminBaseView.permission_required([AdminBaseView.staff_permissions.ADMIN_PRODUCT])
    @use_args(
        {
            "product_ids": fields.List(
                fields.Integer(required=True),
                required=True,
                validate=[validate.Length(1)],
                comment="货品ID列表",
            ),
            "group_id": fields.Integer(required=True, comment="货品分组ID"),
        },
        location="json"
    )
    def put(self, request, args):
        shop = self.current_shop
        group_id = args.get("group_id")
        product_ids = args.pop("product_ids")
        # 校验分组是否存在
        product_group = get_product_group_by_shop_id_and_id(shop.id, group_id)
        if not product_group:
            return self.send_fail(error_text="货品分组不存在")
        # 获取货品,更新货品信息
        update_product_product_group_by_ids(product_ids, group_id)
        return self.send_success()

    @AdminBaseView.permission_required([AdminBaseView.staff_permissions.ADMIN_PRODUCT])
    @use_args(
        {
            "status": StrToList(
                required=False,
                missing=[ProductStatus.ON, ProductStatus.OFF],
                validate=[
                    validate.ContainsOnly(
                        [ProductStatus.ON, ProductStatus.OFF]
                    )
                ],
                comment="货品状态,上架/下架",
            )
        },
        location="query"
    )
    def get(self, request, args):
        shop = self.current_shop
        product_group_with_count = list_product_group_with_product_count(shop.id, **args)
        serializer = AdminProductGroupSerializer(product_group_with_count, many=True)
        return self.send_success(data_list=serializer.data)
示例#30
0
class CategoryPostListResource(TokenRequiredResource):
    get_args = {
        "title": fields.String(allow_none=True, validate=lambda x: 0 <= len(x) <= 255),
        "slug": fields.String(allow_none=True, validate=lambda x: 0 <= len(x) <= 255),
        "author_id": fields.Integer(allow_none=True, validate=lambda x: x > 0),
        "created_at": fields.DateTime(allow_none=True, format="iso8601"),
    }

    @use_args(get_args)
    def get(self, query_args, id):
        filters = []
        if "main" in query_args:
            filters.append(PostCategory.primary == query_args["main"])
        if "title" in query_args:
            filters.append(Post.title.like("%{filter}%".format(filter=query_args["title"])))
        if "slug" in query_args:
            filters.append(Post.slug.like("%{filter}%".format(filter=query_args["slug"])))
        if "author_id" in query_args:
            filters.append(Post.author_id == query_args["author_id"])
        if "created_at" in query_args:
            filters.append(Post.created_at == query_args["created_at"])

        pagination_helper = PaginationHelper(
            request,
            query=Category.query.get(id).posts.filter(*filters),
            resource_for_url="api.category_posts",
            key_name="results",
            schema=post_schema,
            url_parameters={"id": id},
            query_args=query_args,
        )
        result = pagination_helper.paginate_query()
        return result