예제 #1
0
class BaseHTTPService(BaseService):

    template_folder = None
    application_name = None

    def setup_server(self):
        super(BaseHTTPService, self).setup_server()

        # Create an Flask application for this service using the service name
        self.app = Flask(self.application_name or self.__class__.__name__, template_folder=self.template_folder or None)

        methods = dir(self)
        # Attach each route in the class
        for name in [x for x in methods if x.startswith("route_")]:
            method = getattr(self, name)
            self.app.add_url_rule(method.rule, name, method)
            self.logger.debug("Adding handler '%s' for '%s' rule", name, method.rule)

        # Attach error handlers in the class
        for name in [x for x in methods if x.startswith("error_")]:
            method = getattr(self, name)
            code = int(name.split("_", 2)[1])
            self.app.error_handler_spec[None][code] = method
            self.logger.debug("Adding handler '%s' for error code '%d'", name, code)

    def run(self):
        self.logger.debug("Waiting for clients")
        try:
            self.app.run(self.listener_address, self.listener_port)
        except KeyboardInterrupt:
            self.logger.warning("Canceled by the user")
            self.stop()

    def stop(self):
        self.logger.debug("Stopping server")
예제 #2
0
파일: __init__.py 프로젝트: hekun97/MIS
def create_app(test_config=None):
    app = Flask(__name__, instance_relative_config=True)
    app.config.from_mapping(
        SECRET_KEY='dev',
        # 数据库的路径和名称
        DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'),
    )

    if test_config is None:
        # load the instance config, if it exists, when not testing
        app.config.from_pyfile('config.py', silent=True)
    else:
        # load the test config if passed in
        app.config.from_mapping(test_config)

    # ensure the instance folder exists
    try:
        os.makedirs(app.instance_path)
    except OSError:
        pass
    # 从 views.py 中调用 init_route() 函数,称为懒加载
    # init_route(app)
    # 导入auth.py 的bp的 蓝图 函数
    from . import auth
    app.register_blueprint(auth.bp)

    # 导入系统 system.py 的 bp蓝图
    from . import system
    app.register_blueprint(system.bp)
    app.add_url_rule('/', endpoint='index')
    
    # 用户端蓝图
    init_user(app=app)

    # 管理员端蓝图
    init_admin(app=app)

    # 导入数据库文件 db.py
    from . import db
    db.init_app(app)

    return app
예제 #3
0
    def register(self, app: Flask) -> NoReturn:
        def scan(washer_id: int) -> str:
            if request.method == 'POST':
                data: Dict = request.get_json()

                new_state: WasherState = WasherState[data['state']]
                if new_state == WasherState.RUNNING:
                    end_time: int = time() + self.CYCLES[data['cycle']]
                    self.washers[washer_id] = end_time
                elif new_state == WasherState.EMPTY:
                    # set to magic "empty" number
                    self.washers[washer_id] = -1

                return 'ok'
            else:  # request.method == 'GET'
                return render_template('scan.html')

        app.add_url_rule('/scan/<int:washer_id>',
                         'scan',
                         scan,
                         methods=['GET', 'POST'])
예제 #4
0
class BaseHTTPService(BaseService):

    template_folder = None
    application_name = None

    def setup_server(self):
        super(BaseHTTPService, self).setup_server()

        # Create an Flask application for this service using the service name
        self.app = Flask(self.application_name or self.__class__.__name__,
                         template_folder=self.template_folder or None)

        methods = dir(self)
        # Attach each route in the class
        for name in [x for x in methods if x.startswith("route_")]:
            method = getattr(self, name)
            self.app.add_url_rule(method.rule, name, method)
            self.logger.debug("Adding handler '%s' for '%s' rule", name,
                              method.rule)

        # Attach error handlers in the class
        for name in [x for x in methods if x.startswith("error_")]:
            method = getattr(self, name)
            code = int(name.split("_", 2)[1])
            self.app.error_handler_spec[None][code] = method
            self.logger.debug("Adding handler '%s' for error code '%d'", name,
                              code)

    def run(self):
        self.logger.debug("Waiting for clients")
        try:
            self.app.run(self.listener_address, self.listener_port)
        except KeyboardInterrupt:
            self.logger.warning("Canceled by the user")
            self.stop()

    def stop(self):
        self.logger.debug("Stopping server")
예제 #5
0
class SimpleFlaskAppTest(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)

        self.client = self.app.test_client()

        self.db = MongoEngine(self.app)

        with self.app.app_context():
            self.db.connection.drop_database("test")
        # self.db.connection

        class TestCol(db.Document):
            value = db.StringField()

            def __unicode__(self):
                return "TestCol(value={})".format(self.value)

        TestCol.objects.delete()

        TestCol.objects.create(value="1")
        TestCol.objects.create(value="2")

        self.TestCol = TestCol

    def _parse(self, resp):
        resp = resp.decode("utf-8")
        return json.loads(resp)

    def test_validation_mongoengine_will_work_with_model_serializer(self):
        class Doc(db.Document):
            value = db.StringField(validation=RegexpValidator(
                r"\d+", message="Bad value").for_mongoengine())

        Doc.drop_collection()

        class Serializer(ModelSerializer):
            class Meta:
                model = Doc

        Doc.objects.create(value="123")

        s = Serializer(data={"value": "asd"})
        self.assertEqual(s.validate(), False)
        self.assertEqual(s.errors, {"value": ["Bad value"]})

    def test_resource_decorator(self):
        class S(BaseSerializer):

            field = fields.StringField(required=True)

        @self.app.route("/test", methods=["POST"])
        @validate(S)
        def resource(cleaned_data):
            return "OK"

        resp = self.client.post("/test",
                                data=json.dumps({}),
                                headers={"Content-Type": "application/json"})
        self.assertEqual(resp.status_code, 400)
        self.assertEqual(json.loads(resp.data.decode("utf-8")),
                         {'field': ['Field is required']})

    def testSimpleResourceAndRouter(self):

        router = DefaultRouter(self.app)

        class Resource(BaseResource):
            def get(self, request):
                return "GET"

            def post(self, request):
                return "POST"

            def put(self, request):
                return "PUT"

            def patch(self, request):
                return "PATCH"

            def delete(self, request):
                return "DELETE"

            @list_route(methods=["GET", "POST"])
            def listroute(self, request):
                return "LIST"

            @detail_route(methods=["GET", "POST"])
            def detailroute(self, request, pk):
                return "detail"

        self.assertSetEqual(set(Resource.get_allowed_methods()),
                            {"get", "post", "put", "patch", "delete"})

        router.register("/test", Resource, "test")

        for method in ["get", "post", "put", "patch", "delete"]:
            resp = getattr(self.client, method)("/test")
            self.assertEqual(resp.data.decode("utf-8"), method.upper())

        for method in ["GET", "POST"]:
            resp = getattr(self.client, method.lower())("/test/listroute")
            self.assertEqual(resp.status_code, 200)
            self.assertEqual(resp.data.decode("utf-8"), "LIST")

        for method in ["GET", "POST"]:
            resp = getattr(self.client, method.lower())("/test/detailroute/1")
            self.assertEqual(resp.status_code, 200)
            self.assertEqual(resp.data.decode("utf-8"), "detail")

            resp = self.client.get("/test/detailroute")
            self.assertEqual(resp.status_code, 404)

    def testRoutingWithBluePrint(self):

        bp = Blueprint("test", __name__)
        router = DefaultRouter(bp)

        class Res(BaseResource):
            def get(self, request):
                return "GET"

        router.register("/blabla", Res, "blabla")

        self.app.register_blueprint(bp, url_prefix="/test")

        with self.app.test_request_context():
            self.assertEqual(url_for("test.blabla"), "/test/blabla")

    @pytest.mark.testModelResource123
    def testModelResource(self):

        router = DefaultRouter(self.app)

        class Base(db.Document):

            title = db.StringField()

        class ED(db.EmbeddedDocument):

            value = db.StringField()

        class Model(db.Document):

            base = db.ReferenceField(Base)
            f1 = db.StringField()
            f2 = db.BooleanField()
            f3 = db.StringField()

            embedded = db.EmbeddedDocumentField(ED)
            listf = db.EmbeddedDocumentListField(ED)

            dictf = db.DictField()

        Model.objects.delete()

        ins = Model.objects.create(base=Base.objects.create(title="1"),
                                   f1="1",
                                   f2=True,
                                   f3="1",
                                   embedded={"value": "123"},
                                   listf=[{
                                       "value": "234"
                                   }],
                                   dictf={"key": "value"})

        Model.objects.create(base=Base.objects.create(title="2"),
                             f1="2",
                             f2=True,
                             f3="2",
                             embedded={"value": "123"},
                             listf=[{
                                 "value": "234"
                             }])

        class S(ModelSerializer):
            title = fields.ForeignKeyField("base__title")

            class Meta:
                model = Model
                fk_fields = ("base__title", )

        class ModelRes(ModelResource):

            serializer_class = S
            queryset = Model.objects.all()
            pagination_class = DefaultPagination

        router.register("/test", ModelRes, "modelres")
        resp = self.client.get("/test")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)

        self.assertEqual(len(data["results"]), 2)
        item = data["results"][0]
        self.assertEqual(item["dictf"], {"key": "value"})
        self.assertEqual(item["title"], "1")
        self.assertEqual(item["base__title"], "1")

        # get one object
        resp = self.client.get("/test/{}".format(ins.id))
        self.assertEqual(resp.status_code, 200)
        pprint(self._parse(resp.data))

        #test pagination
        for i in range(10):
            Model.objects.create(base=Base.objects.create(title="1"),
                                 f1="1",
                                 f2=True,
                                 f3="2")

        self.assertEqual(Model.objects.count(), 12)

        resp = self.client.get("/test?page=1")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        results = data["results"]

        self.assertEqual(results[0]["embedded"], {"value": "123"})
        self.assertEqual(results[0]["listf"], [{"value": "234"}])

        self.assertEqual(len(data["results"]), 10)

        resp = self.client.get("/test?page=2")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        self.assertEqual(len(data["results"]), 2)

        resp = self.client.get("/test?page=2&page_size=5")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        self.assertEqual(len(data["results"]), 5)

        resp = self.client.get("/test?page=3&page_size=5")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        self.assertEqual(len(data["results"]), 2)

        #test put
        resp = self.client.put("/test/{}".format(ins.id),
                               data=json.dumps({"f3": "OLALA"}),
                               headers={"Content-Type": "application/json"})
        self.assertEqual(resp.status_code, 200, resp.data)
        data = self._parse(resp.data)
        self.assertEqual(data["f1"], "1")
        self.assertEqual(data["f2"], True)
        self.assertEqual(data["f3"], "OLALA")

    def testMongoEngineForeignKeyField(self):

        self.assertEqual(self.TestCol.objects.count(), 2)

        class Serializer(BaseSerializer):
            fk = fields.MongoEngineIdField(self.TestCol, required=True)

        v = Serializer({"fk": "123"})
        self.assertEqual(v.validate(), False)
        self.assertEqual(v.errors, {'fk': ['Incorrect id: 123']})

        v = Serializer({"fk": str(self.TestCol.objects.first().id)})
        v.validate()
        self.assertEqual(v.errors, {})
        self.assertEqual(v.cleaned_data["fk"], self.TestCol.objects.first())

        class R(BaseResource):
            def post(self, request):
                errors, data = self.validate_request(Serializer)
                if errors:
                    return errors

                return "OK"

        self.app.add_url_rule("/api",
                              view_func=R.as_view("test2"),
                              methods=["GET", "POST"])

        resp = self.client.post("/api",
                                data=json.dumps({}),
                                headers={"Content-Type": "application/json"})
        self.assertEqual(resp.status_code, 400)
        data = json.loads(resp.data.decode("utf-8"))
        self.assertEqual(data["fk"], ['Field is required'])

    def testSerialization(self):
        class Col(db.Document):

            value = db.StringField()
            created = db.DateTimeField(default=datetime.datetime.now)

        Col.objects.delete()
        Col.objects.create(value="1")
        Col.objects.create(value="2")

        class S(BaseSerializer):

            value = fields.StringField()
            created = fields.DateTimeField(read_only=True)

        data = S(Col.objects.all()).to_python()
        self.assertEqual(len(data), 2)
        self.assertEqual(list(map(lambda i: i["value"], data)), ["1", "2"])

        #test can't set read only field
        ser = S({"value": "1", "created": "2016-01-01 00:00:00"})
        ser.validate()

        self.assertTrue("created" not in ser.cleaned_data)

    def testModelSerialization(self):
        class DeepInner(db.EmbeddedDocument):
            value = db.StringField()

        class Inner(db.EmbeddedDocument):
            value = db.StringField()
            deep = db.EmbeddedDocumentField(DeepInner)

        class Col(db.Document):

            value = db.StringField()
            excluded_field = db.StringField(default="excluded")
            created = db.DateTimeField(default=datetime.datetime.now)
            inner = db.EmbeddedDocumentField(Inner)

        Col.objects.delete()
        Col.objects.create(value="1",
                           inner={
                               "value": "inner1",
                               "deep": {
                                   "value": "123"
                               }
                           })
        Col.objects.create(value="2", inner={"value": "inner2"})

        class Serializer(ModelSerializer):

            method_field = fields.MethodField("test")

            renamed = fields.ForeignKeyField(
                document_fieldname="inner__deep__value")

            def test(self, doc):
                return doc.value

            class Meta:
                model = Col
                fields = ("value", "created", "method_field")
                fk_fields = ("inner__value", "inner__deep__value")

        data = Serializer(Col.objects.all()).to_python()

        for item in data:
            self.assertTrue("value" in item)
            self.assertEqual(item["value"], item["method_field"])
            self.assertTrue(type(item["created"]), datetime.datetime)
            self.assertEqual(item["renamed"], item["inner__deep__value"])
예제 #6
0
class TestFlask(unittest.TestCase):
    app = None

    def setUp(self):
        modules.flags_manager.remove_all_flags()

        self.app = Flask(__name__)
        self.app.add_url_rule('/test', 'test', test_api, methods=['POST'])

        init_flask('test_app', self.app, setup_kuber_config_loader=False)

        flags.DEFINE_INTEGER_FLAG("test")
        self.loader = KuberConfigLoader("test_service")
        self.loader.load_config(
            config_pb2.GlobalConfig(
                flags=[{
                    "name": "test",
                    "type": "INTEGER",
                    "value": {
                        "base_value": {
                            "number_value": 1
                        }
                    }
                }],
                experiments={
                    2:
                    ExperimentDefinition(
                        id=2,
                        flag_values={
                            "test": FlagValue(base_value={"number_value": 2})
                        })
                }))

        self.server = Process(target=self.app.run, args=("127.0.0.1", 8008))
        self.server.start()
        sleep(1)

    def tearDown(self):
        self.server.terminate()

    def test_experiments(self):
        headers = {
            'x-internal-state-bin': 'EgA='  # experiments: []
        }
        data = requests.post(url='http://127.0.0.1:8008/test',
                             headers=headers).json()
        self.assertEqual(data['test'], 1)

        headers = {
            'x-internal-state-bin': 'EgEC'  # experiments: [2]
        }
        data = requests.post(url='http://127.0.0.1:8008/test',
                             headers=headers).json()
        self.assertEqual(data['test'], 2)

    def test_trace_info(self):
        headers = {
            'x-internal-trace-info-bin':
            'ChYSCzA5MTIzNDU2Nzg5CgcxLjIuMy40'  # ip: 1.2.3.4, phone: 09123456789
        }
        data = requests.post(url='http://127.0.0.1:8008/test',
                             headers=headers).json()
        self.assertEqual(data['trace_info']['client']['ip'], '1.2.3.4')
        self.assertEqual(data['trace_info']['client']['phone'], '09123456789')
예제 #7
0
    def __init__(self, *args, **kwargs):
        super(HacknightNamespace, self).__init__(*args, **kwargs)
 
    def on_hashtag(self, hashtag):
        hashtag_value = hashtag.get('value')
        gevent.spawn(send_tweets, self, hashtag_value)
 
    def on_test(self, data):
        while True:
            self.emit('datetime', str(datetime.utcnow()))
            gevent.sleep(1)
 
def main_endpoint(remaining_path):
    socketio_manage(request.environ, {'/main': HacknightNamespace}, request)
 
app.add_url_rule('/socket.io/<path:remaining_path>', 'main', main_endpoint)
 
 
 
@run_with_reloader
def run_server():
    #app.run('0.0.0.0', 8080, debug=True)
    global app
 
    print 'Starting SocketIO Server with Gevent Mode ...'
    app = SharedDataMiddleware(app, {})
    http_server = SocketIOServer(('', 80), app, namespace="socket.io", policy_server=False)
    http_server.serve_forever()
 
if __name__ == '__main__':
    run_server()
예제 #8
0
class FlaskRestServer(object):
    '''
	Small server which can be run on a different thread.
	Wires REST calls to DAOs.
	'''
    def __init__(self,
                 dbUri='sqlite:///:memory:',
                 port=5000,
                 verbose=False,
                 serializer=jsonSerializerWithUri):

        self.flaskApp = Flask(__name__,
                              static_url_path='',
                              static_folder=abspath('./static'))

        print 'Using following sqlite database URI:', dbUri
        self.flaskApp.config['SQLALCHEMY_DATABASE_URI'] = dbUri

        self.flaskApp.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

        if verbose:
            self.flaskApp.config['SQLALCHEMY_ECHO'] = True

        @self.flaskApp.route('/')
        def index():
            return self.flaskApp.send_static_file('index.html')

        self.port = port
        self.serializer = serializer
        self.namespaces = {}

    def wire(self, cls, icls=None, namespace=None):
        '''
		Wire HTTP calls to a given class to a DAO.
		'''

        if icls is None:
            icls = cls

        # assign default namespace if need be
        if namespace is None:
            iclsName = icls.__name__
            namespace = iclsName[0] + iclsName[1:] + 's'

        namespaceSingle = namespace[:-1]

        self.namespaces[cls] = {
            'cls': cls,
            'icls': icls,
            'namespace': namespace,
            'namespaceSingle': namespaceSingle,
            'innerNamespaces': {}
        }

        def _ok(content={'result': True}):
            '''
			Build positive response with result:True by default
			'''
            ret = Response(self.serializer._to(content),
                           status=200,
                           content_type=self.serializer.contentType)
            return ret

        def _abort(exceptionText):
            '''
			Create response with an error.
			Serialize exception to be sent to caller.
			'''
            self.currentError = exceptionText
            abort(400)

        def _deserializeRequestData():
            '''
			Takes request data (be it POST or PUT) and deserializes it using given method.
			'''
            if flask_request.get_data() is '':
                return {}

            # if there is POST/PUT data, it should be serialized using expected method
            if str(flask_request.headers['Content-Type']
                   ) != self.serializer.contentType:
                _abort('Request Content-Type does not match expected: %s!=%s' %
                       (flask_request.headers['Content-Type'],
                        self.serializer.contentType))
            try:
                # deserialize data
                ret = self.serializer._from(flask_request.get_data())
            except:
                _abort(
                    'Could not deserialize data %s using deserialization method %s'
                    %
                    (flask_request.get_data(), self.serializer._from.__name__))

            return ret

        def handleGetAll():
            '''
			Fetch a variable dynamically when GET call comes in.
			Simply return serialized version
			'''
            if flask_request.args:
                filterArgs = {
                    name: str(value)
                    for name, value in flask_request.args.iteritems()
                    if name != '_'
                }
                items = cls.query.filter_by(**filterArgs).all()
            else:
                items = cls.query.all()

            self.serializer.currentUri = url_for(namespace, _external=True)
            self.serializer.currentType = icls
            ret = _ok({namespace: items})
            self.serializer.currentUri = None
            self.serializer.currentType = None
            return ret

        def handleGet(uid):
            '''
			Handle findById.
			'''
            self.serializer.currentUri = url_for(namespace + '_findById',
                                                 uid=uid,
                                                 _external=True)
            self.serializer.currentType = icls
            ret = _ok({namespaceSingle: cls.query.get(uid)})
            self.serializer.currentUri = None
            self.serializer.currentType = None
            return ret

        def handleGetInner(uid, propName):
            '''
			Handle finding of attributes mapped with one-to-many.
			'''
            innerNamespace = self.namespaces[cls]['innerNamespaces'][propName]
            item = cls.query.get(uid)
            innerQuery = getattr(item, propName)

            if flask_request.args:
                filterArgs = {
                    name: str(value)
                    for name, value in flask_request.args.iteritems()
                }
                innerItems = innerQuery.filter_by(**filterArgs).all()
            else:
                innerItems = innerQuery.all()

            self.serializer.currentUri = url_for(innerNamespace['namespace'],
                                                 _external=True)
            self.serializer.currentType = innerNamespace['icls']
            ret = _ok({innerNamespace['namespace']: innerItems})
            self.serializer.currentUri = None
            self.serializer.currentType = None
            return ret

        def handlePost():
            '''
			Handle POST calls.
			If attempted prop is a list, create new element with parameters in request data.
			'''
            data = _deserializeRequestData()
            newItem = cls(**data)
            dbsession.add(newItem)
            dbsession.commit()
            dbsession.refresh(newItem)

            self.serializer.currentUri = url_for(namespace + '_findById',
                                                 uid=newItem.uid,
                                                 _external=True)
            self.serializer.currentType = icls
            ret = _ok({namespaceSingle: newItem})
            self.serializer.currentUri = None
            self.serializer.currentType = None
            return ret

        def handlePut(uid):
            '''
			Handle PUT calls.
			If attempted prop is a dict or class instance, it gets updated based on request data.
			'''
            data = _deserializeRequestData()
            dbsession.query(cls).filter_by(uid=uid).update(data)
            dbsession.commit()
            item = cls.query.get(uid)

            self.serializer.currentUri = url_for(namespace + '_findById',
                                                 uid=item.uid,
                                                 _external=True)
            self.serializer.currentType = icls
            ret = _ok({namespaceSingle: item})
            self.serializer.currentUri = None
            self.serializer.currentType = None
            return ret

        def handleDelete(uid):
            '''
			Handle DELETE calls.
			Deletes item with given ID.
			'''
            dbsession.delete(cls.query.get(uid))
            dbsession.commit()
            return _ok()

        '''
		Add HTTP hooks.
		'''
        self.flaskApp.add_url_rule('/' + namespace,
                                   namespace,
                                   handleGetAll,
                                   methods=['GET'])
        self.flaskApp.add_url_rule('/' + namespace + '/<int:uid>',
                                   namespace + '_findById',
                                   handleGet,
                                   methods=['GET'])
        self.flaskApp.add_url_rule('/' + namespace +
                                   '/<int:uid>/<string:propName>',
                                   namespace + '_findInner',
                                   handleGetInner,
                                   methods=['GET'])
        self.flaskApp.add_url_rule('/' + namespace,
                                   namespace + '_create',
                                   handlePost,
                                   methods=['POST'])
        self.flaskApp.add_url_rule('/' + namespace + '/<int:uid>',
                                   namespace + '_update',
                                   handlePut,
                                   methods=['PUT'])
        self.flaskApp.add_url_rule('/' + namespace + '/<int:uid>',
                                   namespace + '_delete',
                                   handleDelete,
                                   methods=['DELETE'])

        print 'Wired: http://localhost:%d/%s' % (self.port, namespace)

    def wireOneToMany(self, cls, innerCls, propName):
        '''
		Note that there is a one-to-many relationship between cls and innerClass
		that can be accessed in cls through propName.
		'''
        self.namespaces[cls]['innerNamespaces'][propName] = self.namespaces[
            innerCls]

    def start(self, threaded=True):
        '''
		Launch server.
		'''
        def _handleBadRequest(error):
            '''
			Create response with an error.
			Serialize exception to be sent to caller.
			'''
            e = ServerException(self.currentError)
            stderr.write('Server Exception: %s\n' % e)
            return Response(self.serializer._to(e),
                            status=400,
                            content_type=self.serializer.contentType)

        # handle 400 error
        self.flaskApp.register_error_handler(400, _handleBadRequest)

        print 'Starting server'
        if isPortListening(port=self.port):
            raise ServerException('Port %d already is use' % self.port)

        if threaded:
            # start http server on different thread so current one can go on changing the variables.
            start_new_thread(self.flaskApp.run, ('0.0.0.0', self.port))
        else:
            self.flaskApp.run('0.0.0.0', self.port)
예제 #9
0
#! /usr/bin/env python
from flask.app import Flask
from flask.globals import request

from zaguan.examples.colors.controller import ColorsController

_msg_queue = []
def queue_send(msg):
    _msg_queue.append(msg)

if __name__ == '__main__':
    controller = ColorsController()
    controller.send_function = queue_send

    app = Flask(__name__, static_folder="html")
    app.add_url_rule('/<path:filename>', endpoint='static',
                     view_func=app.send_static_file)
    app.debug = True

    @app.route("/")
    def index(module=None):
        html = open("html/index.html").read()
        return html + '<script src="js/debug.js" type="text/javascript"></script>'

    @app.route("/colors/<action>")
    def actions(action):
        controller.process_uri(request.url)
        return ""

    @app.route("/debug/messages")
    def debug_send():
        global _msg_queue
예제 #10
0
class SimpleFlaskAppTest(unittest.TestCase):

    def setUp(self):
        self.app = Flask(__name__)

        self.client = self.app.test_client()

        self.db = MongoEngine(self.app)

        with self.app.app_context():
            self.db.connection.drop_database("test")
        # self.db.connection

        class TestCol(db.Document):
            value = db.StringField()

            def __unicode__(self):
                return "TestCol(value={})".format(self.value)

        TestCol.objects.delete()

        TestCol.objects.create(value="1")
        TestCol.objects.create(value="2")

        self.TestCol = TestCol

    def _parse(self, resp):
        resp = resp.decode("utf-8")
        return json.loads(resp)

    def test_validation_mongoengine_will_work_with_model_serializer(self):

        class Doc(db.Document):
            value = db.StringField(validation=RegexpValidator(r"\d+", message="Bad value").for_mongoengine())

        Doc.drop_collection()

        class Serializer(ModelSerializer):
            class Meta:
                model = Doc

        Doc.objects.create(value="123")

        s = Serializer(data={"value": "asd"})
        self.assertEqual(s.validate(), False)
        self.assertEqual(s.errors, {"value": ["Bad value"]})

    def test_resource_decorator(self):

        class S(BaseSerializer):

            field = fields.StringField(required=True)

        @self.app.route("/test", methods=["POST"])
        @validate(S)
        def resource(cleaned_data):
            return "OK"

        resp = self.client.post("/test", data=json.dumps({}), headers={"Content-Type": "application/json"})
        self.assertEqual(resp.status_code, 400)
        self.assertEqual(
            json.loads(resp.data.decode("utf-8")), {'field': ['Field is required']}
        )

    def testSimpleResourceAndRouter(self):

        router = DefaultRouter(self.app)

        class Resource(BaseResource):

            def get(self, request):
                return "GET"

            def post(self, request):
                return "POST"

            def put(self, request):
                return "PUT"

            def patch(self, request):
                return "PATCH"

            def delete(self, request):
                return "DELETE"

            @list_route(methods=["GET", "POST"])
            def listroute(self, request):
                return "LIST"

            @detail_route(methods=["GET", "POST"])
            def detailroute(self, request, pk):
                return "detail"

        self.assertSetEqual(
            set(Resource.get_allowed_methods()), {"get", "post", "put", "patch", "delete"}
        )

        router.register("/test", Resource, "test")

        for method in ["get", "post", "put", "patch", "delete"]:
            resp = getattr(self.client, method)("/test")
            self.assertEqual(resp.data.decode("utf-8"), method.upper())

        for method in ["GET", "POST"]:
            resp = getattr(self.client, method.lower())("/test/listroute")
            self.assertEqual(resp.status_code, 200)
            self.assertEqual(resp.data.decode("utf-8"), "LIST")

        for method in ["GET", "POST"]:
            resp = getattr(self.client, method.lower())("/test/detailroute/1")
            self.assertEqual(resp.status_code, 200)
            self.assertEqual(resp.data.decode("utf-8"), "detail")

            resp = self.client.get("/test/detailroute")
            self.assertEqual(resp.status_code, 404)



    def testRoutingWithBluePrint(self):

        bp = Blueprint("test", __name__)
        router = DefaultRouter(bp)

        class Res(BaseResource):
            def get(self, request):
                return "GET"

        router.register("/blabla", Res, "blabla")

        self.app.register_blueprint(bp, url_prefix="/test")

        with self.app.test_request_context():
            self.assertEqual(url_for("test.blabla"), "/test/blabla")

    @pytest.mark.testModelResource123
    def testModelResource(self):

        router = DefaultRouter(self.app)

        class Base(db.Document):

            title = db.StringField()

        class ED(db.EmbeddedDocument):

            value = db.StringField()

        class Model(db.Document):

            base = db.ReferenceField(Base)
            f1 = db.StringField()
            f2 = db.BooleanField()
            f3 = db.StringField()

            embedded = db.EmbeddedDocumentField(ED)
            listf = db.EmbeddedDocumentListField(ED)

            dictf = db.DictField()

        Model.objects.delete()

        ins = Model.objects.create(
            base=Base.objects.create(title="1"),
            f1="1",
            f2=True,
            f3="1",
            embedded={"value": "123"},
            listf=[{"value": "234"}],
            dictf={"key": "value"}
        )

        Model.objects.create(
            base=Base.objects.create(title="2"),
            f1="2",
            f2=True,
            f3="2",
            embedded={"value": "123"},
            listf=[{"value": "234"}]
        )

        class S(ModelSerializer):
            title = fields.ForeignKeyField("base__title")
            class Meta:
                model = Model
                fk_fields = ("base__title", )

        class ModelRes(ModelResource):

            serializer_class = S
            queryset = Model.objects.all()
            pagination_class = DefaultPagination

        router.register("/test", ModelRes, "modelres")
        resp = self.client.get("/test")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)

        self.assertEqual(len(data["results"]), 2)
        item = data["results"][0]
        self.assertEqual(item["dictf"], {"key": "value"})
        self.assertEqual(item["title"], "1")
        self.assertEqual(item["base__title"], "1")

        # get one object
        resp = self.client.get("/test/{}".format(ins.id))
        self.assertEqual(resp.status_code, 200)
        pprint(self._parse(resp.data))

        #test pagination
        for i in range(10):
            Model.objects.create(
                base=Base.objects.create(title="1"),
                f1="1",
                f2=True,
                f3="2"
            )

        self.assertEqual(Model.objects.count(), 12)

        resp = self.client.get("/test?page=1")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        results = data["results"]

        self.assertEqual(results[0]["embedded"], {"value": "123"})
        self.assertEqual(results[0]["listf"], [{"value": "234"}])

        self.assertEqual(len(data["results"]), 10)

        resp = self.client.get("/test?page=2")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        self.assertEqual(len(data["results"]), 2)

        resp = self.client.get("/test?page=2&page_size=5")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        self.assertEqual(len(data["results"]), 5)

        resp = self.client.get("/test?page=3&page_size=5")
        self.assertEqual(resp.status_code, 200)
        data = self._parse(resp.data)
        self.assertEqual(len(data["results"]), 2)


        #test put
        resp = self.client.put("/test/{}".format(ins.id), data=json.dumps({
            "f3": "OLALA"
        }), headers={"Content-Type": "application/json"})
        self.assertEqual(resp.status_code, 200, resp.data)
        data = self._parse(resp.data)
        self.assertEqual(data["f1"], "1")
        self.assertEqual(data["f2"], True)
        self.assertEqual(data["f3"], "OLALA")


    def testMongoEngineForeignKeyField(self):

        self.assertEqual(self.TestCol.objects.count(), 2)

        class Serializer(BaseSerializer):
            fk = fields.MongoEngineIdField(self.TestCol, required=True)

        v = Serializer({"fk": "123"})
        self.assertEqual(v.validate(), False)
        self.assertEqual(v.errors, {'fk': ['Incorrect id: 123']})

        v = Serializer({"fk": str(self.TestCol.objects.first().id)})
        v.validate()
        self.assertEqual(v.errors, {})
        self.assertEqual(v.cleaned_data["fk"], self.TestCol.objects.first())

        class R(BaseResource):

            def post(self, request):
                errors, data = self.validate_request(Serializer)
                if errors:
                    return errors

                return "OK"

        self.app.add_url_rule("/api", view_func=R.as_view("test2"),
                              methods=["GET", "POST"])

        resp = self.client.post("/api", data=json.dumps({}),
                                headers={"Content-Type": "application/json"})
        self.assertEqual(resp.status_code, 400)
        data = json.loads(resp.data.decode("utf-8"))
        self.assertEqual(data["fk"], ['Field is required'])

    def testSerialization(self):

        class Col(db.Document):

            value = db.StringField()
            created = db.DateTimeField(default=datetime.datetime.now)

        Col.objects.delete()
        Col.objects.create(value="1")
        Col.objects.create(value="2")

        class S(BaseSerializer):

            value = fields.StringField()
            created = fields.DateTimeField(read_only=True)

        data = S(Col.objects.all()).to_python()
        self.assertEqual(len(data), 2)
        self.assertEqual(
            list(map(lambda i: i["value"], data)),
            ["1", "2"]
        )

        #test can't set read only field
        ser = S({"value": "1", "created": "2016-01-01 00:00:00"})
        ser.validate()

        self.assertTrue("created" not in ser.cleaned_data)

    def testModelSerialization(self):

        class DeepInner(db.EmbeddedDocument):
            value = db.StringField()

        class Inner(db.EmbeddedDocument):
            value = db.StringField()
            deep = db.EmbeddedDocumentField(DeepInner)

        class Col(db.Document):

            value = db.StringField()
            excluded_field = db.StringField(default="excluded")
            created = db.DateTimeField(default=datetime.datetime.now)
            inner = db.EmbeddedDocumentField(Inner)


        Col.objects.delete()
        Col.objects.create(value="1", inner={"value": "inner1", "deep": {"value": "123"}})
        Col.objects.create(value="2", inner={"value": "inner2"})

        class Serializer(ModelSerializer):

            method_field = fields.MethodField("test")

            renamed = fields.ForeignKeyField(document_fieldname="inner__deep__value")

            def test(self, doc):
                return doc.value

            class Meta:
                model = Col
                fields = ("value", "created", "method_field")
                fk_fields = ("inner__value", "inner__deep__value")

        data = Serializer(Col.objects.all()).to_python()

        for item in data:
            self.assertTrue("value" in item)
            self.assertEqual(item["value"], item["method_field"])
            self.assertTrue(type(item["created"]), datetime.datetime)
            self.assertEqual(item["renamed"], item["inner__deep__value"])