Esempio n. 1
0
class TestCase(unittest.TestCase):
    '''An helper mixin for common operations'''
    def setUp(self):
        '''Initialize an Flask application'''
        self.app = Flask(__name__)

    @contextmanager
    def context(self):
        with self.app.test_request_context('/'):
            yield

    def get_specs(self, prefix='/api', app=None):
        '''Get a Swagger specification for a RestPlus API'''
        with self.app.test_client() as client:
            response = client.get('{0}/specs.json'.format(prefix))
            self.assertEquals(response.status_code, 200)
            self.assertEquals(response.content_type, 'application/json')
            return json.loads(response.data.decode('utf8'))

    def get_declaration(self, namespace='default', prefix='/api', status=200, app=None):
        '''Get an API declaration for a given namespace'''
        with self.app.test_client() as client:
            response = client.get('{0}/{1}.json'.format(prefix, namespace))
            self.assertEquals(response.status_code, status)
            self.assertEquals(response.content_type, 'application/json')
            return json.loads(response.data.decode('utf8'))
Esempio n. 2
0
class DefaultsTestCase(FlaskCorsTestCase):
    def setUp(self):
        self.app = Flask(__name__)

        @self.app.route('/', methods=['GET','OPTIONS'])
        @cross_origin()
        def wildcard():
            return 'Welcome!'

    def test_wildcard_defaults_no_origin(self):
        ''' If there is no Origin header in the request, the Access-Control-Allow-Origin
            header should not be included, according to the w3 spec.
        '''
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb('/')
                self.assertEqual(result.headers.get(AccessControlAllowOrigin), '*')

    def test_wildcard_defaults_origin(self):
        ''' If there is no Origin header in the request, the Access-Control-Allow-Origin
            header should be included, if and only if the always_send parameter is
            `True`, which is the default value.
        '''
        example_origin = 'http://example.com'
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb('/',headers = {'Origin': example_origin})
                self.assertEqual(result.headers.get(AccessControlAllowOrigin),'*')
Esempio n. 3
0
class W3TestCase(FlaskCorsTestCase):
    def setUp(self):
        self.app = Flask(__name__)

        @self.app.route('/', methods=['GET','OPTIONS'])
        @cross_origin(origins='*', send_wildcard=False, always_send=False)
        def allowOrigins():
            ''' This sets up flask-cors to echo the request's `Origin` header,
                only if it is actually set. This behavior is most similar to the
                actual W3 specification, http://www.w3.org/TR/cors/ but
                is not the default because it is more common to use the wildcard
                approach. 
            '''
            return 'Welcome!'

    def test_wildcard_origin_header(self):
        ''' If there is an Origin header in the request, the Access-Control-Allow-Origin
            header should be echoed back.
        '''
        example_origin = 'http://example.com'
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb('/', headers = {'Origin': example_origin})
                self.assertEqual(result.headers.get(AccessControlAllowOrigin),example_origin)

    def test_wildcard_no_origin_header(self):
        ''' If there is no Origin header in the request, the Access-Control-Allow-Origin
            header should not be included.
        '''
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb('/')
                self.assertTrue(AccessControlAllowOrigin not in result.headers)
Esempio n. 4
0
class SupportsCredentialsCase(FlaskCorsTestCase):
    def setUp(self):
        self.app = Flask(__name__)

        @self.app.route('/test_credentials')
        @cross_origin(supports_credentials=True)
        def test_credentials():
            return 'Credentials!'

        @self.app.route('/test_open')
        @cross_origin()
        def test_open():
            return 'Open!'

    def test_credentialed_request(self):
        ''' The specified route should return the
            Access-Control-Allow-Credentials header.
        '''
        with self.app.test_client() as c:
            result = c.get('/test_credentials')
            header = result.headers.get(ACL_CREDENTIALS)
            self.assertEquals(header, 'true')

    def test_open_request(self):
        ''' The default behavior should be to disallow credentials.
        '''
        with self.app.test_client() as c:
            result = c.get('/test_open')
            self.assertTrue(ACL_CREDENTIALS not in result.headers)
Esempio n. 5
0
class TestCase(unittest.TestCase):
    '''An helper mixin for common operations'''
    def setUp(self):
        '''Initialize an Flask application'''
        self.app = Flask(__name__)

    @contextmanager
    def context(self, **kwargs):
        with self.app.test_request_context('/', **kwargs):
            yield

    @contextmanager
    def settings(self, **settings):
        '''
        A context manager to alter app settings during a test and restore it after..
        '''
        original = {}

        # backup
        for key, value in settings.items():
            original[key] = self.app.config.get(key)
            self.app.config[key] = value

        yield

        # restore
        for key, value in original.items():
            self.app.config[key] = value

    @contextmanager
    def assert_warning(self, category=Warning):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')
            # Cause all warnings to always be triggered.
            yield
            self.assertGreaterEqual(len(w), 1, 'It should raise a warning')
            warning = w[0]
            self.assertEqual(warning.category, category, 'It should raise {0}'.format(category.__name__))

    def get(self, url, **kwargs):
        with self.app.test_client() as client:
            return client.get(url, **kwargs)

    def post(self, url, **kwargs):
        with self.app.test_client() as client:
            return client.post(url, **kwargs)

    def get_json(self, url, status=200, **kwargs):
        response = self.get(url, **kwargs)
        self.assertEqual(response.status_code, status)
        self.assertEqual(response.content_type, 'application/json')
        return json.loads(response.data.decode('utf8'))

    def get_specs(self, prefix='', status=200, **kwargs):
        '''Get a Swagger specification for a RestPlus API'''
        return self.get_json('{0}/swagger.json'.format(prefix), status=status, **kwargs)

    def assertDataEqual(self, tested, expected):
        '''Compare data without caring about order and type (dict vs. OrderedDict)'''
        assert_data_equal(tested, expected)
Esempio n. 6
0
class OriginsW3TestCase(FlaskCorsTestCase):
    def setUp(self):
        self.app = Flask(__name__)

        @self.app.route('/')
        @cross_origin(origins='*', send_wildcard=False, always_send=False)
        def allowOrigins():
            ''' This sets up flask-cors to echo the request's `Origin` header,
                only if it is actually set. This behavior is most similar to
                the actual W3 specification, http://www.w3.org/TR/cors/ but
                is not the default because it is more common to use the
                wildcard configuration in order to support CDN caching.
            '''
            return 'Welcome!'

        @self.app.route('/default-origins')
        @cross_origin(send_wildcard=False, always_send=False)
        def noWildcard():
            ''' With the default origins configuration, send_wildcard should
                still be respected.
            '''
            return 'Welcome!'

    def test_wildcard_origin_header(self):
        ''' If there is an Origin header in the request, the
            Access-Control-Allow-Origin header should be echoed back.
        '''
        example_origin = 'http://example.com'
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb('/', headers={'Origin': example_origin})
                self.assertEqual(
                    result.headers.get(ACL_ORIGIN),
                    example_origin
                )

    def test_wildcard_no_origin_header(self):
        ''' If there is no Origin header in the request, the
            Access-Control-Allow-Origin header should not be included.
        '''
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb('/')
                self.assertTrue(ACL_ORIGIN not in result.headers)

    def test_wildcard_default_origins(self):
        ''' If there is an Origin header in the request, the
            Access-Control-Allow-Origin header should be echoed back.
        '''
        example_origin = 'http://example.com'
        with self.app.test_client() as c:
            for verb in self.iter_verbs(c):
                result = verb(
                    '/default-origins',
                    headers={'Origin': example_origin}
                )
                self.assertEqual(
                    result.headers.get(ACL_ORIGIN),
                    example_origin
                )
    def test_bugsnag_custom_data(self, deliver):
        meta_data = [{"hello": {"world": "once"}},
                     {"again": {"hello": "world"}}]

        app = Flask("bugsnag")

        @app.route("/hello")
        def hello():
            bugsnag.configure_request(meta_data=meta_data.pop())
            raise SentinelError("oops")

        handle_exceptions(app)
        app.test_client().get('/hello')
        app.test_client().get('/hello')

        self.assertEqual(deliver.call_count, 2)

        payload = deliver.call_args_list[0][0][0]
        event = payload['events'][0]
        self.assertEqual(event['metaData'].get('hello'), None)
        self.assertEqual(event['metaData']['again']['hello'], 'world')

        payload = deliver.call_args_list[1][0][0]
        event = payload['events'][0]
        self.assertEqual(event['metaData']['hello']['world'], 'once')
        self.assertEqual(event['metaData'].get('again'), None)
Esempio n. 8
0
 def testGetTpls(self):
     app = Flask(__name__)
     with app.test_client() as c:
         testRequest = c.get('/tpls.json?which=asdfjdskfjs')
         self.assertEquals(statserv.server.send_error(request,
                                                      'template does'\
                                                      ' not exist'),
                      statserv.server.get_tpls(),
                      'The get_tpls method really shouldn\'t try to send '\
                      'back a template for \'asdfjdskfjs.\'')
     tplsempty = False
     with app.test_client() as c:
         testRequest = c.get('/tpls.json?which=')
         tplsempty = statserv.server.get_tpls()
     with app.test_client() as c:
         testRequest = c.get('/tpls.json?which=all')
         self.assertEquals(statserv.server.get_tpls(),
                           tplsempty,
                           'The get_tpls method should send back all '\
                           'templates on both which=all and which=.')
     with app.test_client() as c:
         testRequest = c.get('/tpls.json?callback=blah'\
                             '&which=header&which=home')
         header = open(statserv.server.determine_path()\
                           + '/tpls/header.tpl').read()
         home = open(statserv.server.determine_path()\
                           + '/tpls/home.tpl').read()
         response = statserv.server.make_response('blah', dict({'header':
                                                                    header,
                                                                'home':
                                                                    home}))
         self.assertEquals(statserv.server.get_tpls(),
                           response,
                           'The single-template support does not seem to be '\
                           'working properly.')
Esempio n. 9
0
class KazooTestCase(unittest.TestCase):

    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['TESTING'] = True
        self.kazoo = Kazoo(self.app)

        self.node_prefix = '/kazoo-test/'
        self.node_name = 'test-kazoo-node'
        self.node_value = uuid.uuid1().bytes
        self.app.extensions['kazoo']['client'].create('/kazoo-test/%s' % self.node_name, self.node_value, makepath=True)

        @self.app.route('/write/<path>')
        def write(path):
            return str(kazoo_client.set(self.node_prefix + path, b'OK'))

        @self.app.route('/read/<path>')
        def read(path):
            return kazoo_client.get(self.node_prefix + path)[0]

    def tearDown(self):
        try:
            self.app.extensions['kazoo']['client'].delete(self.node_prefix + self.node_name)
        except NoNodeError:
            print "Teardown got NoNodeError"

    def test_kazoo_set(self):
        with self.app.test_client() as c:
            results = c.get('/write/%s' % self.node_name)
            self.assertEqual(results.status_code, 200)

    def test_kazoo_read(self):
        with self.app.test_client() as c:
            results = c.get('/read/%s' % self.node_name)
            self.assertEqual(results.data, self.node_value)
class LogicalAPINamingTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.tasks = [{"id": 1, "task": "Do the laundry"}, {"id": 2, "task": "Do the dishes"}]

        def get_tasks(request):
            return self.tasks

        def post_task(request):
            data = request.json
            self.tasks.append({"task": data["task"]})
            return {}, 201

        self.get_tasks = get_tasks
        self.post_task = post_task

    def test_naming_a_single_api(self):
        api_1 = Api(version="v1", name="read-only-methods")
        api_1.register_endpoint(ApiEndpoint(http_method="GET", endpoint="/task/", handler=self.get_tasks))

        self.app.register_blueprint(api_1)
        self.app.config["TESTING"] = True

        client = self.app.test_client()
        resp = client.get("/v1/task/", content_type="application/json")
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.headers["Content-Type"], "application/json")

        data = json.loads(resp.data.decode(resp.charset))
        self.assertEqual(len(data), 2)

    def test_naming_multiple_logic_apis(self):
        api_1 = Api(version="v1", name="read-only-methods")
        api_1.register_endpoint(ApiEndpoint(http_method="GET", endpoint="/task/", handler=self.get_tasks))

        self.app.register_blueprint(api_1)

        api_2 = Api(version="v1", name="write-methods")
        api_2.register_endpoint(ApiEndpoint(http_method="POST", endpoint="/task/", handler=self.post_task))
        self.app.register_blueprint(api_2)
        self.app.config["TESTING"] = True

        client = self.app.test_client()

        # Testing GET
        resp = client.get("/v1/task/", content_type="application/json")
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.headers["Content-Type"], "application/json")

        data = json.loads(resp.data.decode(resp.charset))
        self.assertEqual(len(data), 2)

        # Testing POST
        resp = client.post("/v1/task/", content_type="application/json", data=json.dumps({"task": "New Task!"}))

        self.assertEqual(resp.status_code, 201)
        self.assertEqual(resp.headers["Content-Type"], "application/json")
        self.assertEqual(len(self.tasks), 3)
Esempio n. 11
0
def test_modal_edit():
    # bootstrap 2 - test edit_modal
    app_bs2 = Flask(__name__)
    admin_bs2 = Admin(app_bs2, template_mode="bootstrap2")

    class EditModalOn(fileadmin.FileAdmin):
        edit_modal = True
        editable_extensions = ('txt',)

    class EditModalOff(fileadmin.FileAdmin):
        edit_modal = False
        editable_extensions = ('txt',)

    path = op.join(op.dirname(__file__), 'files')
    edit_modal_on = EditModalOn(path, '/files/', endpoint='edit_modal_on')
    edit_modal_off = EditModalOff(path, '/files/', endpoint='edit_modal_off')

    admin_bs2.add_view(edit_modal_on)
    admin_bs2.add_view(edit_modal_off)

    client_bs2 = app_bs2.test_client()

    # bootstrap 2 - ensure modal window is added when edit_modal is enabled
    rv = client_bs2.get('/admin/edit_modal_on/')
    eq_(rv.status_code, 200)
    data = rv.data.decode('utf-8')
    ok_('fa_modal_window' in data)

    # bootstrap 2 - test edit modal disabled
    rv = client_bs2.get('/admin/edit_modal_off/')
    eq_(rv.status_code, 200)
    data = rv.data.decode('utf-8')
    ok_('fa_modal_window' not in data)

    # bootstrap 3
    app_bs3 = Flask(__name__)
    admin_bs3 = Admin(app_bs3, template_mode="bootstrap3")

    admin_bs3.add_view(edit_modal_on)
    admin_bs3.add_view(edit_modal_off)

    client_bs3 = app_bs3.test_client()

    # bootstrap 3 - ensure modal window is added when edit_modal is enabled
    rv = client_bs3.get('/admin/edit_modal_on/')
    eq_(rv.status_code, 200)
    data = rv.data.decode('utf-8')
    ok_('fa_modal_window' in data)

    # bootstrap 3 - test modal disabled
    rv = client_bs3.get('/admin/edit_modal_off/')
    eq_(rv.status_code, 200)
    data = rv.data.decode('utf-8')
    ok_('fa_modal_window' not in data)
Esempio n. 12
0
class ExposeHeadersTestCase(FlaskCorsTestCase):
    def setUp(self):
        self.app = Flask(__name__)

        @self.app.route('/test_default')
        @cross_origin()
        def test_default():
            return 'Welcome!'

        @self.app.route('/test_list')
        @cross_origin(expose_headers=["Foo", "Bar"])
        def test_list():
            return 'Welcome!'

        @self.app.route('/test_string')
        @cross_origin(expose_headers="Foo")
        def test_string():
            return 'Welcome!'

        @self.app.route('/test_set')
        @cross_origin(expose_headers=set(["Foo", "Bar"]))
        def test_set():
            return 'Welcome!'

    def test_default(self):
        with self.app.test_client() as c:
            resp =  c.get('/test_default')
            self.assertTrue(resp.headers.get(ACL_EXPOSE_HEADERS) is None,
                            "Default should have no allowed headers")

    def test_list_serialized(self):
        ''' If there is an Origin header in the request,
            the Access-Control-Allow-Origin header should be echoed back.
        '''
        with self.app.test_client() as c:
            resp =  c.get('/test_list')
            self.assertEqual(resp.headers.get(ACL_EXPOSE_HEADERS), 'Bar, Foo')

    def test_string_serialized(self):
        ''' If there is an Origin header in the request, the
            Access-Control-Allow-Origin header should be echoed back.
        '''
        with self.app.test_client() as c:
            resp =  c.get('/test_string')
            self.assertEqual(resp.headers.get(ACL_EXPOSE_HEADERS), 'Foo')

    def test_set_serialized(self):
        ''' If there is an Origin header in the request, the
            Access-Control-Allow-Origin header should be echoed back.
        '''
        with self.app.test_client() as c:
            resp =  c.get('/test_set')
            self.assertEqual(resp.headers.get(ACL_EXPOSE_HEADERS), 'Bar, Foo')
Esempio n. 13
0
class TestCase(unittest.TestCase):
    """An helper mixin for common operations"""

    def setUp(self):
        """Initialize an Flask application"""
        self.app = Flask(__name__)

    @contextmanager
    def context(self):
        with self.app.test_request_context("/"):
            yield

    @contextmanager
    def settings(self, **settings):
        """
        A context manager to alter app settings during a test and restore it after..
        """
        original = {}

        # backup
        for key, value in settings.items():
            original[key] = self.app.config.get(key)
            self.app.config[key] = value

        yield

        # restore
        for key, value in original.items():
            self.app.config[key] = value

    def get_specs(self, prefix="/api", app=None, status=200):
        """Get a Swagger specification for a RestPlus API"""
        with self.app.test_client() as client:
            response = client.get("{0}/swagger.json".format(prefix))
            self.assertEquals(response.status_code, status)
            self.assertEquals(response.content_type, "application/json")
            return json.loads(response.data.decode("utf8"))

    def get_declaration(self, namespace="default", prefix="/api", status=200, app=None):
        """Get an API declaration for a given namespace"""
        with self.app.test_client() as client:
            response = client.get("{0}/swagger.json".format(prefix, namespace))
            self.assertEquals(response.status_code, status)
            self.assertEquals(response.content_type, "application/json")
            return json.loads(response.data.decode("utf8"))

    def get_json(self, url, status=200):
        with self.app.test_client() as client:
            response = client.get(url)

        self.assertEquals(response.status_code, status)
        self.assertEquals(response.content_type, "application/json")
        return json.loads(response.data.decode("utf8"))
Esempio n. 14
0
    def test_multiple_apps(self):
        app1 = Flask(__name__)
        app2 = Flask(__name__)

        limiter = Limiter(global_limits = ["1/second"])
        limiter.init_app(app1)
        limiter.init_app(app2)

        @app1.route("/ping")
        def ping():
            return "PONG"

        @app1.route("/slowping")
        @limiter.limit("1/minute")
        def slow_ping():
            return "PONG"


        @app2.route("/ping")
        @limiter.limit("2/second")
        def ping_2():
            return "PONG"

        @app2.route("/slowping")
        @limiter.limit("2/minute")
        def slow_ping_2():
            return "PONG"

        with hiro.Timeline().freeze() as timeline:
            with app1.test_client() as cli:
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                timeline.forward(59)
                self.assertEqual(cli.get("/slowping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/slowping").status_code, 200)
            with app2.test_client() as cli:
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/ping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/ping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                timeline.forward(59)
                self.assertEqual(cli.get("/slowping").status_code, 200)
                self.assertEqual(cli.get("/slowping").status_code, 429)
                timeline.forward(1)
                self.assertEqual(cli.get("/slowping").status_code, 200)
Esempio n. 15
0
def test_bugsnag_notify(deliver):
    app = Flask("bugsnag")

    @app.route("/hello")
    def hello():
        bugsnag.notify(SentinalError("oops"))
        return "OK"

    handle_exceptions(app)
    app.test_client().get('/hello')

    eq_(deliver.call_count, 1)
    payload = deliver.call_args[0][0]
    eq_(payload['events'][0]['metaData']['request']['url'], 'http://localhost/hello')
Esempio n. 16
0
def test_bugsnag_crash(deliver):
    app = Flask("bugsnag")

    @app.route("/hello")
    def hello():
        raise SentinalError("oops")

    handle_exceptions(app)
    app.test_client().get('/hello')

    eq_(deliver.call_count, 1)
    payload = deliver.call_args[0][0]
    eq_(payload['events'][0]['exceptions'][0]['errorClass'], 'test_flask.SentinalError')
    eq_(payload['events'][0]['metaData']['request']['url'], 'http://localhost/hello')
Esempio n. 17
0
def test_bugsnag_crash(deliver):
    app = Flask("bugsnag")

    @app.route("/hello")
    def hello():
        raise SentinalError("oops")

    handle_exceptions(app)
    app.test_client().get("/hello")

    eq_(deliver.call_count, 1)
    payload = deliver.call_args[0][0]
    eq_(payload["events"][0]["exceptions"][0]["errorClass"], "test_flask.SentinalError")
    eq_(payload["events"][0]["metaData"]["request"]["url"], "http://localhost/hello")
Esempio n. 18
0
    def test_bugsnag_notify(self):
        app = Flask("bugsnag")

        @app.route("/hello")
        def hello():
            bugsnag.notify(SentinelError("oops"))
            return "OK"

        handle_exceptions(app)
        app.test_client().get('/hello')

        self.assertEqual(1, len(self.server.received))
        payload = self.server.received[0]['json_body']
        self.assertEqual(payload['events'][0]['metaData']['request']['url'],
                         'http://localhost/hello')
Esempio n. 19
0
def test_bugsnag_includes_unknown_content_type_posted_data(deliver):
    app = Flask("bugsnag")

    @app.route("/form", methods=["PUT"])
    def hello():
        raise SentinalError("oops")

    handle_exceptions(app)
    app.test_client().put("/form", data="_data", content_type="application/octet-stream")

    eq_(deliver.call_count, 1)
    payload = deliver.call_args[0][0]
    eq_(payload["events"][0]["exceptions"][0]["errorClass"], "test_flask.SentinalError")
    eq_(payload["events"][0]["metaData"]["request"]["url"], "http://localhost/form")
    ok_("_data" in payload["events"][0]["metaData"]["request"]["data"]["body"])
Esempio n. 20
0
def test_bugsnag_includes_posted_json_data(deliver):
    app = Flask("bugsnag")

    @app.route("/ajax", methods=["POST"])
    def hello():
        raise SentinalError("oops")

    handle_exceptions(app)
    app.test_client().post("/ajax", data='{"key": "value"}', content_type="application/json")

    eq_(deliver.call_count, 1)
    payload = deliver.call_args[0][0]
    eq_(payload["events"][0]["exceptions"][0]["errorClass"], "test_flask.SentinalError")
    eq_(payload["events"][0]["metaData"]["request"]["url"], "http://localhost/ajax")
    eq_(payload["events"][0]["metaData"]["request"]["data"], dict(key="value"))
Esempio n. 21
0
def test_menu_render():
    menu = decorators.menu
    menu.clear()
    app = Flask(__name__)
    app.testing = True

    @menu("Hello World", group_name="admin")
    class Hello(object):

        @menu("Index")
        def index(self):
            pass

        @menu("Page 2")
        def index2(self):
            pass

    @menu("Monster")
    class Monster(object):

        @menu("Home")
        def maggi(self):
            pass

    with app.test_client() as c:
        c.get("/")
        assert len(menu.render()) == 2
Esempio n. 22
0
    def test_flask(self):
        app = Flask(__name__)
        db = SQLAlchemy(app)
        app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:'

        class Cities(db.Model):
            __tablename__ = 'users'

            id = Column(Integer, primary_key=True)
            name = Column(String)
            population = Column(Integer)

            def __init__(self, name, population):
                self.name = name
                self.population = population
                
        app.config['TESTING'] = True
        app = app.test_client()
        db.create_all()

        city = Cities("Cordoba", 1000000)
        db.session.add(city)
        city = Cities("Rafaela", 99000)
        db.session.add(city)
        db.session.commit()

        query_string = '{ "sort": { "population" : "desc" } }'
        results = elastic_query(Cities, query_string)
        assert(results[0].name == 'Cordoba')
Esempio n. 23
0
def main():
    app = Flask(__name__)
    app.config.update(
        DB_CONNECTION_STRING=':memory:',
        CACHE_TYPE='simple',
        SQLALCHEMY_DATABASE_URI='sqlite://',
    )
    app.debug = True
    injector = init_app(app=app, modules=[AppModule])

    configure_views(app=app, cached=injector.get(Cache).cached)

    post_init_app(app, injector)

    client = app.test_client()

    response = client.get('/')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.post('/', data={'key': 'foo', 'value': 'bar'})
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.get('/')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.get('/hello')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.delete('/hello')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.get('/')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.get('/hello')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
    response = client.delete('/hello')
    print('%s\n%s%s' % (response.status, response.headers, response.data))
Esempio n. 24
0
    def setUp(self):
        app = Flask(__name__)
        Funnel(app)

        app.config['CSS_BUNDLES'] = {
            'css-bundle': (
                'css/test.css',
                'less/test.less',
                'scss/test.scss',
                'stylus/test.styl',
            ),
        }

        app.config['JS_BUNDLES'] = {
            'js-bundle': (
                'js/test1.js',
                'js/test2.js',
                'coffee/test.coffee',
            ),
        }

        @app.route('/')
        def index():
            return render_template_string(
                "{{ css('css-bundle') }} {{ js('js-bundle') }}")

        self.app = app
        self.client = app.test_client()
Esempio n. 25
0
    def test_custom_headers_from_config(self):
        app = Flask(__name__)
        app.config.setdefault(C.HEADER_LIMIT, "X-Limit")
        app.config.setdefault(C.HEADER_REMAINING, "X-Remaining")
        app.config.setdefault(C.HEADER_RESET, "X-Reset")
        limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True)
        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(
                    resp.headers.get('X-Limit'),
                    '10'
                )
                self.assertEqual(
                    resp.headers.get('X-Remaining'),
                    '0'
                )
                self.assertEqual(
                    resp.headers.get('X-Reset'),
                    str(int(time.time() + 49))
                )
Esempio n. 26
0
    def test_whitelisting(self):

        app = Flask(__name__)
        limiter = Limiter(app, global_limits=["1/minute"], headers_enabled=True)

        @app.route("/")
        def t():
            return "test"

        @limiter.request_filter
        def w():
            if request.headers.get("internal", None) == "true":
                return True
            return False

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                self.assertEqual(cli.get("/").status_code, 200)
                self.assertEqual(cli.get("/").status_code, 429)
                timeline.forward(60)
                self.assertEqual(cli.get("/").status_code, 200)

                for i in range(0,10):
                    self.assertEqual(
                        cli.get("/", headers = {"internal": "true"}).status_code,
                        200
                    )
Esempio n. 27
0
    def setUp(self):
        self.engine = Engine("op1", "op2")
        self.engine.op1.setup(in_name="in", out_name="middle", required=False)
        self.engine.op2.setup(in_name="middle", out_name="out")

        self.engine.op1.set(OptProductEx())
        foisdouze = OptProductEx("foisdouze")
        foisdouze.force_option_value("factor", 12)
        self.engine.op2.set(foisdouze, OptProductEx())

        egn_view = EngineView(self.engine, name="my_egn")
        egn_view.add_input("in", Numeric(vtype=int, min=-5, max=5))
        egn_view.add_input("middle", Numeric(vtype=int))
        print(self.engine.needed_inputs())
        egn_view.add_output("in")
        egn_view.add_output("middle")
        egn_view.add_output("out")

        api = ReliureAPI()
        api.register_view(egn_view)

        app = Flask(__name__)
        app.config['TESTING'] = True
        app.register_blueprint(api, url_prefix="/api")
        self.app = app.test_client()
Esempio n. 28
0
    def test_headers_breach(self):
        app = Flask(__name__)
        limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True)

        @app.route("/t1")
        @limiter.limit("2/second; 10 per minute; 20/hour")
        def t():
            return "test"

        with hiro.Timeline().freeze() as timeline:
            with app.test_client() as cli:
                for i in range(11):
                    resp = cli.get("/t1")
                    timeline.forward(1)

                self.assertEqual(
                    resp.headers.get('X-RateLimit-Limit'),
                    '10'
                )
                self.assertEqual(
                    resp.headers.get('X-RateLimit-Remaining'),
                    '0'
                )
                self.assertEqual(
                    resp.headers.get('X-RateLimit-Reset'),
                    str(int(time.time() + 49))
                )
Esempio n. 29
0
class MethodViewLoginTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.login_manager = LoginManager()
        self.login_manager.init_app(self.app)
        self.login_manager._login_disabled = False

        class SecretEndpoint(MethodView):
            decorators = [
                login_required,
                fresh_login_required,
            ]

            def options(self):
                return u''

            def get(self):
                return u''

        self.app.add_url_rule('/secret',
                              view_func=SecretEndpoint.as_view('secret'))

    def test_options_call_exempt(self):
        with self.app.test_client() as c:
            result = c.open('/secret', method='OPTIONS')
            self.assertEqual(result.status_code, 200)
    def setUp(self):
        app = Flask(__name__)
        app.config['SECRET_KEY'] = 'my secret'

        digest_auth_my_realm = HTTPDigestAuth(realm='My Realm')

        @digest_auth_my_realm.get_password
        def get_digest_password_3(username):
            if username == 'susan':
                return 'hello'
            elif username == 'john':
                return 'bye'
            else:
                return None

        @app.route('/')
        def index():
            return 'index'

        @app.route('/digest-with-realm')
        @digest_auth_my_realm.login_required
        def digest_auth_my_realm_route():
            return 'digest_auth_my_realm:' + digest_auth_my_realm.username()

        self.app = app
        self.client = app.test_client()
Esempio n. 31
0
 def __init__(self):
     app = Flask(__name__)
     app.register_blueprint(api)
     self.app = app.test_client()
Esempio n. 32
0
def client(app: Flask) -> FlaskClient:
    with app.test_client() as c:
        return c
Esempio n. 33
0
 def setUp(cls):
     """Runs before every test case"""
     app = Flask(__name__)
     healthcheck.HealthView.register(app)
     app.config['TESTING'] = True
     cls.app = app.test_client()
Esempio n. 34
0
 def on_success(service: Flask):
     with service.test_client() as client:
         return client
class TestEndpointsWithHeadersAndCookies(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.secret_key = 'super=secret'
        self.app.config['JWT_TOKEN_LOCATION'] = ['cookies', 'headers']
        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
        self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/'
        self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/auth/refresh'
        self.jwt_manager = JWTManager(self.app)
        self.client = self.app.test_client()

        @self.app.route('/auth/login_cookies', methods=['POST'])
        def login_cookies():
            # Create the tokens we will be sending back to the user
            access_token = create_access_token(identity='test')
            refresh_token = create_refresh_token(identity='test')

            # Set the JWTs and the CSRF double submit protection cookies in this response
            resp = jsonify({'login': True})
            set_access_cookies(resp, access_token)
            set_refresh_cookies(resp, refresh_token)
            return resp, 200

        @self.app.route('/auth/login_headers', methods=['POST'])
        def login_headers():
            ret = {
                'access_token': create_access_token('test', fresh=True),
                'refresh_token': create_refresh_token('test')
            }
            return jsonify(ret), 200

        @self.app.route('/api/protected')
        @jwt_required
        def protected():
            return jsonify({'msg': "hello world"})

    def _jwt_post(self, url, jwt):
        response = self.client.post(
            url,
            content_type='application/json',
            headers={'Authorization': 'Bearer {}'.format(jwt)})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        return status_code, data

    def _jwt_get(self,
                 url,
                 jwt,
                 header_name='Authorization',
                 header_type='Bearer'):
        header_type = '{} {}'.format(header_type, jwt).strip()
        response = self.client.get(url, headers={header_name: header_type})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        return status_code, data

    def _login_cookies(self):
        resp = self.client.post('/auth/login_cookies')
        index = 1

        access_cookie_str = resp.headers[index][1]
        access_cookie_key = access_cookie_str.split('=')[0]
        access_cookie_value = "".join(access_cookie_str.split('=')[1:])
        self.client.set_cookie('localhost', access_cookie_key,
                               access_cookie_value)
        index += 1

        if self.app.config['JWT_COOKIE_CSRF_PROTECT']:
            access_csrf_str = resp.headers[index][1]
            access_csrf_key = access_csrf_str.split('=')[0]
            access_csrf_value = "".join(access_csrf_str.split('=')[1:])
            self.client.set_cookie('localhost', access_csrf_key,
                                   access_csrf_value)
            index += 1
            access_csrf = access_csrf_value.split(';')[0]
        else:
            access_csrf = ""

        refresh_cookie_str = resp.headers[index][1]
        refresh_cookie_key = refresh_cookie_str.split('=')[0]
        refresh_cookie_value = "".join(refresh_cookie_str.split('=')[1:])
        self.client.set_cookie('localhost', refresh_cookie_key,
                               refresh_cookie_value)
        index += 1

        if self.app.config['JWT_COOKIE_CSRF_PROTECT']:
            refresh_csrf_str = resp.headers[index][1]
            refresh_csrf_key = refresh_csrf_str.split('=')[0]
            refresh_csrf_value = "".join(refresh_csrf_str.split('=')[1:])
            self.client.set_cookie('localhost', refresh_csrf_key,
                                   refresh_csrf_value)
            refresh_csrf = refresh_csrf_value.split(';')[0]
        else:
            refresh_csrf = ""

        return access_csrf, refresh_csrf

    def _login_headers(self):
        resp = self.client.post('/auth/login_headers')
        data = json.loads(resp.get_data(as_text=True))
        return data['access_token'], data['refresh_token']

    def test_accessing_endpoint_with_headers(self):
        access_token, _ = self._login_headers()
        header_type = '{} {}'.format('Bearer', access_token).strip()
        response = self.client.get('/api/protected',
                                   headers={'Authorization': header_type})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})

    def test_accessing_endpoint_with_cookies(self):
        access_csrf, _ = self._login_cookies()
        response = self.client.get('/api/protected',
                                   headers={'X-CSRF-TOKEN': access_csrf})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})

    def test_accessing_endpoint_without_jwt(self):
        response = self.client.get('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)
class TestEndpointsWithCookies(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.secret_key = 'super=secret'
        self.app.config['JWT_TOKEN_LOCATION'] = 'cookies'
        self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/'
        self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/auth/refresh'
        self.app.config['JWT_ACCESS_COOKIE_NAME'] = 'access_token_cookie'
        self.app.config['JWT_ALGORITHM'] = 'HS256'
        self.jwt_manager = JWTManager(self.app)
        self.client = self.app.test_client()

        @self.app.route('/auth/login', methods=['POST'])
        def login():
            # Create the tokens we will be sending back to the user
            access_token = create_access_token(identity='test')
            refresh_token = create_refresh_token(identity='test')

            # Set the JWTs and the CSRF double submit protection cookies in this response
            resp = jsonify({'login': True})
            set_access_cookies(resp, access_token)
            set_refresh_cookies(resp, refresh_token)
            return resp, 200

        @self.app.route('/auth/logout', methods=['POST'])
        def logout():
            resp = jsonify({'logout': True})
            unset_jwt_cookies(resp)
            return resp, 200

        @self.app.route('/auth/refresh', methods=['POST'])
        @jwt_refresh_token_required
        def refresh():
            username = get_jwt_identity()
            access_token = create_access_token(username, fresh=False)
            resp = jsonify({'refresh': True})
            set_access_cookies(resp, access_token)
            return resp, 200

        @self.app.route('/api/protected', methods=['POST'])
        @jwt_required
        def protected():
            return jsonify({'msg': "hello world"})

    def _login(self):
        resp = self.client.post('/auth/login')
        index = 1

        access_cookie_str = resp.headers[index][1]
        access_cookie_key = access_cookie_str.split('=')[0]
        access_cookie_value = "".join(access_cookie_str.split('=')[1:])
        self.client.set_cookie('localhost', access_cookie_key,
                               access_cookie_value)
        index += 1

        if self.app.config['JWT_COOKIE_CSRF_PROTECT']:
            access_csrf_str = resp.headers[index][1]
            access_csrf_key = access_csrf_str.split('=')[0]
            access_csrf_value = "".join(access_csrf_str.split('=')[1:])
            self.client.set_cookie('localhost', access_csrf_key,
                                   access_csrf_value)
            index += 1
            access_csrf = access_csrf_value.split(';')[0]
        else:
            access_csrf = ""

        refresh_cookie_str = resp.headers[index][1]
        refresh_cookie_key = refresh_cookie_str.split('=')[0]
        refresh_cookie_value = "".join(refresh_cookie_str.split('=')[1:])
        self.client.set_cookie('localhost', refresh_cookie_key,
                               refresh_cookie_value)
        index += 1

        if self.app.config['JWT_COOKIE_CSRF_PROTECT']:
            refresh_csrf_str = resp.headers[index][1]
            refresh_csrf_key = refresh_csrf_str.split('=')[0]
            refresh_csrf_value = "".join(refresh_csrf_str.split('=')[1:])
            self.client.set_cookie('localhost', refresh_csrf_key,
                                   refresh_csrf_value)
            refresh_csrf = refresh_csrf_value.split(';')[0]
        else:
            refresh_csrf = ""

        return access_csrf, refresh_csrf

    def test_headers(self):
        # Try with default options
        resp = self.client.post('/auth/login')
        access_cookie = resp.headers[1][1]
        access_csrf = resp.headers[2][1]
        refresh_cookie = resp.headers[3][1]
        refresh_csrf = resp.headers[4][1]
        self.assertIn('access_token_cookie', access_cookie)
        self.assertIn('csrf_access_token', access_csrf)
        self.assertIn('Path=/', access_csrf)
        self.assertIn('refresh_token_cookie', refresh_cookie)
        self.assertIn('csrf_refresh_token', refresh_csrf)
        self.assertIn('Path=/', refresh_csrf)

        # Try with overwritten options
        self.app.config['JWT_ACCESS_COOKIE_NAME'] = 'new_access_cookie'
        self.app.config['JWT_REFRESH_COOKIE_NAME'] = 'new_refresh_cookie'
        self.app.config['JWT_ACCESS_CSRF_COOKIE_NAME'] = 'x_csrf_access_token'
        self.app.config[
            'JWT_REFRESH_CSRF_COOKIE_NAME'] = 'x_csrf_refresh_token'
        self.app.config['JWT_ACCESS_COOKIE_PATH'] = None
        self.app.config['JWT_REFRESH_COOKIE_PATH'] = None

        resp = self.client.post('/auth/login')
        access_cookie = resp.headers[1][1]
        access_csrf = resp.headers[2][1]
        refresh_cookie = resp.headers[3][1]
        refresh_csrf = resp.headers[4][1]
        self.assertIn('new_access_cookie', access_cookie)
        self.assertIn('x_csrf_access_token', access_csrf)
        self.assertIn('Path=/', access_csrf)
        self.assertIn('new_refresh_cookie', refresh_cookie)
        self.assertIn('x_csrf_refresh_token', refresh_csrf)
        self.assertIn('Path=/', refresh_csrf)

        # Try logout headers
        resp = self.client.post('/auth/logout')
        refresh_cookie = resp.headers[1][1]
        access_cookie = resp.headers[2][1]
        self.assertIn('Expires=Thu, 01-Jan-1970', refresh_cookie)
        self.assertIn('Expires=Thu, 01-Jan-1970', access_cookie)

    def test_endpoints_with_cookies(self):
        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False

        # Try access without logging in
        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Try refresh without logging in
        response = self.client.post('/auth/refresh')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Try with logging in
        self._login()
        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})

        # Try refresh without logging in
        response = self.client.post('/auth/refresh')
        access_cookie_str = response.headers[1][1]
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertIn('access_token_cookie', access_cookie_str)
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'refresh': True})

        # Try accessing endpoint with newly refreshed token
        access_cookie_key = access_cookie_str.split('=')[0]
        access_cookie_value = "".join(access_cookie_str.split('=')[1:])
        self.client.set_cookie('localhost', access_cookie_key,
                               access_cookie_value)
        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})

    def test_access_endpoints_with_cookies_and_csrf(self):
        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True

        # Try without logging in
        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Login
        access_csrf, refresh_csrf = self._login()

        # Try with logging in but without double submit csrf protection
        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Try with logged in and bad header name for double submit token
        response = self.client.post('/api/protected',
                                    headers={'bad-header-name': 'banana'})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Try with logged in and bad header data for double submit token
        response = self.client.post('/api/protected',
                                    headers={'X-CSRF-TOKEN': 'banana'})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Try with logged in and good double submit token
        response = self.client.post('/api/protected',
                                    headers={'X-CSRF-TOKEN': access_csrf})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})

    def test_access_endpoints_with_cookie_missing_csrf_field(self):
        # Test accessing a csrf protected endpoint with a cookie that does not
        # have a csrf token in it
        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False
        self._login()
        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True

        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

    def test_access_endpoints_with_cookie_csrf_claim_not_string(self):
        now = datetime.utcnow()
        token_data = {
            'exp': now + timedelta(minutes=5),
            'iat': now,
            'nbf': now,
            'jti': 'banana',
            'identity': 'banana',
            'type': 'refresh',
            'csrf': 404
        }
        secret = self.app.secret_key
        algorithm = self.app.config['JWT_ALGORITHM']
        encoded_token = jwt.encode(token_data, secret,
                                   algorithm).decode('utf-8')
        access_cookie_key = self.app.config['JWT_ACCESS_COOKIE_NAME']
        self.client.set_cookie('localhost', access_cookie_key, encoded_token)

        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
        response = self.client.post('/api/protected')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

    def test_custom_csrf_methods(self):
        @self.app.route('/protected-post', methods=['POST'])
        @jwt_required
        def protected_post():
            return jsonify({'msg': "hello world"})

        @self.app.route('/protected-get', methods=['GET'])
        @jwt_required
        def protected_get():
            return jsonify({'msg': "hello world"})

        # Login (saves jwts in the cookies for the test client
        self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
        self._login()

        # Test being able to access GET without CSRF protection, and POST with
        # CSRF protection
        self.app.config['JWT_CSRF_METHODS'] = ['POST']

        response = self.client.post('/protected-post')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        response = self.client.get('/protected-get')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})

        # Now swap it around, and verify the JWT_CRSF_METHODS are being honored
        self.app.config['JWT_CSRF_METHODS'] = ['GET']

        response = self.client.get('/protected-get')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        response = self.client.post('/protected-post')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertEqual(data, {'msg': 'hello world'})
Esempio n. 37
0
class LoginTestCase(unittest.TestCase):
    ''' Tests for results of the login_user function '''
    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SECRET_KEY'] = 'deterministic'
        self.app.config['SESSION_PROTECTION'] = None
        self.app.config['TESTING'] = True
        self.remember_cookie_name = 'remember'
        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
        self.login_manager = LoginManager()
        self.login_manager.init_app(self.app)
        self.login_manager._login_disabled = False

        @self.app.route('/')
        def index():
            return u'Welcome!'

        @self.app.route('/secret')
        def secret():
            return self.login_manager.unauthorized()

        @self.app.route('/login-notch')
        def login_notch():
            return unicode(login_user(notch))

        @self.app.route('/login-notch-remember')
        def login_notch_remember():
            return unicode(login_user(notch, remember=True))

        @self.app.route('/login-notch-permanent')
        def login_notch_permanent():
            session.permanent = True
            return unicode(login_user(notch))

        @self.app.route('/needs-refresh')
        def needs_refresh():
            return self.login_manager.needs_refresh()

        @self.app.route('/confirm-login')
        def _confirm_login():
            confirm_login()
            return u''

        @self.app.route('/username')
        def username():
            if current_user.is_authenticated():
                return current_user.name
            return u'Anonymous'

        @self.app.route('/is-fresh')
        def is_fresh():
            return unicode(login_fresh())

        @self.app.route('/logout')
        def logout():
            return unicode(logout_user())

        @self.login_manager.user_loader
        def load_user(user_id):
            return USERS[int(user_id)]

        @self.login_manager.header_loader
        def load_user_from_header(header_value):
            if header_value.startswith('Basic '):
                header_value = header_value.replace('Basic ', '', 1)
            try:
                user_id = base64.b64decode(header_value)
            except TypeError:
                pass
            return USERS.get(int(user_id))

        @self.login_manager.request_loader
        def load_user_from_request(request):
            user_id = request.args.get('user_id')
            try:
                user_id = int(float(user_id))
            except TypeError:
                pass
            return USERS.get(user_id)

        @self.app.route('/empty_session')
        def empty_session():
            return unicode(u'modified=%s' % session.modified)

        # This will help us with the possibility of typoes in the tests. Now
        # we shouldn't have to check each response to help us set up state
        # (such as login pages) to make sure it worked: we will always
        # get an exception raised (rather than return a 404 response)
        @self.app.errorhandler(404)
        def handle_404(e):
            raise e

        unittest.TestCase.setUp(self)

    def _get_remember_cookie(self, test_client):
        our_cookies = test_client.cookie_jar._cookies['localhost.local']['/']
        return our_cookies[self.remember_cookie_name]

    def _delete_session(self, c):
        # Helper method to cause the session to be deleted
        # as if the browser was closed. This will remove
        # the session regardless of the permament flag
        # on the session!
        with c.session_transaction() as sess:
            sess.clear()

    #
    # Login
    #
    def test_test_request_context_users_are_anonymous(self):
        with self.app.test_request_context():
            self.assertTrue(current_user.is_anonymous())

    def test_defaults_anonymous(self):
        with self.app.test_client() as c:
            result = c.get('/username')
            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))

    def test_login_user(self):
        with self.app.test_request_context():
            result = login_user(notch)
            self.assertTrue(result)
            self.assertEqual(current_user.name, u'Notch')

    def test_login_user_emits_signal(self):
        with self.app.test_request_context():
            with listen_to(user_logged_in) as listener:
                login_user(notch)
                listener.assert_heard_one(self.app, user=notch)

    def test_login_inactive_user(self):
        with self.app.test_request_context():
            result = login_user(creeper)
            self.assertTrue(current_user.is_anonymous())
            self.assertFalse(result)

    def test_login_inactive_user_forced(self):
        with self.app.test_request_context():
            login_user(creeper, force=True)
            self.assertEqual(current_user.name, u'Creeper')

    def test_login_user_with_header(self):
        user_id = 2
        user_name = USERS[user_id].name
        self.login_manager.request_callback = None
        with self.app.test_client() as c:
            basic_fmt = 'Basic {0}'
            decoded = bytes.decode(base64.b64encode(str.encode(str(user_id))))
            headers = [('Authorization', basic_fmt.format(decoded))]
            result = c.get('/username', headers=headers)
            self.assertEqual(user_name, result.data.decode('utf-8'))

    def test_login_invalid_user_with_header(self):
        user_id = 4
        user_name = u'Anonymous'
        self.login_manager.request_callback = None
        with self.app.test_client() as c:
            basic_fmt = 'Basic {0}'
            decoded = bytes.decode(base64.b64encode(str.encode(str(user_id))))
            headers = [('Authorization', basic_fmt.format(decoded))]
            result = c.get('/username', headers=headers)
            self.assertEqual(user_name, result.data.decode('utf-8'))

    def test_login_user_with_request(self):
        user_id = 2
        user_name = USERS[user_id].name
        with self.app.test_client() as c:
            url = '/username?user_id={user_id}'.format(user_id=user_id)
            result = c.get(url)
            self.assertEqual(user_name, result.data.decode('utf-8'))

    def test_login_invalid_user_with_request(self):
        user_id = 4
        user_name = u'Anonymous'
        with self.app.test_client() as c:
            url = '/username?user_id={user_id}'.format(user_id=user_id)
            result = c.get(url)
            self.assertEqual(user_name, result.data.decode('utf-8'))

    #
    # Logout
    #
    def test_logout_logs_out_current_user(self):
        with self.app.test_request_context():
            login_user(notch)
            logout_user()
            self.assertTrue(current_user.is_anonymous())

    def test_logout_emits_signal(self):
        with self.app.test_request_context():
            login_user(notch)
            with listen_to(user_logged_out) as listener:
                logout_user()
                listener.assert_heard_one(self.app, user=notch)

    #
    # Unauthorized
    #
    def test_unauthorized_fires_unauthorized_signal(self):
        with self.app.test_client() as c:
            with listen_to(user_unauthorized) as listener:
                c.get('/secret')
                listener.assert_heard_one(self.app)

    def test_unauthorized_flashes_message_with_login_view(self):
        self.login_manager.login_view = '/login'

        expected_message = self.login_manager.login_message = u'Log in!'
        expected_category = self.login_manager.login_message_category = 'login'

        with self.app.test_client() as c:
            c.get('/secret')
            msgs = get_flashed_messages(category_filter=[expected_category])
            self.assertEqual([expected_message], msgs)

    def test_unauthorized_flash_message_localized(self):
        def _gettext(msg):
            if msg == u'Log in!':
                return u'Einloggen'

        self.login_manager.login_view = '/login'
        self.login_manager.localize_callback = _gettext
        self.login_manager.login_message = u'Log in!'

        expected_message = u'Einloggen'
        expected_category = self.login_manager.login_message_category = 'login'

        with self.app.test_client() as c:
            c.get('/secret')
            msgs = get_flashed_messages(category_filter=[expected_category])
            self.assertEqual([expected_message], msgs)
        self.login_manager.localize_callback = None

    def test_unauthorized_uses_authorized_handler(self):
        @self.login_manager.unauthorized_handler
        def _callback():
            return Response('This is secret!', 401)

        with self.app.test_client() as c:
            result = c.get('/secret')
            self.assertEqual(result.status_code, 401)
            self.assertEqual(u'This is secret!', result.data.decode('utf-8'))

    def test_unauthorized_aborts_with_401(self):
        with self.app.test_client() as c:
            result = c.get('/secret')
            self.assertEqual(result.status_code, 401)

    def test_unauthorized_redirects_to_login_view(self):
        self.login_manager.login_view = 'login'

        @self.app.route('/login')
        def login():
            return 'Login Form Goes Here!'

        with self.app.test_client() as c:
            result = c.get('/secret')
            self.assertEqual(result.status_code, 302)
            self.assertEqual(result.location,
                             'http://localhost/login?next=%2Fsecret')

    #
    # Session Persistence/Freshness
    #
    def test_login_persists(self):
        with self.app.test_client() as c:
            c.get('/login-notch')
            result = c.get('/username')

            self.assertEqual(u'Notch', result.data.decode('utf-8'))

    def test_logout_persists(self):
        with self.app.test_client() as c:
            c.get('/login-notch')
            c.get('/logout')
            result = c.get('/username')
            self.assertEqual(result.data.decode('utf-8'), u'Anonymous')

    def test_incorrect_id_logs_out(self):
        # Ensure that any attempt to reload the user by the ID
        # will seem as if the user is no longer valid
        @self.login_manager.user_loader
        def new_user_loader(user_id):
            return

        with self.app.test_client() as c:
            # Successfully logs in
            c.get('/login-notch')
            result = c.get('/username')

            self.assertEqual(u'Anonymous', result.data.decode('utf-8'))

    def test_authentication_is_fresh(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            result = c.get('/is-fresh')
            self.assertEqual(u'True', result.data.decode('utf-8'))

    def test_remember_me(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)
            result = c.get('/username')
            self.assertEqual(u'Notch', result.data.decode('utf-8'))

    def test_remember_me_uses_custom_cookie_parameters(self):
        name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname'
        duration = self.app.config['REMEMBER_COOKIE_DURATION'] = \
            timedelta(days=2)
        domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local'

        with self.app.test_client() as c:
            c.get('/login-notch-remember')

            # TODO: Is there a better way to test this?
            self.assertTrue(domain in c.cookie_jar._cookies,
                            'Custom domain not found as cookie domain')
            domain_cookie = c.cookie_jar._cookies[domain]
            self.assertTrue(name in domain_cookie['/'],
                            'Custom name not found as cookie name')
            cookie = domain_cookie['/'][name]

            expiration_date = datetime.fromtimestamp(cookie.expires)
            expected_date = datetime.now() + duration
            difference = expected_date - expiration_date

            fail_msg = 'The expiration date {0} was far from the expected {1}'
            fail_msg = fail_msg.format(expiration_date, expected_date)
            self.assertLess(difference, timedelta(seconds=10), fail_msg)
            self.assertGreater(difference, timedelta(seconds=-10), fail_msg)

    def test_remember_me_is_unfresh(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)
            self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8'))

    def test_user_loaded_from_cookie_fired(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)
            with listen_to(user_loaded_from_cookie) as listener:
                c.get('/username')
                listener.assert_heard_one(self.app, user=notch)

    def test_user_loaded_from_header_fired(self):
        user_id = 1
        user_name = USERS[user_id].name
        self.login_manager.request_callback = None
        with self.app.test_client() as c:
            with listen_to(user_loaded_from_header) as listener:
                headers = [(
                    'Authorization',
                    'Basic %s' %
                    (bytes.decode(base64.b64encode(str.encode(str(user_id))))),
                )]
                result = c.get('/username', headers=headers)
                self.assertEqual(user_name, result.data.decode('utf-8'))
                listener.assert_heard_one(self.app, user=USERS[user_id])

    def test_user_loaded_from_request_fired(self):
        user_id = 1
        user_name = USERS[user_id].name
        with self.app.test_client() as c:
            with listen_to(user_loaded_from_request) as listener:
                url = '/username?user_id={user_id}'.format(user_id=user_id)
                result = c.get(url)
                self.assertEqual(user_name, result.data.decode('utf-8'))
                listener.assert_heard_one(self.app, user=USERS[user_id])

    def test_logout_stays_logged_out_with_remember_me(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            c.get('/logout')
            result = c.get('/username')
            self.assertEqual(result.data.decode('utf-8'), u'Anonymous')

    def test_needs_refresh_uses_handler(self):
        @self.login_manager.needs_refresh_handler
        def _on_refresh():
            return u'Needs Refresh!'

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            result = c.get('/needs-refresh')
            self.assertEqual(u'Needs Refresh!', result.data.decode('utf-8'))

    def test_needs_refresh_fires_needs_refresh_signal(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            with listen_to(user_needs_refresh) as listener:
                c.get('/needs-refresh')
                listener.assert_heard_one(self.app)

    def test_needs_refresh_fires_flash_when_redirect_to_refresh_view(self):
        self.login_manager.refresh_view = '/refresh_view'

        self.login_manager.needs_refresh_message = u'Refresh'
        self.login_manager.needs_refresh_message_category = 'refresh'
        category_filter = [self.login_manager.needs_refresh_message_category]

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            c.get('/needs-refresh')
            msgs = get_flashed_messages(category_filter=category_filter)
            self.assertIn(self.login_manager.needs_refresh_message, msgs)

    def test_needs_refresh_flash_message_localized(self):
        def _gettext(msg):
            if msg == u'Refresh':
                return u'Aktualisieren'

        self.login_manager.refresh_view = '/refresh_view'
        self.login_manager.localize_callback = _gettext

        self.login_manager.needs_refresh_message = u'Refresh'
        self.login_manager.needs_refresh_message_category = 'refresh'
        category_filter = [self.login_manager.needs_refresh_message_category]

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            c.get('/needs-refresh')
            msgs = get_flashed_messages(category_filter=category_filter)
            self.assertIn(u'Aktualisieren', msgs)
        self.login_manager.localize_callback = None

    def test_needs_refresh_aborts_403(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            result = c.get('/needs-refresh')
            self.assertEqual(result.status_code, 403)

    def test_redirects_to_refresh_view(self):
        @self.app.route('/refresh-view')
        def refresh_view():
            return ''

        self.login_manager.refresh_view = 'refresh_view'
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            result = c.get('/needs-refresh')
            self.assertEqual(result.status_code, 302)
            expected = 'http://localhost/refresh-view?next=%2Fneeds-refresh'
            self.assertEqual(result.location, expected)

    def test_confirm_login(self):
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)
            self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8'))
            c.get('/confirm-login')
            self.assertEqual(u'True', c.get('/is-fresh').data.decode('utf-8'))

    def test_user_login_confirmed_signal_fired(self):
        with self.app.test_client() as c:
            with listen_to(user_login_confirmed) as listener:
                c.get('/confirm-login')
                listener.assert_heard_one(self.app)

    def test_session_not_modified(self):
        with self.app.test_client() as c:
            # Within the request we think we didn't modify the session.
            self.assertEquals(u'modified=False',
                              c.get('/empty_session').data.decode('utf-8'))
            # But after the request, the session could be modified by the
            # "after_request" handlers that call _update_remember_cookie.
            # Ensure that if nothing changed the session is not modified.
            self.assertFalse(session.modified)

    #
    # Session Protection
    #
    def test_session_protection_basic_passes_successive_requests(self):
        self.app.config['SESSION_PROTECTION'] = 'basic'
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            username_result = c.get('/username')
            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
            fresh_result = c.get('/is-fresh')
            self.assertEqual(u'True', fresh_result.data.decode('utf-8'))

    def test_session_protection_strong_passes_successive_requests(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            username_result = c.get('/username')
            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
            fresh_result = c.get('/is-fresh')
            self.assertEqual(u'True', fresh_result.data.decode('utf-8'))

    def test_session_protection_basic_marks_session_unfresh(self):
        self.app.config['SESSION_PROTECTION'] = 'basic'
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            username_result = c.get('/username',
                                    headers=[('User-Agent', 'different')])
            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
            fresh_result = c.get('/is-fresh')
            self.assertEqual(u'False', fresh_result.data.decode('utf-8'))

    def test_session_protection_basic_fires_signal(self):
        self.app.config['SESSION_PROTECTION'] = 'basic'

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            with listen_to(session_protected) as listener:
                c.get('/username', headers=[('User-Agent', 'different')])
                listener.assert_heard_one(self.app)

    def test_session_protection_basic_skips_when_remember_me(self):
        self.app.config['SESSION_PROTECTION'] = 'basic'

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            # clear session to force remember me (and remove old session id)
            self._delete_session(c)
            # should not trigger protection because "sess" is empty
            with listen_to(session_protected) as listener:
                c.get('/username')
                listener.assert_heard_none(self.app)

    def test_session_protection_strong_skips_when_remember_me(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            # clear session to force remember me (and remove old session id)
            self._delete_session(c)
            # should not trigger protection because "sess" is empty
            with listen_to(session_protected) as listener:
                c.get('/username')
                listener.assert_heard_none(self.app)

    def test_permanent_strong_session_protection_marks_session_unfresh(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'
        with self.app.test_client() as c:
            c.get('/login-notch-permanent')
            username_result = c.get('/username',
                                    headers=[('User-Agent', 'different')])
            self.assertEqual(u'Notch', username_result.data.decode('utf-8'))
            fresh_result = c.get('/is-fresh')
            self.assertEqual(u'False', fresh_result.data.decode('utf-8'))

    def test_permanent_strong_session_protection_fires_signal(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'

        with self.app.test_client() as c:
            c.get('/login-notch-permanent')
            with listen_to(session_protected) as listener:
                c.get('/username', headers=[('User-Agent', 'different')])
                listener.assert_heard_one(self.app)

    def test_session_protection_strong_deletes_session(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'
        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            username_result = c.get('/username',
                                    headers=[('User-Agent', 'different')])
            self.assertEqual(u'Anonymous',
                             username_result.data.decode('utf-8'))

    def test_session_protection_strong_fires_signal_user_agent(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            with listen_to(session_protected) as listener:
                c.get('/username', headers=[('User-Agent', 'different')])
                listener.assert_heard_one(self.app)

    def test_session_protection_strong_fires_signal_x_forwarded_for(self):
        self.app.config['SESSION_PROTECTION'] = 'strong'

        with self.app.test_client() as c:
            c.get('/login-notch-remember',
                  headers=[('X-Forwarded-For', '10.1.1.1')])
            with listen_to(session_protected) as listener:
                c.get('/username', headers=[('X-Forwarded-For', '10.1.1.2')])
                listener.assert_heard_one(self.app)

    def test_session_protection_skip_when_off_and_anonymous(self):
        with self.app.test_client() as c:
            # no user access
            with listen_to(user_accessed) as user_listener:
                results = c.get('/')
                user_listener.assert_heard_none(self.app)

            # access user with no session data
            with listen_to(session_protected) as session_listener:
                results = c.get('/username')
                self.assertEqual(results.data.decode('utf-8'), u'Anonymous')
                session_listener.assert_heard_none(self.app)

            # verify no session data has been set
            self.assertFalse(session)

    def test_session_protection_skip_when_basic_and_anonymous(self):
        self.app.config['SESSION_PROTECTION'] = 'basic'

        with self.app.test_client() as c:
            # no user access
            with listen_to(user_accessed) as user_listener:
                results = c.get('/')
                user_listener.assert_heard_none(self.app)

            # access user with no session data
            with listen_to(session_protected) as session_listener:
                results = c.get('/username')
                self.assertEqual(results.data.decode('utf-8'), u'Anonymous')
                session_listener.assert_heard_none(self.app)

            # verify no session data has been set other than '_id'
            self.assertIsNotNone(session.get('_id'))
            self.assertTrue(len(session) == 1)

    #
    # Custom Token Loader
    #
    def test_custom_token_loader(self):
        @self.login_manager.token_loader
        def load_token(token):
            return USER_TOKENS.get(token)

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)

            # Test that remember me functionality still works
            self.assertEqual(u'Notch', c.get('/username').data.decode('utf-8'))

            # Test that we used the custom authentication token
            remember_cookie = self._get_remember_cookie(c)
            expected_value = make_secure_token(u'Notch', key='deterministic')
            self.assertEqual(expected_value, remember_cookie.value)

    def test_change_api_key_with_token_loader(self):
        @self.login_manager.token_loader
        def load_token(token):
            return USER_TOKENS.get(token)

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)
            self.app.config['SECRET_KEY'] = 'ima change this now'

            result = c.get('/username')
            self.assertEqual(result.data.decode('utf-8'), u'Notch')

    def test_custom_token_loader_with_no_user(self):
        @self.login_manager.token_loader
        def load_token(token):
            return

        with self.app.test_client() as c:
            c.get('/login-notch-remember')
            self._delete_session(c)

            result = c.get('/username')
            self.assertEqual(result.data.decode('utf-8'), u'Anonymous')

    #
    # Lazy Access User
    #
    def test_requests_without_accessing_session(self):
        with self.app.test_client() as c:
            c.get('/login-notch')

            #no session access
            with listen_to(user_accessed) as listener:
                c.get('/')
                listener.assert_heard_none(self.app)

            #should have a session access
            with listen_to(user_accessed) as listener:
                result = c.get('/username')
                listener.assert_heard_one(self.app)
                self.assertEqual(result.data.decode('utf-8'), u'Notch')

    #
    # View Decorators
    #
    def test_login_required_decorator(self):
        @self.app.route('/protected')
        @login_required
        def protected():
            return u'Access Granted'

        with self.app.test_client() as c:
            result = c.get('/protected')
            self.assertEqual(result.status_code, 401)

            c.get('/login-notch')
            result2 = c.get('/protected')
            self.assertIn(u'Access Granted', result2.data.decode('utf-8'))

    def test_decorators_are_disabled(self):
        @self.app.route('/protected')
        @login_required
        @fresh_login_required
        def protected():
            return u'Access Granted'

        self.app.login_manager._login_disabled = True

        with self.app.test_client() as c:
            result = c.get('/protected')
            self.assertIn(u'Access Granted', result.data.decode('utf-8'))

    def test_fresh_login_required_decorator(self):
        @self.app.route('/very-protected')
        @fresh_login_required
        def very_protected():
            return 'Access Granted'

        with self.app.test_client() as c:
            result = c.get('/very-protected')
            self.assertEqual(result.status_code, 401)

            c.get('/login-notch-remember')
            logged_in_result = c.get('/very-protected')
            self.assertEqual(u'Access Granted',
                             logged_in_result.data.decode('utf-8'))

            self._delete_session(c)
            stale_result = c.get('/very-protected')
            self.assertEqual(stale_result.status_code, 403)

            c.get('/confirm-login')
            refreshed_result = c.get('/very-protected')
            self.assertEqual(u'Access Granted',
                             refreshed_result.data.decode('utf-8'))

    #
    # Misc
    #
    @unittest.skipIf(werkzeug_version.startswith("0.9"),
                     "wait for upstream implementing RFC 5987")
    def test_chinese_user_agent(self):
        with self.app.test_client() as c:
            result = c.get('/', headers=[('User-Agent', u'中文')])
            self.assertEqual(u'Welcome!', result.data.decode('utf-8'))

    @unittest.skipIf(werkzeug_version.startswith("0.9"),
                     "wait for upstream implementing RFC 5987")
    def test_russian_cp1251_user_agent(self):
        with self.app.test_client() as c:
            headers = [('User-Agent', u'ЯЙЮя'.encode('cp1251'))]
            response = c.get('/', headers=headers)
            self.assertEqual(response.data.decode('utf-8'), u'Welcome!')

    def test_make_secure_token_default_key(self):
        with self.app.test_request_context():
            self.assertEqual(make_secure_token('foo'),
                             '0f05743a2b617b2625362ab667c0dbdf4c9ec13a')

    def test_user_context_processor(self):
        with self.app.test_request_context():
            _ucp = self.app.context_processor(_user_context_processor)
            self.assertIsInstance(_ucp()['current_user'], AnonymousUserMixin)
def client(request):
    parent_app = Flask(__name__)
    parent_app.register_blueprint(app, url_prefix=request.param)
    with parent_app.test_client() as client:
        yield client
Esempio n. 39
0
    def setUp(self):
        app = Flask(__name__)
        app.config['SECRET_KEY'] = 'my secret'

        basic_auth = HTTPBasicAuth()
        basic_auth_my_realm = HTTPBasicAuth(scheme='CustomBasic',
                                            realm='My Realm')
        basic_custom_auth = HTTPBasicAuth()
        basic_verify_auth = HTTPBasicAuth()
        digest_auth = HTTPDigestAuth()
        digest_auth_my_realm = HTTPDigestAuth(scheme='CustomDigest',
                                              realm='My Realm')
        digest_auth_ha1_pw = HTTPDigestAuth(use_ha1_pw=True)

        @digest_auth_ha1_pw.get_password
        def get_digest_password(username):
            if username == 'susan':
                return get_ha1(username, 'hello', digest_auth_ha1_pw.realm)
            elif username == 'john':
                return get_ha1(username, 'bye', digest_auth_ha1_pw.realm)
            else:
                return None

        @basic_auth.get_password
        def get_basic_password(username):
            if username == 'john':
                return 'hello'
            elif username == 'susan':
                return 'bye'
            else:
                return None

        @basic_auth_my_realm.get_password
        def get_basic_password_2(username):
            if username == 'john':
                return 'johnhello'
            elif username == 'susan':
                return 'susanbye'
            else:
                return None

        @basic_auth_my_realm.hash_password
        def basic_auth_my_realm_hash_password(username, password):
            return username + password

        @basic_auth_my_realm.error_handler
        def basic_auth_my_realm_error():
            return 'custom error'

        @basic_custom_auth.get_password
        def get_basic_custom_auth_get_password(username):
            if username == 'john':
                return md5('hello').hexdigest()
            elif username == 'susan':
                return md5('bye').hexdigest()
            else:
                return None

        @basic_custom_auth.hash_password
        def basic_custom_auth_hash_password(password):
            return md5(password).hexdigest()

        @basic_verify_auth.verify_password
        def basic_verify_auth_verify_password(username, password):
            g.anon = False
            if username == 'john':
                return password == 'hello'
            elif username == 'susan':
                return password == 'bye'
            elif username == '':
                g.anon = True
                return True
            return False

        @digest_auth.get_password
        def get_digest_password_2(username):
            if username == 'susan':
                return 'hello'
            elif username == 'john':
                return 'bye'
            else:
                return None

        @digest_auth_my_realm.get_password
        def get_digest_password_3(username):
            if username == 'susan':
                return 'hello'
            elif username == 'john':
                return 'bye'
            else:
                return None

        @app.route('/')
        def index():
            return 'index'

        @app.route('/basic')
        @basic_auth.login_required
        def basic_auth_route():
            return 'basic_auth:' + basic_auth.username()

        @app.route('/basic-with-realm')
        @basic_auth_my_realm.login_required
        def basic_auth_my_realm_route():
            return 'basic_auth_my_realm:' + basic_auth_my_realm.username()

        @app.route('/basic-custom')
        @basic_custom_auth.login_required
        def basic_custom_auth_route():
            return 'basic_custom_auth:' + basic_custom_auth.username()

        @app.route('/basic-verify')
        @basic_verify_auth.login_required
        def basic_verify_auth_route():
            return 'basic_verify_auth:' + basic_verify_auth.username() + \
                ' anon:' + str(g.anon)

        @app.route('/digest')
        @digest_auth.login_required
        def digest_auth_route():
            return 'digest_auth:' + digest_auth.username()

        @app.route('/digest_ha1_pw')
        @digest_auth_ha1_pw.login_required
        def digest_auth_ha1_pw_route():
            return 'digest_auth:' + digest_auth.username()

        @app.route('/digest-with-realm')
        @digest_auth_my_realm.login_required
        def digest_auth_my_realm_route():
            return 'digest_auth_my_realm:' + digest_auth_my_realm.username()

        self.app = app
        self.basic_auth = basic_auth
        self.basic_auth_my_realm = basic_auth_my_realm
        self.basic_custom_auth = basic_custom_auth
        self.basic_verify_auth = basic_verify_auth
        self.digest_auth = digest_auth
        self.client = app.test_client()
Esempio n. 40
0
        val = self.serializer.dumps(dict(session))
        self.redis.setex(self.key_prefix + session.sid, val,
                         int(app.permanent_session_lifetime.total_seconds()))
        response.set_cookie(app.session_cookie_name,
                            session.sid,
                            expires=expires,
                            httponly=http_only,
                            domain=domain,
                            path=path,
                            secure=secure)


if __name__ == '__main__':
    from flask import Flask, session

    app = Flask(__name__)
    RedisSession.init_app(app)

    @app.route('/')
    def index():
        session['test'] = {"test": "test"}
        return 'set session ok'

    @app.route('/dump')
    def dump():
        print session
        return 'dump session ok'

    with app.test_client() as c:
        print c.get('/')
        print c.get('/dump')
Esempio n. 41
0
class FlaskTestCase(unittest.TestCase):
    def setUp(self):
        from flask import Flask
        from flask_appbuilder import AppBuilder
        from flask_appbuilder.models.sqla.interface import SQLAInterface
        from flask_appbuilder import ModelRestApi

        self.app = Flask(__name__)
        self.basedir = os.path.abspath(os.path.dirname(__file__))
        self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
        self.app.config["SECRET_KEY"] = "thisismyscretkey"
        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        self.app.config["FAB_API_MAX_PAGE_SIZE"] = MAX_PAGE_SIZE
        self.app.config["WTF_CSRF_ENABLED"] = False

        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)
        # Create models and insert data
        insert_data(self.db.session, MODEL1_DATA_SIZE)

        class Model1Api(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            list_columns = [
                "field_integer",
                "field_float",
                "field_string",
                "field_date",
            ]
            description_columns = {
                "field_integer": "Field Integer",
                "field_float": "Field Float",
                "field_string": "Field String",
            }

        self.model1api = Model1Api
        self.appbuilder.add_api(Model1Api)

        class Model1ApiFieldsInfo(Model1Api):
            datamodel = SQLAInterface(Model1)
            add_columns = ["field_integer", "field_float", "field_string", "field_date"]
            edit_columns = ["field_string", "field_integer"]

        self.model1apifieldsinfo = Model1ApiFieldsInfo
        self.appbuilder.add_api(Model1ApiFieldsInfo)

        class Model1FuncApi(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            list_columns = [
                "field_integer",
                "field_float",
                "field_string",
                "field_date",
                "full_concat",
            ]
            description_columns = {
                "field_integer": "Field Integer",
                "field_float": "Field Float",
                "field_string": "Field String",
            }

        self.model1funcapi = Model1Api
        self.appbuilder.add_api(Model1FuncApi)

        class Model1ApiExcludeCols(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            list_exclude_columns = ["field_integer", "field_float", "field_date"]
            show_exclude_columns = list_exclude_columns
            edit_exclude_columns = list_exclude_columns
            add_exclude_columns = list_exclude_columns

        self.appbuilder.add_api(Model1ApiExcludeCols)

        class Model1ApiOrder(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            base_order = ("field_integer", "desc")

        self.appbuilder.add_api(Model1ApiOrder)

        class Model1ApiRestrictedPermissions(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            base_permissions = ["can_get", "can_info"]

        self.appbuilder.add_api(Model1ApiRestrictedPermissions)

        class Model1ApiFiltered(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            base_filters = [
                ["field_integer", FilterGreater, 2],
                ["field_integer", FilterSmaller, 4],
            ]

        self.appbuilder.add_api(Model1ApiFiltered)

        class ModelWithEnumsApi(ModelRestApi):
            datamodel = SQLAInterface(ModelWithEnums)

        self.appbuilder.add_api(ModelWithEnumsApi)

        class Model1BrowserLogin(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            allow_browser_login = True

        self.appbuilder.add_api(Model1BrowserLogin)

        class ModelMMApi(ModelRestApi):
            datamodel = SQLAInterface(ModelMMParent)

        self.appbuilder.add_api(ModelMMApi)

        class ModelMMRequiredApi(ModelRestApi):
            datamodel = SQLAInterface(ModelMMParentRequired)

        self.appbuilder.add_api(ModelMMRequiredApi)

        class Model1CustomValidationApi(ModelRestApi):
            datamodel = SQLAInterface(Model1)
            validators_columns = {"field_string": validate_name}

        self.appbuilder.add_api(Model1CustomValidationApi)

        class Model2Api(ModelRestApi):
            datamodel = SQLAInterface(Model2)
            list_columns = ["group"]
            show_columns = ["group"]

        self.model2api = Model2Api
        self.appbuilder.add_api(Model2Api)

        class Model2ApiFilteredRelFields(ModelRestApi):
            datamodel = SQLAInterface(Model2)
            list_columns = ["group"]
            show_columns = ["group"]
            add_query_rel_fields = {
                "group": [
                    ["field_integer", FilterGreater, 2],
                    ["field_integer", FilterSmaller, 4],
                ]
            }
            edit_query_rel_fields = add_query_rel_fields

        self.model2apifilteredrelfields = Model2ApiFilteredRelFields
        self.appbuilder.add_api(Model2ApiFilteredRelFields)

        role_admin = self.appbuilder.sm.find_role("Admin")
        self.appbuilder.sm.add_user(
            USERNAME, "admin", "user", "*****@*****.**", role_admin, PASSWORD
        )

    def tearDown(self):
        self.appbuilder = None
        self.app = None
        self.db = None

    @staticmethod
    def auth_client_get(client, token, uri):
        return client.get(uri, headers={"Authorization": "Bearer {}".format(token)})

    @staticmethod
    def auth_client_delete(client, token, uri):
        return client.delete(uri, headers={"Authorization": "Bearer {}".format(token)})

    @staticmethod
    def auth_client_put(client, token, uri, json):
        return client.put(
            uri, json=json, headers={"Authorization": "Bearer {}".format(token)}
        )

    @staticmethod
    def auth_client_post(client, token, uri, json):
        return client.post(
            uri, json=json, headers={"Authorization": "Bearer {}".format(token)}
        )

    @staticmethod
    def _login(client, username, password):
        """
            Login help method
        :param client: Flask test client
        :param username: username
        :param password: password
        :return: Flask client response class
        """
        return client.post(
            "api/{}/security/login".format(API_SECURITY_VERSION),
            data=json.dumps(
                {
                    API_SECURITY_USERNAME_KEY: username,
                    API_SECURITY_PASSWORD_KEY: password,
                    API_SECURITY_PROVIDER_KEY: "db",
                }
            ),
            content_type="application/json",
        )

    def login(self, client, username, password):
        # Login with default admin
        rv = self._login(client, username, password)
        try:
            return json.loads(rv.data.decode("utf-8")).get("access_token")
        except Exception:
            return rv

    def browser_login(self, client, username, password):
        # Login with default admin
        return client.post(
            "/login/",
            data=dict(username=username, password=password),
            follow_redirects=True,
        )

    def browser_logout(self, client):
        return client.get("/logout/")

    def test_auth_login(self):
        """
            REST Api: Test auth login
        """
        client = self.app.test_client()
        rv = self._login(client, USERNAME, PASSWORD)
        eq_(rv.status_code, 200)
        assert json.loads(rv.data.decode("utf-8")).get(
            API_SECURITY_ACCESS_TOKEN_KEY, False
        )

    def test_auth_login_failed(self):
        """
            REST Api: Test auth login failed
        """
        client = self.app.test_client()
        rv = self._login(client, "fail", "fail")
        eq_(json.loads(rv.data), {"message": "Not authorized"})
        eq_(rv.status_code, 401)

    def test_auth_login_bad(self):
        """
            REST Api: Test auth login bad request
        """
        client = self.app.test_client()
        rv = client.post("api/v1/security/login", data="BADADATA")
        eq_(rv.status_code, 400)

    def test_auth_authorization_browser(self):
        """
            REST Api: Test auth with browser login
        """
        client = self.app.test_client()
        rv = self.browser_login(client, USERNAME, PASSWORD)
        # Test access with browser login
        uri = "api/v1/model1browserlogin/1"
        rv = client.get(uri)
        eq_(rv.status_code, 200)
        # Test unauthorized access with browser login
        uri = "api/v1/model1api/1"
        rv = client.get(uri)
        eq_(rv.status_code, 401)
        # Test access wihout cookie or JWT
        rv = self.browser_logout(client)
        # Test access with browser login
        uri = "api/v1/model1browserlogin/1"
        rv = client.get(uri)
        eq_(rv.status_code, 401)
        # Test access with JWT but without cookie
        token = self.login(client, USERNAME, PASSWORD)
        uri = "api/v1/model1browserlogin/1"
        rv = self.auth_client_get(client, token, uri)
        eq_(rv.status_code, 200)

    def test_auth_authorization(self):
        """
            REST Api: Test auth base limited authorization
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        # Test unauthorized DELETE
        pk = 1
        uri = "api/v1/model1apirestrictedpermissions/{}".format(pk)
        rv = self.auth_client_delete(client, token, uri)
        eq_(rv.status_code, 401)
        # Test unauthorized POST
        item = dict(
            field_string="test{}".format(MODEL1_DATA_SIZE + 1),
            field_integer=MODEL1_DATA_SIZE + 1,
            field_float=float(MODEL1_DATA_SIZE + 1),
            field_date=None,
        )
        uri = "api/v1/model1apirestrictedpermissions/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 401)
        # Test unauthorized GET
        uri = "api/v1/model1apirestrictedpermissions/1"
        rv = self.auth_client_get(client, token, uri)
        eq_(rv.status_code, 200)

    def test_get_item(self):
        """
            REST Api: Test get item
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        for i in range(1, MODEL1_DATA_SIZE):
            rv = self.auth_client_get(client, token, "api/v1/model1api/{}".format(i))
            data = json.loads(rv.data.decode("utf-8"))
            eq_(rv.status_code, 200)
            self.assert_get_item(rv, data, i - 1)

    def assert_get_item(self, rv, data, value):
        eq_(
            data[API_RESULT_RES_KEY],
            {
                "field_date": None,
                "field_float": float(value),
                "field_integer": value,
                "field_string": "test{}".format(value),
            },
        )
        # test descriptions
        eq_(data["description_columns"], self.model1api.description_columns)
        # test labels
        eq_(
            data[API_LABEL_COLUMNS_RES_KEY],
            {
                "field_date": "Field Date",
                "field_float": "Field Float",
                "field_integer": "Field Integer",
                "field_string": "Field String",
            },
        )
        eq_(rv.status_code, 200)

    def test_get_item_select_cols(self):
        """
            REST Api: Test get item with select columns
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        for i in range(1, MODEL1_DATA_SIZE):
            uri = "api/v1/model1api/{}?q=({}:!(field_integer))".format(
                i, API_SELECT_COLUMNS_RIS_KEY
            )
            rv = self.auth_client_get(client, token, uri)
            data = json.loads(rv.data.decode("utf-8"))
            eq_(data[API_RESULT_RES_KEY], {"field_integer": i - 1})
            eq_(
                data[API_DESCRIPTION_COLUMNS_RES_KEY],
                {"field_integer": "Field Integer"},
            )
            eq_(data[API_LABEL_COLUMNS_RES_KEY], {"field_integer": "Field Integer"})
            eq_(rv.status_code, 200)

    def test_get_item_select_meta_data(self):
        """
            REST Api: Test get item select meta data
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        selectable_keys = [
            API_DESCRIPTION_COLUMNS_RIS_KEY,
            API_LABEL_COLUMNS_RIS_KEY,
            API_SHOW_COLUMNS_RIS_KEY,
            API_SHOW_TITLE_RIS_KEY,
        ]
        for selectable_key in selectable_keys:
            argument = {API_SELECT_KEYS_RIS_KEY: [selectable_key]}
            uri = "api/v1/model1api/1?{}={}".format(
                API_URI_RIS_KEY, prison.dumps(argument)
            )
            rv = self.auth_client_get(client, token, uri)
            data = json.loads(rv.data.decode("utf-8"))
            eq_(len(data.keys()), 1 + 2)  # always exist id, result
            # We assume that rison meta key equals result meta key
            assert selectable_key in data

    def test_get_item_excluded_cols(self):
        """
            REST Api: Test get item with excluded columns
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        pk = 1
        rv = self.auth_client_get(
            client, token, "api/v1/model1apiexcludecols/{}".format(pk)
        )
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data[API_RESULT_RES_KEY], {"field_string": "test0"})
        eq_(rv.status_code, 200)

    def test_get_item_not_found(self):
        """
            REST Api: Test get item not found
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        pk = MODEL1_DATA_SIZE + 1
        rv = self.auth_client_get(client, token, "api/v1/model1api/{}".format(pk))
        eq_(rv.status_code, 404)

    def test_get_item_base_filters(self):
        """
            REST Api: Test get item with base filters
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # We can't get a base filtered item
        pk = 1
        rv = self.auth_client_get(
            client, token, "api/v1/model1apifiltered/{}".format(pk)
        )
        eq_(rv.status_code, 404)
        # This one is ok pk=4 field_integer=3 2>3<4
        pk = 4
        rv = self.auth_client_get(
            client, token, "api/v1/model1apifiltered/{}".format(pk)
        )
        eq_(rv.status_code, 200)

    def test_get_item_1m_field(self):
        """
            REST Api: Test get item with 1-N related field
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # We can't get a base filtered item
        pk = 1
        rv = self.auth_client_get(client, token, "api/v1/model2api/{}".format(pk))
        data = json.loads(rv.data.decode("utf-8"))
        eq_(rv.status_code, 200)
        expected_rel_field = {
            "group": {
                "field_date": None,
                "field_float": 0.0,
                "field_integer": 0,
                "field_string": "test0",
                "id": 1,
            }
        }
        eq_(data[API_RESULT_RES_KEY], expected_rel_field)

    def test_get_item_mm_field(self):
        """
            REST Api: Test get item with N-N related field
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # We can't get a base filtered item
        pk = 1
        rv = self.auth_client_get(client, token, "api/v1/modelmmapi/{}".format(pk))
        data = json.loads(rv.data.decode("utf-8"))
        eq_(rv.status_code, 200)
        expected_rel_field = [
            {"field_string": "1", "id": 1},
            {"field_string": "2", "id": 2},
            {"field_string": "3", "id": 3},
        ]
        eq_(data[API_RESULT_RES_KEY]["children"], expected_rel_field)

    def test_get_list(self):
        """
            REST Api: Test get list
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        rv = self.auth_client_get(client, token, "api/v1/model1api/")

        data = json.loads(rv.data.decode("utf-8"))
        # Tests count property
        eq_(data["count"], MODEL1_DATA_SIZE)
        # Tests data result default page size
        eq_(len(data[API_RESULT_RES_KEY]), self.model1api.page_size)
        for i in range(1, self.model1api.page_size):
            self.assert_get_list(rv, data[API_RESULT_RES_KEY][i - 1], i - 1)

    @staticmethod
    def assert_get_list(rv, data, value):
        eq_(
            data,
            {
                "field_date": None,
                "field_float": float(value),
                "field_integer": value,
                "field_string": "test{}".format(value),
            },
        )
        eq_(rv.status_code, 200)

    def test_get_list_order(self):
        """
            REST Api: Test get list order params
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # test string order asc
        arguments = {"order_column": "field_integer", "order_direction": "asc"}
        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments))
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": 0.0,
                "field_integer": 0,
                "field_string": "test0",
            },
        )
        eq_(rv.status_code, 200)
        # test string order desc
        arguments = {"order_column": "field_integer", "order_direction": "desc"}
        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments))
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": float(MODEL1_DATA_SIZE - 1),
                "field_integer": MODEL1_DATA_SIZE - 1,
                "field_string": "test{}".format(MODEL1_DATA_SIZE - 1),
            },
        )
        eq_(rv.status_code, 200)

    def test_get_list_base_order(self):
        """
            REST Api: Test get list with base order
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # test string order asc
        rv = self.auth_client_get(client, token, "api/v1/model1apiorder/")
        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": float(MODEL1_DATA_SIZE - 1),
                "field_integer": MODEL1_DATA_SIZE - 1,
                "field_string": "test{}".format(MODEL1_DATA_SIZE - 1),
            },
        )
        # Test override
        arguments = {"order_column": "field_integer", "order_direction": "asc"}
        uri = "api/v1/model1apiorder/?{}={}".format(
            API_URI_RIS_KEY, prison.dumps(arguments)
        )
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": 0.0,
                "field_integer": 0,
                "field_string": "test0",
            },
        )

    def test_get_list_page(self):
        """
            REST Api: Test get list page params
        """
        page_size = 5
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # test page zero
        arguments = {
            "page_size": page_size,
            "page": 0,
            "order_column": "field_integer",
            "order_direction": "asc",
        }
        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments))
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": 0.0,
                "field_integer": 0,
                "field_string": "test0",
            },
        )
        eq_(rv.status_code, 200)
        eq_(len(data[API_RESULT_RES_KEY]), page_size)
        # test page one
        arguments = {
            "page_size": page_size,
            "page": 1,
            "order_column": "field_integer",
            "order_direction": "asc",
        }
        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments))
        rv = self.auth_client_get(client, token, uri)

        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": float(page_size),
                "field_integer": page_size,
                "field_string": "test{}".format(page_size),
            },
        )
        eq_(rv.status_code, 200)
        eq_(len(data[API_RESULT_RES_KEY]), page_size)

    def test_get_list_max_page_size(self):
        """
            REST Api: Test get list max page size config setting
        """
        page_size = 100  # Max is globally set to MAX_PAGE_SIZE
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        # test page zero
        arguments = {
            "page_size": page_size,
            "page": 0,
            "order_column": "field_integer",
            "order_direction": "asc",
        }
        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments))
        print("URI {}".format(uri))
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(len(data[API_RESULT_RES_KEY]), MAX_PAGE_SIZE)

    def test_get_list_filters(self):
        """
            REST Api: Test get list filter params
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        filter_value = 5
        # test string order asc
        arguments = {
            API_FILTERS_RIS_KEY: [
                {"col": "field_integer", "opr": "gt", "value": filter_value}
            ],
            "order_column": "field_integer",
            "order_direction": "asc",
        }

        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments))

        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(
            data[API_RESULT_RES_KEY][0],
            {
                "field_date": None,
                "field_float": float(filter_value + 1),
                "field_integer": filter_value + 1,
                "field_string": "test{}".format(filter_value + 1),
            },
        )
        eq_(rv.status_code, 200)

    def test_get_list_select_cols(self):
        """
            REST Api: Test get list with selected columns
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        argument = {
            API_SELECT_COLUMNS_RIS_KEY: ["field_integer"],
            "order_column": "field_integer",
            "order_direction": "asc",
        }

        uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(argument))
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data[API_RESULT_RES_KEY][0], {"field_integer": 0})
        eq_(data[API_LABEL_COLUMNS_RES_KEY], {"field_integer": "Field Integer"})
        eq_(data[API_DESCRIPTION_COLUMNS_RES_KEY], {"field_integer": "Field Integer"})
        eq_(data[API_LIST_COLUMNS_RES_KEY], ["field_integer"])
        eq_(rv.status_code, 200)

    def test_get_list_select_meta_data(self):
        """
            REST Api: Test get list select meta data
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        selectable_keys = [
            API_DESCRIPTION_COLUMNS_RIS_KEY,
            API_LABEL_COLUMNS_RIS_KEY,
            API_ORDER_COLUMNS_RIS_KEY,
            API_LIST_COLUMNS_RIS_KEY,
            API_LIST_TITLE_RIS_KEY,
        ]
        for selectable_key in selectable_keys:
            argument = {API_SELECT_KEYS_RIS_KEY: [selectable_key]}
            uri = "api/v1/model1api/?{}={}".format(
                API_URI_RIS_KEY, prison.dumps(argument)
            )
            rv = self.auth_client_get(client, token, uri)
            data = json.loads(rv.data.decode("utf-8"))
            eq_(len(data.keys()), 1 + 3)  # always exist count, ids, result
            # We assume that rison meta key equals result meta key
            assert selectable_key in data

    def test_get_list_exclude_cols(self):
        """
            REST Api: Test get list with excluded columns
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        uri = "api/v1/model1apiexcludecols/"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data[API_RESULT_RES_KEY][0], {"field_string": "test0"})

    def test_get_list_base_filters(self):
        """
            REST Api: Test get list with base filters
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        arguments = {"order_column": "field_integer", "order_direction": "desc"}
        uri = "api/v1/model1apifiltered/?{}={}".format(
            API_URI_RIS_KEY, prison.dumps(arguments)
        )
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expected_result = [
            {
                "field_date": None,
                "field_float": 3.0,
                "field_integer": 3,
                "field_string": "test3",
            }
        ]
        eq_(data[API_RESULT_RES_KEY], expected_result)

    def test_info_filters(self):
        """
            REST Api: Test info filters
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        uri = "api/v1/model1api/_info"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expected_filters = {
            "field_date": [
                {"name": "Equal to", "operator": "eq"},
                {"name": "Greater than", "operator": "gt"},
                {"name": "Smaller than", "operator": "lt"},
                {"name": "Not Equal to", "operator": "neq"},
            ],
            "field_float": [
                {"name": "Equal to", "operator": "eq"},
                {"name": "Greater than", "operator": "gt"},
                {"name": "Smaller than", "operator": "lt"},
                {"name": "Not Equal to", "operator": "neq"},
            ],
            "field_integer": [
                {"name": "Equal to", "operator": "eq"},
                {"name": "Greater than", "operator": "gt"},
                {"name": "Smaller than", "operator": "lt"},
                {"name": "Not Equal to", "operator": "neq"},
            ],
            "field_string": [
                {"name": "Starts with", "operator": "sw"},
                {"name": "Ends with", "operator": "ew"},
                {"name": "Contains", "operator": "ct"},
                {"name": "Equal to", "operator": "eq"},
                {"name": "Not Starts with", "operator": "nsw"},
                {"name": "Not Ends with", "operator": "new"},
                {"name": "Not Contains", "operator": "nct"},
                {"name": "Not Equal to", "operator": "neq"},
            ],
        }
        eq_(data["filters"], expected_filters)

    def test_info_fields(self):
        """
            REST Api: Test info fields (add, edit)
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        uri = "api/v1/model1apifieldsinfo/_info"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expect_add_fields = [
            {
                "description": "Field Integer",
                "label": "Field Integer",
                "name": "field_integer",
                "required": False,
                "unique": False,
                "type": "Integer",
            },
            {
                "description": "Field Float",
                "label": "Field Float",
                "name": "field_float",
                "required": False,
                "unique": False,
                "type": "Float",
            },
            {
                "description": "Field String",
                "label": "Field String",
                "name": "field_string",
                "required": True,
                "unique": True,
                "type": "String",
                "validate": ["<Length(min=None, max=50, equal=None, error=None)>"],
            },
            {
                "description": "",
                "label": "Field Date",
                "name": "field_date",
                "required": False,
                "unique": False,
                "type": "Date",
            },
        ]
        expect_edit_fields = list()
        for edit_col in self.model1apifieldsinfo.edit_columns:
            for item in expect_add_fields:
                if item["name"] == edit_col:
                    expect_edit_fields.append(item)
        eq_(data[API_ADD_COLUMNS_RES_KEY], expect_add_fields)
        eq_(data[API_EDIT_COLUMNS_RES_KEY], expect_edit_fields)

    def test_info_fields_rel_field(self):
        """
            REST Api: Test info fields with related fields
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        uri = "api/v1/model2api/_info"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expected_rel_add_field = {
            "count": MODEL2_DATA_SIZE,
            "description": "",
            "label": "Group",
            "name": "group",
            "required": True,
            "unique": False,
            "type": "Related",
            "values": [],
        }
        for i in range(self.model2api.page_size):
            expected_rel_add_field["values"].append(
                {"id": i + 1, "value": "test{}".format(i)}
            )
        for rel_field in data[API_ADD_COLUMNS_RES_KEY]:
            if rel_field["name"] == "group":
                eq_(rel_field, expected_rel_add_field)

    def test_info_fields_rel_filtered_field(self):
        """
            REST Api: Test info fields with filtered
            related fields
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        uri = "api/v1/model2apifilteredrelfields/_info"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expected_rel_add_field = {
            "description": "",
            "label": "Group",
            "name": "group",
            "required": True,
            "unique": False,
            "type": "Related",
            "count": 1,
            "values": [{"id": 4, "value": "test3"}],
        }
        for rel_field in data[API_ADD_COLUMNS_RES_KEY]:
            if rel_field["name"] == "group":
                eq_(rel_field, expected_rel_add_field)
        for rel_field in data[API_EDIT_COLUMNS_RES_KEY]:
            if rel_field["name"] == "group":
                eq_(rel_field, expected_rel_add_field)

    def test_info_permissions(self):
        """
            REST Api: Test info permissions
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        uri = "api/v1/model1api/_info"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expected_permissions = [
            "can_delete",
            "can_get",
            "can_info",
            "can_post",
            "can_put",
        ]
        eq_(sorted(data[API_PERMISSIONS_RES_KEY]), expected_permissions)
        uri = "api/v1/model1apirestrictedpermissions/_info"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        expected_permissions = ["can_get", "can_info"]
        eq_(sorted(data[API_PERMISSIONS_RES_KEY]), expected_permissions)

    def test_info_select_meta_data(self):
        """
            REST Api: Test info select meta data
        """
        # select meta for add fields
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        selectable_keys = [
            API_ADD_COLUMNS_RIS_KEY,
            API_EDIT_COLUMNS_RIS_KEY,
            API_PERMISSIONS_RIS_KEY,
            API_FILTERS_RIS_KEY,
            API_ADD_TITLE_RIS_KEY,
            API_EDIT_TITLE_RIS_KEY,
        ]
        for selectable_key in selectable_keys:
            arguments = {API_SELECT_KEYS_RIS_KEY: [selectable_key]}
            uri = "api/v1/model1api/_info?{}={}".format(
                API_URI_RIS_KEY, prison.dumps(arguments)
            )
            rv = self.auth_client_get(client, token, uri)
            data = json.loads(rv.data.decode("utf-8"))
            eq_(len(data.keys()), 1)
            # We assume that rison meta key equals result meta key
            assert selectable_key in data

    def test_delete_item(self):
        """
            REST Api: Test delete item
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        pk = 2
        uri = "api/v1/model1api/{}".format(pk)
        rv = self.auth_client_delete(client, token, uri)
        eq_(rv.status_code, 200)
        model = self.db.session.query(Model1).get(pk)
        eq_(model, None)

    def test_delete_item_not_found(self):
        """
            REST Api: Test delete item not found
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)

        pk = MODEL1_DATA_SIZE + 1
        uri = "api/v1/model1api/{}".format(pk)
        rv = self.auth_client_delete(client, token, uri)
        eq_(rv.status_code, 404)

    def test_delete_item_base_filters(self):
        """
            REST Api: Test delete item with base filters
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        # Try to delete a filtered item
        pk = 1
        uri = "api/v1/model1apifiltered/{}".format(pk)
        rv = self.auth_client_delete(client, token, uri)
        eq_(rv.status_code, 404)

    def test_update_item(self):
        """
            REST Api: Test update item
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 3
        item = dict(field_string="test_Put", field_integer=0, field_float=0.0)
        uri = "api/v1/model1api/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 200)
        model = self.db.session.query(Model1).get(pk)
        eq_(model.field_string, "test_Put")
        eq_(model.field_integer, 0)
        eq_(model.field_float, 0.0)

    def test_update_custom_validation(self):
        """
            REST Api: Test update item custom validation
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 3
        item = dict(field_string="test_Put", field_integer=0, field_float=0.0)
        uri = "api/v1/model1customvalidationapi/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 422)
        pk = 3
        item = dict(field_string="Atest_Put", field_integer=0, field_float=0.0)
        uri = "api/v1/model1customvalidationapi/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 200)

    def test_update_item_base_filters(self):
        """
            REST Api: Test update item with base filters
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 4
        item = dict(field_string="test_Put", field_integer=3, field_float=3.0)
        uri = "api/v1/model1apifiltered/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 200)
        model = self.db.session.query(Model1).get(pk)
        eq_(model.field_string, "test_Put")
        eq_(model.field_integer, 3)
        eq_(model.field_float, 3.0)
        # We can't update an item that is base filtered
        pk = 1
        uri = "api/v1/model1apifiltered/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 404)

    def test_update_item_not_found(self):
        """
            REST Api: Test update item not found
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = MODEL1_DATA_SIZE + 1
        item = dict(field_string="test_Put", field_integer=0, field_float=0.0)
        uri = "api/v1/model1api/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 404)

    def test_update_val_size(self):
        """
            REST Api: Test update validate size
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 1
        field_string = "a" * 51
        item = dict(field_string=field_string, field_integer=11, field_float=11.0)
        uri = "api/v1/model1api/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data["message"]["field_string"][0], "Longer than maximum length 50.")

    def test_update_mm_field(self):
        """
            REST Api: Test update m-m field
        """
        model = ModelMMChild()
        model.field_string = "update_m,m"
        self.appbuilder.get_session.add(model)
        self.appbuilder.get_session.commit()
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 1
        item = dict(children=[4])
        uri = "api/v1/modelmmapi/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 200)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data[API_RESULT_RES_KEY], {"children": [4], "field_string": "0"})

    def test_update_item_val_type(self):
        """
            REST Api: Test update validate type
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 1
        item = dict(
            field_string="test{}".format(MODEL1_DATA_SIZE + 1),
            field_integer="test{}".format(MODEL1_DATA_SIZE + 1),
            field_float=11.0,
        )
        uri = "api/v1/model1api/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data["message"]["field_integer"][0], "Not a valid integer.")

        item = dict(field_string=11, field_integer=11, field_float=11.0)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data["message"]["field_string"][0], "Not a valid string.")

    def test_update_item_excluded_cols(self):
        """
            REST Api: Test update item with excluded cols
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        pk = 1
        item = dict(field_string="test_Put", field_integer=1000)
        uri = "api/v1/model1apiexcludecols/{}".format(pk)
        rv = self.auth_client_put(client, token, uri, item)
        eq_(rv.status_code, 200)
        model = self.db.session.query(Model1).get(pk)
        eq_(model.field_integer, 0)
        eq_(model.field_float, 0.0)
        eq_(model.field_date, None)

    def test_create_item(self):
        """
            REST Api: Test create item
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        item = dict(
            field_string="test{}".format(MODEL1_DATA_SIZE + 1),
            field_integer=MODEL1_DATA_SIZE + 1,
            field_float=float(MODEL1_DATA_SIZE + 1),
            field_date=None,
        )
        uri = "api/v1/model1api/"
        rv = self.auth_client_post(client, token, uri, item)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(rv.status_code, 201)
        eq_(data[API_RESULT_RES_KEY], item)
        model = (
            self.db.session.query(Model1)
            .filter_by(field_string="test{}".format(MODEL1_DATA_SIZE + 1))
            .first()
        )
        eq_(model.field_string, "test{}".format(MODEL1_DATA_SIZE + 1))
        eq_(model.field_integer, MODEL1_DATA_SIZE + 1)
        eq_(model.field_float, float(MODEL1_DATA_SIZE + 1))

    def test_create_item_custom_validation(self):
        """
            REST Api: Test create item custom validation
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        item = dict(
            field_string="test{}".format(MODEL1_DATA_SIZE + 1),
            field_integer=MODEL1_DATA_SIZE + 1,
            field_float=float(MODEL1_DATA_SIZE + 1),
            field_date=None,
        )
        uri = "api/v1/model1customvalidationapi/"
        rv = self.auth_client_post(client, token, uri, item)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(rv.status_code, 422)
        eq_(data, {"message": {"field_string": ["Name must start with an A"]}})
        item = dict(
            field_string="A{}".format(MODEL1_DATA_SIZE + 1),
            field_integer=MODEL1_DATA_SIZE + 1,
            field_float=float(MODEL1_DATA_SIZE + 1),
            field_date=None,
        )
        uri = "api/v1/model1customvalidationapi/"
        rv = self.auth_client_post(client, token, uri, item)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(rv.status_code, 201)

    def test_create_item_val_size(self):
        """
            REST Api: Test create validate size
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        field_string = "a" * 51
        item = dict(
            field_string=field_string,
            field_integer=MODEL1_DATA_SIZE + 1,
            field_float=float(MODEL1_DATA_SIZE + 1),
        )
        uri = "api/v1/model1api/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data["message"]["field_string"][0], "Longer than maximum length 50.")

    def test_create_item_val_type(self):
        """
            REST Api: Test create validate type
        """
        # Test integer as string
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        item = dict(
            field_string="test{}".format(MODEL1_DATA_SIZE),
            field_integer="test{}".format(MODEL1_DATA_SIZE),
            field_float=float(MODEL1_DATA_SIZE),
        )
        uri = "api/v1/model1api/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data["message"]["field_integer"][0], "Not a valid integer.")
        # Test string as integer
        item = dict(
            field_string=MODEL1_DATA_SIZE,
            field_integer=MODEL1_DATA_SIZE,
            field_float=float(MODEL1_DATA_SIZE),
        )
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data["message"]["field_string"][0], "Not a valid string.")

    def test_create_item_excluded_cols(self):
        """
            REST Api: Test create with excluded columns
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        item = dict(field_string="test{}".format(MODEL1_DATA_SIZE + 1))
        uri = "api/v1/model1apiexcludecols/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 201)
        item = dict(
            field_string="test{}".format(MODEL1_DATA_SIZE + 2),
            field_integer=MODEL1_DATA_SIZE + 2,
        )
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 201)
        model = (
            self.db.session.query(Model1)
            .filter_by(field_string="test{}".format(MODEL1_DATA_SIZE + 1))
            .first()
        )
        eq_(model.field_integer, None)
        eq_(model.field_float, None)
        eq_(model.field_date, None)

    def test_create_item_with_enum(self):
        """
            REST Api: Test create item with enum
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        item = dict(enum2="e1")
        uri = "api/v1/modelwithenumsapi/"
        rv = self.auth_client_post(client, token, uri, item)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(rv.status_code, 201)
        model = self.db.session.query(ModelWithEnums).get(data["id"])
        eq_(model.enum2, TmpEnum.e1)

    def test_create_item_mm_field(self):
        """
            REST Api: Test create with M-M field
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        item = dict(
            field_string='new1',
            children=[1, 2]
        )
        uri = "api/v1/modelmmapi/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 201)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data[API_RESULT_RES_KEY], {"children": [1, 2], "field_string": "new1"})
        # Test without M-M field data, default is not required
        item = dict(
            field_string='new2'
        )
        uri = "api/v1/modelmmapi/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 201)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data[API_RESULT_RES_KEY], {"children": [], "field_string": "new2"})
        # Test without M-M field data, default is required
        item = dict(
            field_string='new1'
        )
        uri = "api/v1/modelmmrequiredapi/"
        rv = self.auth_client_post(client, token, uri, item)
        eq_(rv.status_code, 422)
        data = json.loads(rv.data.decode("utf-8"))
        eq_(data, {"message": {"children": ["Missing data for required field."]}})

    def test_get_list_col_function(self):
        """
            REST Api: Test get list of objects with columns as functions
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        uri = "api/v1/model1funcapi/"
        rv = self.auth_client_get(client, token, uri)
        data = json.loads(rv.data.decode("utf-8"))
        # Tests count property
        eq_(data["count"], MODEL1_DATA_SIZE)
        # Tests data result default page size
        eq_(len(data[API_RESULT_RES_KEY]), self.model1api.page_size)
        for i in range(1, self.model1api.page_size):
            item = data[API_RESULT_RES_KEY][i - 1]
            eq_(
                item["full_concat"],
                "{}.{}.{}.{}".format("test" + str(i - 1), i - 1, float(i - 1), None),
            )

    def test_openapi(self):
        """
            REST Api: Test OpenAPI spec
        """
        client = self.app.test_client()
        token = self.login(client, USERNAME, PASSWORD)
        uri = "api/v1/_openapi"
        rv = self.auth_client_get(client, token, uri)
        eq_(rv.status_code, 200)
Esempio n. 42
0
class TestAPI(TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.scheduler = APScheduler()
        self.scheduler.api_enabled = True
        self.scheduler.init_app(self.app)
        self.scheduler.start()
        self.client = self.app.test_client()

    def test_scheduler_info(self):
        response = self.client.get(self.scheduler.api_prefix)
        self.assertEqual(response.status_code, 200)
        info = json.loads(response.get_data(as_text=True))
        self.assertIsNotNone(info['current_host'])
        self.assertEqual(info['allowed_hosts'], ['*'])
        self.assertTrue(info['running'])

    def test_add_job(self):
        job = {
            'id': 'job1',
            'func': 'tests.test_api:job1',
            'trigger': 'date',
            'run_date': '2025-12-01T12:30:01+00:00',
        }

        response = self.client.post(self.scheduler.api_prefix + '/jobs',
                                    data=json.dumps(job))
        self.assertEqual(response.status_code, 200)

        job2 = json.loads(response.get_data(as_text=True))

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('run_date'), job2.get('run_date'))

    def test_add_conflicted_job(self):
        job = {
            'id': 'job1',
            'func': 'tests.test_api:job1',
            'trigger': 'date',
            'run_date': '2025-12-01T12:30:01+00:00',
        }

        response = self.client.post(self.scheduler.api_prefix + '/jobs',
                                    data=json.dumps(job))
        self.assertEqual(response.status_code, 200)

        response = self.client.post(self.scheduler.api_prefix + '/jobs',
                                    data=json.dumps(job))
        self.assertEqual(response.status_code, 409)

    def test_add_invalid_job(self):
        job = {
            'id': None,
        }

        response = self.client.post(self.scheduler.api_prefix + '/jobs',
                                    data=json.dumps(job))
        self.assertEqual(response.status_code, 500)

    def test_delete_job(self):
        self.__add_job()

        response = self.client.delete(self.scheduler.api_prefix + '/jobs/job1')
        self.assertEqual(response.status_code, 204)

        response = self.client.get(self.scheduler.api_prefix + '/jobs/job1')
        self.assertEqual(response.status_code, 404)

    def test_delete_job_not_found(self):
        response = self.client.delete(self.scheduler.api_prefix + '/jobs/job1')
        self.assertEqual(response.status_code, 404)

    def test_get_job(self):
        job = self.__add_job()

        response = self.client.get(self.scheduler.api_prefix + '/jobs/job1')
        self.assertEqual(response.status_code, 200)

        job2 = json.loads(response.get_data(as_text=True))

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('minutes'), job2.get('minutes'))

    def test_get_job_not_found(self):
        response = self.client.get(self.scheduler.api_prefix + '/jobs/job1')
        self.assertEqual(response.status_code, 404)

    def test_get_all_jobs(self):
        job = self.__add_job()

        response = self.client.get(self.scheduler.api_prefix + '/jobs')
        self.assertEqual(response.status_code, 200)

        jobs = json.loads(response.get_data(as_text=True))

        self.assertEqual(len(jobs), 1)

        job2 = jobs[0]

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('minutes'), job2.get('minutes'))

    def test_update_job(self):
        job = self.__add_job()

        data_to_update = {
            'args': [1],
            'trigger': 'cron',
            'minute': '*/1',
            'start_date': '2025-01-01'  # means midnight local time
        }

        response = self.client.patch(self.scheduler.api_prefix + '/jobs/job1',
                                     data=json.dumps(data_to_update))
        self.assertEqual(response.status_code, 200)

        job2 = json.loads(response.get_data(as_text=True))

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(data_to_update.get('args'), job2.get('args'))
        self.assertEqual(data_to_update.get('trigger'), job2.get('trigger'))
        self.assertEqual(
            datetime(2025, 1, 1, tzinfo=tzlocal()).isoformat(),
            job2.get('start_date'))
        self.assertEqual(
            datetime(2025, 1, 1, tzinfo=tzlocal()).isoformat(),
            job2.get('next_run_time'))

    def test_update_job_not_found(self):
        data_to_update = {
            'args': [1],
            'trigger': 'cron',
            'minute': '*/1',
            'start_date': '2025-01-01'
        }

        response = self.client.patch(self.scheduler.api_prefix + '/jobs/job1',
                                     data=json.dumps(data_to_update))
        self.assertEqual(response.status_code, 404)

    def test_update_invalid_job(self):
        self.__add_job()

        data_to_update = {
            'trigger': 'invalid_trigger',
        }

        response = self.client.patch(self.scheduler.api_prefix + '/jobs/job1',
                                     data=json.dumps(data_to_update))
        self.assertEqual(response.status_code, 500)

    def test_pause_and_resume_job(self):
        self.__add_job()

        response = self.client.post(self.scheduler.api_prefix +
                                    '/jobs/job1/pause')
        self.assertEqual(response.status_code, 200)
        job = json.loads(response.get_data(as_text=True))
        self.assertIsNone(job.get('next_run_time'))

        response = self.client.post(self.scheduler.api_prefix +
                                    '/jobs/job1/resume')
        self.assertEqual(response.status_code, 200)
        job = json.loads(response.get_data(as_text=True))
        self.assertIsNotNone(job.get('next_run_time'))

    def test_pause_and_resume_job_not_found(self):
        response = self.client.post(self.scheduler.api_prefix +
                                    '/jobs/job1/pause')
        self.assertEqual(response.status_code, 404)

        response = self.client.post(self.scheduler.api_prefix +
                                    '/jobs/job1/resume')
        self.assertEqual(response.status_code, 404)

    def test_run_job(self):
        self.__add_job()

        response = self.client.post(self.scheduler.api_prefix +
                                    '/jobs/job1/run')
        self.assertEqual(response.status_code, 200)
        job = json.loads(response.get_data(as_text=True))
        self.assertIsNotNone(job.get('next_run_time'))

    def test_run_job_not_found(self):
        response = self.client.post(self.scheduler.api_prefix +
                                    '/jobs/job1/run')
        self.assertEqual(response.status_code, 404)

    def __add_job(self):
        job = {
            'id': 'job1',
            'func': 'tests.test_api:job1',
            'trigger': 'interval',
            'minutes': 10,
        }

        response = self.client.post(self.scheduler.api_prefix + '/jobs',
                                    data=json.dumps(job))
        return json.loads(response.get_data(as_text=True))
Esempio n. 43
0
from flask import Flask
from main import rest_methods

import json
import unittest

app = Flask(__name__)
app.register_blueprint(rest_methods)

test_client = app.test_client()


class TestRESTMethods(unittest.TestCase):

    # testing a get request
    def test_get_method(self):

        with test_client.get("/one/") as response:

            self.assertEqual(
                json.loads(response.get_data(as_text=True)), {
                    "request_method": "GET",
                    "request_variables": {
                        "id": "one"
                    },
                    "request_args": {},
                    "request_data": {}
                })

    # testing a post request
    def test_post_method(self):
class TestUserClaimsVerification(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.secret_key = 'super=secret'
        self.jwt_manager = JWTManager(self.app)
        self.client = self.app.test_client()

        @self.jwt_manager.claims_verification_loader
        def claims_verification(user_claims):
            expected_keys = ['foo', 'bar']
            for key in expected_keys:
                if key not in user_claims:
                    return False
            return True

        @self.app.route('/auth/login', methods=['POST'])
        def login():
            ret = {'access_token': create_access_token('test')}
            return jsonify(ret), 200

        @self.app.route('/protected')
        @jwt_required
        def protected():
            return jsonify({'msg': "hello world"})

    def _jwt_get(self,
                 url,
                 jwt,
                 header_name='Authorization',
                 header_type='Bearer'):
        header_type = '{} {}'.format(header_type, jwt).strip()
        response = self.client.get(url, headers={header_name: header_type})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        return status_code, data

    def test_valid_user_claims(self):
        @self.jwt_manager.user_claims_loader
        def user_claims_callback(identity):
            return {'foo': 'baz', 'bar': 'boom'}

        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        status, data = self._jwt_get('/protected', access_token)
        self.assertEqual(data, {'msg': 'hello world'})
        self.assertEqual(status, 200)

    def test_empty_claims_verification_error(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        status, data = self._jwt_get('/protected', access_token)
        self.assertEqual(data, {'msg': 'User claims verification failed'})
        self.assertEqual(status, 400)

    def test_bad_claims_verification_error(self):
        @self.jwt_manager.user_claims_loader
        def user_claims_callback(identity):
            return {'super': 'banana'}

        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        status, data = self._jwt_get('/protected', access_token)
        self.assertEqual(data, {'msg': 'User claims verification failed'})
        self.assertEqual(status, 400)

    def test_bad_claims_custom_error_callback(self):
        @self.jwt_manager.claims_verification_failed_loader
        def user_claims_callback():
            return jsonify({'foo': 'bar'}), 404

        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        status, data = self._jwt_get('/protected', access_token)
        self.assertEqual(data, {'foo': 'bar'})
        self.assertEqual(status, 404)
Esempio n. 45
0
class TestIwebModule(unittest.TestCase):
    def setUp(self):
        self.data_store = DataStore()
        self.iweb = IWeb(self.data_store)

        self.app = Flask(__name__)
        self.base_iweb_api = UbersmithBase(self.data_store)

        self.iweb.hook_to(self.base_iweb_api)
        self.base_iweb_api.hook_to(self.app)

    def test_log_event_successfully(self):
        with self.app.test_client() as c:
            resp = c.post('api/2.0/',
                          data={
                              "method": "iweb.log_event",
                              "event_type": "Some event type",
                              "reference_type": "client",
                              "action":
                              "Client Id 12345 performed a 'method'.",
                              "clientid": "clientid",
                              "user": "******",
                              "reference_id": "clientid"
                          })

        self.assertEqual(resp.status_code, 200)
        self.assertEqual(json.loads(resp.data.decode('utf-8')), {
            "data": "1",
            "error_code": None,
            "error_message": "",
            "status": True
        })
        self.assertEqual(
            self.data_store.event_log[0], {
                "event_type": "Some event type",
                "reference_type": "client",
                "action": "Client Id 12345 performed a 'method'.",
                "clientid": "clientid",
                "user": "******",
                "reference_id": "clientid"
            })

    def test_add_role_successfully(self):
        with self.app.test_client() as c:
            resp = c.post('api/2.0/',
                          data={
                              "method": "iweb.acl_admin_role_add",
                              'name': 'A Admin Role',
                              'descr': 'A Admin Role',
                              'acls[admin.portal][read]': 1,
                          })

        role_id = next(iter(self.data_store.roles))

        self.assertEqual(resp.status_code, 200)
        self.assertEqual(
            json.loads(resp.data.decode('utf-8')), {
                "data": role_id,
                "error_code": None,
                "error_message": "",
                "status": True
            })
        self.assertEqual(
            self.data_store.roles, {
                role_id: {
                    'role_id': role_id,
                    'name': 'A Admin Role',
                    'descr': 'A Admin Role',
                    'acls': {
                        'admin.portal': {
                            'read': '1'
                        }
                    }
                }
            })

    def test_add_role_that_already_exists_fails(self):
        self.data_store.roles = {
            'some_role_id': {
                'role_id': 'some_role_id',
                'name': 'A Admin Role'
            }
        }

        with self.app.test_client() as c:
            resp = c.post('api/2.0/',
                          data={
                              "method": "iweb.acl_admin_role_add",
                              'name': 'A Admin Role',
                              'descr': 'A Admin Role',
                              'acls[admin.portal][read]': 1,
                          })

        self.assertEqual(resp.status_code, 200)
        self.assertEqual(
            json.loads(resp.data.decode('utf-8')), {
                "data": "",
                "error_code": 1,
                "error_message": "The specified Role Name is already in use",
                "status": False
            })

    def test_add_user_role_successfully(self):
        user_id = 'some_user_id'
        self.data_store.roles = {'a_role_id': {}, 'another_role_id': {}}
        with self.app.test_client() as c:
            c.post('api/2.0/',
                   data={
                       "method": "iweb.user_role_assign",
                       "user_id": user_id,
                       "role_id": "a_role_id"
                   })
            resp = c.post('api/2.0/',
                          data={
                              "method": "iweb.user_role_assign",
                              "user_id": user_id,
                              "role_id": "another_role_id"
                          })

        self.assertEqual(resp.status_code, 200)
        self.assertEqual(json.loads(resp.data.decode('utf-8')), {
            "status": True,
            "error_code": None,
            "error_message": "",
            "data": 1
        })
        self.assertEqual(self.data_store.user_mapping[user_id],
                         {'roles': {'a_role_id', 'another_role_id'}})

    def test_add_same_role_to_user_not_allowed(self):
        role_id = 'some_role_id'
        user_id = 'some_user_id'
        self.data_store.roles = {'1': {}}
        self.data_store.user_mapping[user_id] = {'roles': {role_id}}
        with self.app.test_client() as c:
            resp = c.post('api/2.0/',
                          data={
                              "method": "iweb.user_role_assign",
                              "user_id": user_id,
                              "role_id": role_id
                          })

        self.assertEqual(resp.status_code, 200)
        self.assertEqual(
            json.loads(resp.data.decode('utf-8')), {
                "error_code":
                1,
                "error_message":
                "Can't assign role with id '{}' to user "
                "with id '{}'".format(role_id, user_id),
                "status":
                False,
                "data":
                ""
            })
Esempio n. 46
0
    def setUp(self):
        app = Flask(__name__)
        app.register_blueprint(InternalBlueprint)
        app.testing = True

        self.app = app.test_client()
Esempio n. 47
0
class BootstrapTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.testing = True
        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
        self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
        self.app.secret_key = 'for test'
        self.bootstrap = Bootstrap(self.app)  # noqa

        @self.app.route('/')
        def index():
            return render_template_string(
                '{{ bootstrap.load_css() }}{{ bootstrap.load_js() }}')

        self.context = self.app.test_request_context()
        self.context.push()
        self.client = self.app.test_client()

    def tearDown(self):
        self.context.pop()

    def test_extension_init(self):
        self.assertIn('bootstrap', current_app.extensions)

    def test_load_css(self):
        rv = self.bootstrap.load_css()
        self.assertIn('bootstrap.min.css', rv)

    def test_load_js(self):
        rv = self.bootstrap.load_js()
        self.assertIn('bootstrap.min.js', rv)

    def test_local_resources(self):
        current_app.config['BOOTSTRAP_SERVE_LOCAL'] = True

        response = self.client.get('/')
        data = response.get_data(as_text=True)
        self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', data)
        self.assertIn('bootstrap.min.js', data)
        self.assertIn('bootstrap.min.css', data)
        self.assertIn('jquery.min.js', data)

        css_response = self.client.get(
            '/bootstrap/static/css/bootstrap.min.css')
        js_response = self.client.get('/bootstrap/static/js/bootstrap.min.js')
        jquery_response = self.client.get('/bootstrap/static/jquery.min.js')
        self.assertNotEqual(css_response.status_code, 404)
        self.assertNotEqual(js_response.status_code, 404)
        self.assertNotEqual(jquery_response.status_code, 404)

        css_rv = self.bootstrap.load_css()
        js_rv = self.bootstrap.load_js()
        self.assertIn('/bootstrap/static/css/bootstrap.min.css', css_rv)
        self.assertIn('/bootstrap/static/js/bootstrap.min.js', js_rv)
        self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv)
        self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv)

    def test_cdn_resources(self):
        response = self.client.get('/')
        data = response.get_data(as_text=True)
        self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', data)
        self.assertIn('bootstrap.min.js', data)
        self.assertIn('bootstrap.min.css', data)

        css_rv = self.bootstrap.load_css()
        js_rv = self.bootstrap.load_js()
        self.assertNotIn('/bootstrap/static/css/bootstrap.min.css', css_rv)
        self.assertNotIn('/bootstrap/static/js/bootstrap.min.js', js_rv)
        self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv)
        self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv)

    def test_render_field(self):
        @self.app.route('/field')
        def test():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_field %}
            {{ render_field(form.username) }}
            {{ render_field(form.password) }}
            ''',
                                          form=form)

        response = self.client.get('/field')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<input class="form-control" id="username" name="username"', data)
        self.assertIn(
            '<input class="form-control" id="password" name="password"', data)

    def test_render_form(self):
        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
                    {% from 'bootstrap/form.html' import render_form %}
                    {{ render_form(form) }}
                    ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<input class="form-control" id="username" name="username"', data)
        self.assertIn(
            '<input class="form-control" id="password" name="password"', data)

    def test_render_pager(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)

        @self.app.route('/pager')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(100):
                m = Message()
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            return render_template_string('''
                            {% from 'bootstrap/pagination.html' import render_pager %}
                            {{ render_pager(pagination) }}
                            ''',
                                          pagination=pagination,
                                          messages=messages)

        response = self.client.get('/pager')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn('Previous', data)
        self.assertIn('Next', data)
        self.assertIn('<li class="page-item disabled">', data)

        response = self.client.get('/pager?page=2')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn('Previous', data)
        self.assertIn('Next', data)
        self.assertNotIn('<li class="page-item disabled">', data)

    def test_render_pagination(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)

        @self.app.route('/pagination')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(100):
                m = Message()
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            return render_template_string('''
                                    {% from 'bootstrap/pagination.html' import render_pagination %}
                                    {{ render_pagination(pagination) }}
                                    ''',
                                          pagination=pagination,
                                          messages=messages)

        response = self.client.get('/pagination')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn(
            '<a class="page-link" href="#">1 <span class="sr-only">(current)</span></a>',
            data)
        self.assertIn('10</a>', data)

        response = self.client.get('/pagination?page=2')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn('1</a>', data)
        self.assertIn(
            '<a class="page-link" href="#">2 <span class="sr-only">(current)</span></a>',
            data)
        self.assertIn('10</a>', data)

    def test_render_nav_item(self):
        @self.app.route('/nav_item')
        def test():
            return render_template_string('''
                    {% from 'bootstrap/nav.html' import render_nav_item %}
                    {{ render_nav_item('test', 'Home') }}
                    ''')

        response = self.client.get('/nav_item')
        data = response.get_data(as_text=True)
        self.assertIn('<a class="nav-item nav-link active"', data)

    def test_render_breadcrumb_item(self):
        @self.app.route('/breadcrumb_item')
        def test():
            return render_template_string('''
                    {% from 'bootstrap/nav.html' import render_breadcrumb_item %}
                    {{ render_breadcrumb_item('test', 'Home') }}
                    ''')

        response = self.client.get('/breadcrumb_item')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<li class="breadcrumb-item active"  aria-current="page">', data)

    def test_render_static(self):
        @self.app.route('/test_static')
        def test():
            return render_template_string('''
                            {% from 'bootstrap/utils.html' import render_static %}
                            {{ render_static('css', 'test.css') }}
                            {{ render_static('js', 'test.js') }}
                            {{ render_static('icon', 'test.ico') }}
                            ''')

        response = self.client.get('/test_static')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<link rel="stylesheet" href="/static/test.css" type="text/css">',
            data)
        self.assertIn(
            '<script type="text/javascript" src="/static/test.js"></script>',
            data)
        self.assertIn('<link rel="icon" href="/static/test.ico">', data)
Esempio n. 48
0
class WebServer:
    def __init__(self, database):
        self.db = database
        self.app = Flask(__name__)

        UPLOAD_FOLDER = "\\uploads\\sounds"
        self.app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

        self.app.secret_key = SECRET_KEY

        self.set_routes()

    def get_app(self):
        return self.app.test_client()

    def start(self):
        self.app.run(debug=False, host='0.0.0.0')

    def shutdown_server(self):
        func = request.environ.get('werkzeug.server.shutdown')
        if func is None:
            raise RuntimeError('Not running with the Werkzeug Server')
        func()

    #Returns true if filename has an allowed extension
    def allowed_file(filename):

        return WebServer.get_file_ext(filename) in ALLOWED_EXTENSIONS

    def get_file_ext(filename):
        return '.' in filename and filename.rsplit('.', 1)[1].lower()

    def set_routes(self):
        @self.app.route("/")
        def index():
            user = self.db.get_user(
                id=session['userID']) if 'userID' in session else None
            return render_template('index.html', user=user)

        @self.app.route('/signup', methods=['POST', 'GET'])
        def signup():
            if request.method == 'POST':
                email = request.form['email']
                if not self.db.is_email_used(email):
                    salt = uuid.uuid4().hex
                    if len(request.form['password']) < 8:
                        return render_template('signup.html',
                                               message="short-pass")

                    if request.form['password'] == request.form['confpass']:
                        pass_hash = hashlib.sha512(
                            str(request.form['password'] +
                                salt).encode('utf-8')).hexdigest()
                        user = User(email, pass_hash, salt)
                        user = self.db.insert_user(user)
                        session['userID'] = user.id
                        return redirect('/')
                    else:
                        return render_template('signup.html',
                                               message="no-match-pass")
                else:
                    return render_template('signup.html', message="used-email")
            else:
                return render_template('signup.html', message="")

        @self.app.route('/login', methods=['POST', 'GET'])
        def login():
            if request.method == 'POST':
                email = request.form['email']
                if self.db.is_email_used(email):
                    user = self.db.get_user(email=email)
                    input_pass_hash = hashlib.sha512(
                        str(request.form['password'] +
                            user.salt).encode('utf-8')).hexdigest()
                    if input_pass_hash == user.hash:
                        session['userID'] = user.id
                        if 'ref_url' in session:
                            ref = session['ref_url']
                            session.pop('ref_url', None)
                            return redirect('/' + ref)
                        else:
                            return redirect('/')
                    else:
                        # user = self.db.get_user(id=session['userID']) if 'userID' in session else None
                        return render_template('login.html',
                                               message="wrong-pass")
                else:
                    # user = self.db.get_user(id=session['userID']) if 'userID' in session else None
                    return render_template('login.html',
                                           message="no-match-email")
            else:  # GET
                if 'ref' in request.args:
                    session['ref_url'] = request.args.get('ref')

                # user = self.db.get_user(id=session['userID']) if 'userID' in session else None
                return render_template('login.html', message="")

        @self.app.route('/logout', methods=['GET'])
        def logout():
            session.pop('userID', None)
            return redirect('/')

        @self.app.route('/dashboard', methods=['GET'])
        def dashboard():
            if 'userID' not in session: return redirect('/login?ref=dashboard')

            simulations = self.db.get_user_sims(session['userID'])
            sounds = self.db.get_user_sounds(session['userID'])
            robots = self.db.get_user_robots(session['userID'])
            microphones = self.db.get_user_mics(session['userID'])

            user_added_items = self.db.get_all(
                'SELECT * FROM user_added_items WHERE userID = ?',
                [session['userID']],
                type=UserAddedItem)

            added_items = [
                self.db.get_one('SELECT * FROM public_items WHERE id = ?',
                                [x.itemID],
                                type=PublicItem) for x in user_added_items
            ]
            add_count = {
                x.id: len(
                    self.db.get_all(
                        'SELECT * FROM user_added_items WHERE itemID = ?',
                        [x.id],
                        type=UserAddedItem))
                for x in added_items
            }

            return render_template('dashboard.html', user=self.db.get_user(id=session['userID']), \
                                    simulations=simulations, sounds=sounds, robots=robots, microphones=microphones, addedItems=added_items, addCount=add_count)

        @self.app.route('/simulator', methods=['GET'])
        def simulator():
            if 'userID' not in session: return redirect('/login?ref=simulator')

            simconf = ""
            if 'sim' in request.args:
                sim = self.db.get_simulation(request.args['sim'])

                if sim.userID == session['userID']:
                    simconf = open(sim.pathToConfig).read()

            sounds = self.db.get_user_sounds(session['userID'])
            robots = self.db.get_user_robots(session['userID'])

            return render_template('simulator.html',
                                   user=self.db.get_user(id=session['userID']),
                                   sounds=sounds,
                                   robots=robots,
                                   simconf=simconf)

        @self.app.route("/removesimulation", methods=['POST'])
        def remove_sim():
            if 'userID' not in session: return jsonify({'success': 'false'})
            sim = self.db.get_simulation(request.form['sim_id'])
            try:
                os.remove(sim.pathToConfig)
                os.remove(sim.pathToZip)
            except:
                print("Cannot delete simulation files!")

            self.db.delete_simulation(request.form['sim_id'])

            return jsonify({'success': 'true'})

        @self.app.route("/removesound", methods=['POST'])
        def remove_sound():
            if 'userID' not in session: return jsonify({'success': 'false'})
            sound = self.db.get_sound(request.form['sound_id'])
            try:
                os.remove(sound.pathToFile)
            except:
                print("Cannot delete sound file!")

            self.db.delete_sound(request.form['sound_id'])

            return jsonify({'success': 'true'})

        @self.app.route("/removerobot", methods=['POST'])
        def remove_robot():
            if 'userID' not in session: return jsonify({'success': 'false'})
            robot = self.db.get_robot(request.form['robot_id'])
            try:
                os.remove(robot.pathToConfig)
            except:
                print("Cannot delete robot file!")

            self.db.delete_robot(request.form['robot_id'])

            return jsonify({'success': 'true'})

        @self.app.route("/removemic", methods=['POST'])
        def remove_mic():
            if 'userID' not in session: return jsonify({'success': 'false'})
            mic = self.db.get_microphone(request.form['mic_id'])
            try:
                os.remove(mic.pathToConfig)
            except:
                print("Cannot delete microphone file!")

            self.db.delete_microphone(request.form['mic_id'])

            return jsonify({'success': 'true'})

        @self.app.route('/revoke_simulation', methods=['POST'])
        def revoke_simulation():
            if 'userID' not in session: return jsonify({"success": "false"})

            simid = request.form['sim_id']
            sim = self.db.get_simulation(simid)
            endTask(sim.taskID)

            self.db.run_query("UPDATE simulations SET state = ? WHERE id = ?",
                              ("cancelled", simid))

            return jsonify({"success": "true"})

        @self.app.route('/simulator/run_simulation', methods=['POST'])
        def run_simulation():
            if 'userID' not in session:
                return jsonify({"success": "false", "reason": "no user found"})

            if 'sim_id' in request.form:
                old_sim = self.db.get_simulation(request.form['sim_id'])

                sim_file = open(old_sim.pathToConfig).read()
                sim_conf = json.loads(sim_file)
            else:
                sim_conf = json.loads(request.form['config'])

            # Link variables
            sim_conf['simulation_config'] = util.link_vars(
                sim_conf['simulation_config'], sim_conf['variables'])

            seed = str(uuid.uuid4())
            sim_conf['simulation_config']['seed'] = seed

            if 'sim_to_update' in request.form:
                sim = self.db.get_simulation(request.form['sim_to_update'])
                if sim.userID == session['userID']:
                    filename = sim.pathToConfig

                    with open(filename, 'w') as f:
                        json.dump(sim_conf,
                                  f,
                                  sort_keys=False,
                                  indent=4,
                                  ensure_ascii=False)

                    return jsonify({"success": "true"})
                else:
                    return jsonify({"success": "false"})
            else:
                # Save the JSON config to a file
                unique_name = uuid.uuid4()
                filename = UPLOAD_DIR + "simulation_configs/{0}.json".format(
                    unique_name)
                print("putting sim file in: {0}".format(filename))
                with open(filename, 'w') as f:
                    json.dump(sim_conf,
                              f,
                              sort_keys=False,
                              indent=4,
                              ensure_ascii=False)

                sounds = {'utterance': None, 'bgnoise': None}

                # Check for utterance ID
                if 'utterance' in request.files:
                    sounds['utterance'] = processFileUpload(
                        request.files['utterance'], session['userID'])
                else:
                    sounds['utterance'] = self.db.get_sound(
                        request.form['utterance_id'])

                # Checking the utterance file.
                utt = sounds['utterance']
                if utt is None or (
                        utt.userID != session['userID']
                        and not self.db.item_is_public(utt.id, "SOUND")):
                    return jsonify({
                        "success": "false",
                        "reason": "No sound found."
                    })

                if 'bgnoise' in request.files:
                    sounds['bgnoise'] = processFileUpload(
                        request.files['bgnoise'], session['userID'])
                else:
                    if int(request.form['bgnoise_id']) >= 0:
                        sounds['bgnoise'] = self.db.get_sound(
                            request.form['bgnoise_id'])

                bgn = sounds['bgnoise']
                if bgn is not None and (
                        bgn.userID != session['userID']
                        and not self.db.item_is_public(bgn.id, "SOUND")):
                    return jsonify({
                        "success": "false",
                        "reason": "No background noise sound found."
                    })

                sounds['utterance'] = sounds['utterance'].pathToFile
                sounds['bgnoise'] = sounds['bgnoise'].pathToFile if sounds[
                    'bgnoise'] is not None else None

                robotid = request.form['robot_id']
                robot = self.db.get_one("SELECT * FROM robots WHERE id =?",
                                        [robotid],
                                        type=Robot)

                if robot is None or (
                        robot.userID != session['userID']
                        and not self.db.item_is_public(robot.id, "ROBOT")):
                    print("Robot: {0}".format(robot))
                    if robot is not None:
                        print("userID {0}, sessionID {1}".format(
                            robot.userID, session['userID']))
                    print("public item {0}".format(
                        self.db.item_is_public(robot.id, "ROBOT")))
                    return jsonify({
                        "success": "false",
                        "reason": "No robot found."
                    })

                robot_conf = open(robot.pathToConfig).read()
                robot_conf_dict = json.loads(robot_conf)

                date = str(dt.now().date())
                sim = Simulation(filename, date, seed, session['userID'])
                sim = self.db.insert_simulation(sim)

                print("Running sim with: ", sim_cong, robot_conf_dict, sounds)

                runSimulation.delay(sim_conf, robot_conf_dict, sounds,
                                    unique_name, sim.id)

                return jsonify({"success": "true"})

        def processFileUpload(file, user_id):
            # if user does not select file, browser also submit a empty part without filename
            if file.filename == '':
                # print("no file selected in uploads!")
                return None
            if file and WebServer.allowed_file(file.filename):
                unique_name = uuid.uuid4()

                if WebServer.get_file_ext(file.filename) == 'wav':
                    savename = UPLOAD_DIR + 'sounds/{0}.wav'.format(
                        unique_name)
                    file.save(savename)
                    sound = self.db.insert_sound(
                        Sound(file.filename, savename, user_id))
                    return sound
                else:  # Must be a txt/mic response file
                    savename = UPLOAD_DIR + 'mic_responses/{0}.txt'.format(
                        unique_name)
                    file.save(savename)
                    mic = self.db.insert_microphone(
                        Microphone(file.filename, savename, user_id))
                    return mic
            else:
                return None

        # Conf = the robot config_robot
        # files = the dictionary or sound and file paths
        def insertFilePaths(conf, files, mot_id_map, mic_id_map):
            motors = conf['robot_config']['motors']
            mics = conf['robot_config']['microphones']
            for motor in motors:
                if str(motor['id']) in mot_id_map:
                    snd_id = str(mot_id_map[str(motor['id'])])
                    if snd_id in files['sounds']:
                        motor['sound']['uid'] = files['sounds'][snd_id].id
                        motor['sound']['path'] = files['sounds'][
                            snd_id].pathToFile
                    else:  # else Must be a pre existing sound
                        sound_obj = self.db.get_sound(snd_id)
                        motor['sound']['uid'] = sound_obj.id
                        motor['sound']['path'] = sound_obj.pathToFile

            for mic in mics:
                if str(mic['id']) in mic_id_map:
                    res_id = str(mic_id_map[str(mic['id'])])
                    if res_id in files['responses']:
                        mic['mic_style']['uid'] = files['responses'][res_id].id
                        mic['mic_style']['path'] = files['responses'][
                            res_id].pathToFile
                    else:  # else Must be a pre existing response
                        mic_obj = self.db.get_microphone(res_id)
                        mic['mic_style']['uid'] = mic_obj.id
                        mic['mic_style']['path'] = mic_obj.pathToFile

            return conf

        @self.app.route('/designer/save', methods=['POST'])
        def save_robot_config():
            if 'userID' not in session: return jsonify({"success": "false"})

            # Process Sounds
            files = {'sounds': {}, 'responses': {}}
            for id in request.files:
                f = request.files[id]
                file_obj = processFileUpload(f, session['userID'])
                if file_obj == None:
                    return jsonify({
                        "success": "false",
                        "reason": "invalid file"
                    })
                elif isinstance(file_obj, Sound):
                    files['sounds'][str(id)] = file_obj
                else:
                    files['responses'][str(id)] = file_obj

            #Load config and mot_id to i map
            conf = json.loads(request.form['robot-config'])  # robot config
            mot_id_map = json.loads(request.form['mot_id_map']
                                    )  # map of motor id to motor sound file/id
            mic_id_map = json.loads(request.form['mic_id_map']
                                    )  # map of mic id to mic response file/id

            # Update the config with new id values
            conf = insertFilePaths(conf, files, mot_id_map, mic_id_map)

            if 'robot_to_update' in request.form:
                robot = self.db.get_robot(request.form['robot_to_update'])
                if robot.userID == session['userID']:
                    filename = robot.pathToConfig

                    with open(filename, 'w') as f:
                        json.dump(conf,
                                  f,
                                  sort_keys=False,
                                  indent=4,
                                  ensure_ascii=False)

            else:
                # Write the config to a file
                unique_name = uuid.uuid4()
                filename = UPLOAD_DIR + 'robot_configs/{0}.json'.format(
                    unique_name)

                with open(filename, 'w') as f:
                    json.dump(conf,
                              f,
                              sort_keys=False,
                              indent=4,
                              ensure_ascii=False)

                robot = Robot(request.form['robot_name'], filename,
                              session['userID'])
                robot = self.db.insert_robot(robot)

            return jsonify({"success": "true"})

        @self.app.route('/robotdesign', methods=['GET'])
        def robotdesigner():
            if 'userID' not in session:
                return redirect('/login?ref=robotdesign')

            sounds = self.db.get_user_sounds(session['userID'])
            mics = self.db.get_user_mics(session['userID'])

            if 'robot' in request.args:
                robotID = request.args['robot']
                robot = self.db.get_robot(robotID)

                if robot.userID != session['userID']:
                    robot = None
                    robot_conf = ""
                else:
                    robot_conf = open(robot.pathToConfig).read()
            else:
                robot_conf = ""
                robot = None

            return render_template('robotdesign.html',
                                   user=self.db.get_user(id=session['userID']),
                                   sounds=sounds,
                                   mic_responses=mics,
                                   robotconfig=robot_conf,
                                   robot=robot)

        @self.app.route('/upload_config', methods=['POST'])
        def review_config():
            if 'userID' not in session: return jsonify({"success": "false"})

            if 'file' in request.files:
                file = request.files['file']
                myfile = file.read()
                ret = {"success": "true", "config": json.loads(myfile)}

                # Needs to check validity of uploaded file too.

                return jsonify(ret)
            else:
                return jsonify({"success": "false"})

        @self.app.route("/dl/<path>")
        def downloadLogFile(path=None):
            filename = os.path.join(self.app.root_path, 'static', 'dl', path)
            return send_file(filename, as_attachment=True)

        @self.app.route("/uploads/sounds/<name>")
        def serveSound(name=None):
            # //filename = os.path.join(self.app.root_path, 'uploads' ,'sounds', name)
            return send_file("server/uploads/sounds/{0}".format(name),
                             as_attachment=True)

        @self.app.route("/search")
        def search():
            query = request.args['query'] if 'query' in request.args else None
            searchFor = request.args['type'] if 'type' in request.args else '%'

            relevantItems = []

            if query == None:
                relevantItems = self.db.get_all(
                    'SELECT * FROM public_items WHERE type LIKE ?',
                    [searchFor],
                    type=PublicItem)
            else:
                relevantItems = self.db.get_all(
                    "SELECT * FROM public_items WHERE type LIKE ? AND (name LIKE ? OR description LIKE ?)",
                    [searchFor, '%' + query + '%', '%' + query + '%'],
                    type=PublicItem)

            add_count = {
                x.id: len(
                    self.db.get_all(
                        'SELECT * FROM user_added_items WHERE itemID = ?',
                        [x.id],
                        type=UserAddedItem))
                for x in relevantItems
            }

            return render_template('search.html',
                                   user=self.db.get_user(id=session['userID']),
                                   items=relevantItems,
                                   addCount=add_count)

        @self.app.route("/quicksearch")
        def quick_search():
            query = request.args['query'] if 'query' in request.args else None
            searchFor = request.args['type'] if 'type' in request.args else '%'

            relevantItems = []

            if query == None:
                relevantItems = self.db.get_all(
                    'SELECT * FROM public_items WHERE type LIKE ?',
                    [searchFor],
                    type=PublicItem)
            else:
                relevantItems = self.db.get_all(
                    "SELECT * FROM public_items WHERE type LIKE ? AND (name LIKE ? OR description LIKE ?)",
                    [searchFor, '%' + query + '%', '%' + query + '%'],
                    type=PublicItem)

            processedItems = [{
                'id': x.id,
                'name': x.name,
                'desc': x.description,
                'likes': x.likes,
                'type': x.type,
                'itemID': x.itemID
            } for x in relevantItems]

            return jsonify({'result': processedItems})

        @self.app.route("/publish", methods=['POST'])
        def publish():
            item = None
            if request.form['type'] == "ROBOT":
                item = self.db.get_robot(request.form['id'])
            elif request.form['type'] == "SIM":
                item = self.db.get_simulation(request.form['id'])
            elif request.form['type'] == "SOUND":
                item = self.db.get_sound(request.form['id'])
            elif request.form['type'] == "MIC":
                item = self.db.get_microphone(request.form['id'])

            if item == None or item.userID != session['userID']:
                return redirect('/dashboard')

            date = str(dt.now().date())
            publicItem = PublicItem(request.form['name'],
                                    request.form['desc'],
                                    request.form['type'],
                                    request.form['id'],
                                    session['userID'],
                                    publishDate=date)
            publicItem = self.db.insert_public_item(publicItem)

            return redirect('/search?query={0}&type={1}'.format(
                publicItem.name, publicItem.type))

        @self.app.route("/toggle_like", methods=["POST"])
        def toggleLike():
            if 'userID' not in session: return jsonify({"success": "false"})

            if 'item' not in request.form:
                return jsonify({"success": "false"})

            existing_like = self.db.get_one(
                "SELECT * FROM user_liked_items WHERE itemID = ? AND userID = ?",
                [request.form['item'], session['userID']],
                type=UserLikedItem)

            if existing_like is None:
                likedItem = UserLikedItem(session['userID'],
                                          request.form['item'])
                self.db.insert_user_liked_item(likedItem)
            else:
                self.db.run_query('DELETE FROM user_liked_items WHERE id = ?',
                                  [existing_like.id])

            allLikes = self.db.get_all(
                'SELECT * FROM user_liked_items WHERE itemID = ?',
                [request.form['item']],
                type=UserLikedItem)

            self.db.run_query('UPDATE public_items SET likes = ? WHERE id = ?',
                              [len(allLikes), request.form['item']])
            return jsonify({"success": "true", "like_count": len(allLikes)})

        @self.app.route("/toggle_add", methods=["POST"])
        def toggle_add():
            if 'userID' not in session: return jsonify({"success": "false"})

            if 'item' not in request.form:
                return jsonify({"success": "false"})

            existing_add = self.db.get_one(
                "SELECT * FROM user_added_items WHERE itemID = ? AND userID = ?",
                [request.form['item'], session['userID']],
                type=UserAddedItem)

            if existing_add is None:
                addedItem = UserAddedItem(session['userID'],
                                          request.form['item'])
                self.db.insert_user_added_item(addedItem)
            else:
                self.db.run_query('DELETE FROM user_added_items WHERE id = ?',
                                  [existing_add.id])

            allAdds = self.db.get_all(
                'SELECT * FROM user_added_items WHERE itemID = ?',
                [request.form['item']],
                type=UserAddedItem)

            return jsonify({"success": "true", "add_count": len(allAdds)})

        @self.app.route('/documentation')
        def documentation():
            p = request.args['p'] if 'p' in request.args else 'introduction'

            doc = ""
            try:
                with open('server/doc_files/{0}.md'.format(p), 'r') as f:
                    doc = f.read()
            except:
                return jsonify({'success': 'false'})
            doc_html = Markup(
                markdown.markdown(doc, extensions=['fenced_code']))

            return jsonify({'success': 'true', 'html': doc_html})

        @self.app.route('/getrobotconfig')
        def getrobotconfig():
            if 'userID' not in session:
                return jsonify({
                    "success": "false",
                    "reason": "No user session"
                })

            robotID = request.args['robot'] if 'robot' in request.args else None
            print(robotID)
            if robotID is not None:
                robot = self.db.get_one("SELECT * FROM robots WHERE id =?",
                                        [robotID],
                                        type=Robot)

                if robot is not None and (robot.userID == session['userID']
                                          or self.db.item_is_public(
                                              robot.id, "ROBOT")):
                    robot_conf = open(robot.pathToConfig).read()
                    return jsonify({
                        "success": "true",
                        "robot": json.loads(robot_conf)
                    })
                else:
                    return jsonify({
                        "success": "false",
                        "reason": "robot not found"
                    })
            else:
                return jsonify({
                    "success": "false",
                    "reason": "robotID not sent"
                })
Esempio n. 49
0
def getClient():
    app = Flask(__name__)
    define_routes(app)
    client = app.test_client()
    return client
def client(app: Flask):
    with app.test_client() as c:
        yield c
class TestEndpoints(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.secret_key = 'super=secret'
        self.app.config['JWT_ALGORITHM'] = 'HS256'
        self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
        self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
        self.jwt_manager = JWTManager(self.app)
        self.client = self.app.test_client()

        @self.app.route('/auth/login', methods=['POST'])
        def login():
            ret = {
                'access_token': create_access_token('test', fresh=True),
                'refresh_token': create_refresh_token('test')
            }
            return jsonify(ret), 200

        @self.app.route('/auth/refresh', methods=['POST'])
        @jwt_refresh_token_required
        def refresh():
            username = get_jwt_identity()
            ret = {'access_token': create_access_token(username, fresh=False)}
            return jsonify(ret), 200

        @self.app.route('/auth/fresh-login', methods=['POST'])
        def fresh_login():
            ret = {'access_token': create_access_token('test', fresh=True)}
            return jsonify(ret), 200

        @self.app.route('/protected')
        @jwt_required
        def protected():
            return jsonify({'msg': "hello world"})

        @self.app.route('/fresh-protected')
        @fresh_jwt_required
        def fresh_protected():
            return jsonify({'msg': "fresh hello world"})

    def _jwt_post(self, url, jwt):
        response = self.client.post(
            url,
            content_type='application/json',
            headers={'Authorization': 'Bearer {}'.format(jwt)})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        return status_code, data

    def _jwt_get(self,
                 url,
                 jwt,
                 header_name='Authorization',
                 header_type='Bearer'):
        header_type = '{} {}'.format(header_type, jwt).strip()
        response = self.client.get(url, headers={header_name: header_type})
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        return status_code, data

    def test_login(self):
        response = self.client.post('/auth/login')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertIn('access_token', data)
        self.assertIn('refresh_token', data)

    def test_fresh_login(self):
        response = self.client.post('/auth/fresh-login')
        status_code = response.status_code
        data = json.loads(response.get_data(as_text=True))
        self.assertEqual(status_code, 200)
        self.assertIn('access_token', data)
        self.assertNotIn('refresh_token', data)

    def test_refresh(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']
        refresh_token = data['refresh_token']

        status_code, data = self._jwt_post('/auth/refresh', refresh_token)
        self.assertEqual(status_code, 200)
        self.assertIn('access_token', data)
        self.assertNotIn('refresh_token', data)

    def test_wrong_token_refresh(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        # Try to refresh with an access token instead of a refresh one
        status_code, data = self._jwt_post('/auth/refresh', access_token)
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

    def test_jwt_required(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        fresh_access_token = data['access_token']
        refresh_token = data['refresh_token']

        # Test it works with a fresh token
        status, data = self._jwt_get('/protected', fresh_access_token)
        self.assertEqual(data, {'msg': 'hello world'})
        self.assertEqual(status, 200)

        # Test it works with a non-fresh access token
        _, data = self._jwt_post('/auth/refresh', refresh_token)
        non_fresh_token = data['access_token']
        status, data = self._jwt_get('/protected', non_fresh_token)
        self.assertEqual(status, 200)
        self.assertEqual(data, {'msg': 'hello world'})

    def test_jwt_required_wrong_token(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        refresh_token = data['refresh_token']

        # Shouldn't work with a refresh token
        status, text = self._jwt_get('/protected', refresh_token)
        self.assertEqual(status, 422)

    def test_fresh_jwt_required(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        fresh_access_token = data['access_token']
        refresh_token = data['refresh_token']

        # Test it works with a fresh token
        status, data = self._jwt_get('/fresh-protected', fresh_access_token)
        self.assertEqual(data, {'msg': 'fresh hello world'})
        self.assertEqual(status, 200)

        # Test it works with a non-fresh access token
        _, data = self._jwt_post('/auth/refresh', refresh_token)
        non_fresh_token = data['access_token']
        status, text = self._jwt_get('/fresh-protected', non_fresh_token)
        self.assertEqual(status, 401)

    def test_fresh_jwt_required_wrong_token(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        refresh_token = data['refresh_token']

        # Shouldn't work with a refresh token
        status, text = self._jwt_get('/fresh-protected', refresh_token)
        self.assertEqual(status, 422)

    def test_without_secret_key(self):
        app = Flask(__name__)
        app.testing = True  # Propagate exceptions
        JWTManager(app)
        client = app.test_client()

        @app.route('/login', methods=['POST'])
        def login():
            ret = {'access_token': create_access_token('test', fresh=True)}
            return jsonify(ret), 200

        with self.assertRaises(RuntimeError):
            client.post('/login')

    def test_bad_jwt_requests(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        # Test with no authorization header
        response = self.client.get('/protected')
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Test with missing type in authorization header
        auth_header = access_token
        response = self.client.get('/protected',
                                   headers={'Authorization': auth_header})
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

        # Test with type not being Bearer in authorization header
        auth_header = "BANANA {}".format(access_token)
        response = self.client.get('/protected',
                                   headers={'Authorization': auth_header})
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

        # Test with too many items in auth header
        auth_header = "Bearer {} BANANA".format(access_token)
        response = self.client.get('/protected',
                                   headers={'Authorization': auth_header})
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

    def test_bad_tokens(self):
        # Test expired access token
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']
        status_code, data = self._jwt_get('/protected', access_token)
        self.assertEqual(status_code, 200)
        self.assertIn('msg', data)
        time.sleep(2)
        status_code, data = self._jwt_get('/protected', access_token)
        self.assertEqual(status_code, 401)
        self.assertIn('msg', data)

        # Test expired refresh token
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        refresh_token = data['refresh_token']
        status_code, data = self._jwt_post('/auth/refresh', refresh_token)
        self.assertEqual(status_code, 200)
        self.assertIn('access_token', data)
        self.assertNotIn('msg', data)
        time.sleep(2)
        status_code, data = self._jwt_post('/auth/refresh', refresh_token)
        self.assertEqual(status_code, 401)
        self.assertNotIn('access_token', data)
        self.assertIn('msg', data)

        # Test Bogus token
        auth_header = "Bearer {}".format('this_is_totally_an_access_token')
        response = self.client.get('/protected',
                                   headers={'Authorization': auth_header})
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

        # Test token that was signed with a different key
        with self.app.test_request_context():
            token = encode_access_token('foo',
                                        'newsecret',
                                        'HS256',
                                        timedelta(minutes=5),
                                        True, {},
                                        csrf=False)
        auth_header = "Bearer {}".format(token)
        response = self.client.get('/protected',
                                   headers={'Authorization': auth_header})
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

        # Test with valid token that is missing required claims
        now = datetime.utcnow()
        token_data = {'exp': now + timedelta(minutes=5)}
        encoded_token = jwt.encode(
            token_data, self.app.config['SECRET_KEY'],
            self.app.config['JWT_ALGORITHM']).decode('utf-8')
        auth_header = "Bearer {}".format(encoded_token)
        response = self.client.get('/protected',
                                   headers={'Authorization': auth_header})
        data = json.loads(response.get_data(as_text=True))
        status_code = response.status_code
        self.assertEqual(status_code, 422)
        self.assertIn('msg', data)

    def test_jwt_identity_claims(self):
        # Setup custom claims
        @self.jwt_manager.user_claims_loader
        def custom_claims(identity):
            return {'foo': 'bar'}

        @self.app.route('/claims')
        @jwt_required
        def claims():
            return jsonify({
                'username': get_jwt_identity(),
                'claims': get_jwt_claims()
            })

        # Login
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        # Test our custom endpoint
        status, data = self._jwt_get('/claims', access_token)
        self.assertEqual(status, 200)
        self.assertEqual(data, {'username': '******', 'claims': {'foo': 'bar'}})

    def test_jwt_raw_token(self):
        # Endpoints that uses get raw tokens and returns the keys
        @self.app.route('/claims')
        @jwt_required
        def claims():
            jwt = get_raw_jwt()
            claims_keys = [claim for claim in jwt]
            return jsonify(claims_keys), 200

        # Login
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        # Test our custom endpoint
        status, data = self._jwt_get('/claims', access_token)
        self.assertEqual(status, 200)
        self.assertIn('exp', data)
        self.assertIn('iat', data)
        self.assertIn('nbf', data)
        self.assertIn('jti', data)
        self.assertIn('identity', data)
        self.assertIn('fresh', data)
        self.assertIn('type', data)
        self.assertIn('user_claims', data)

    def test_different_headers(self):
        response = self.client.post('/auth/login')
        data = json.loads(response.get_data(as_text=True))
        access_token = data['access_token']

        self.app.config['JWT_HEADER_TYPE'] = 'JWT'
        status, data = self._jwt_get('/protected',
                                     access_token,
                                     header_type='JWT')
        self.assertEqual(data, {'msg': 'hello world'})
        self.assertEqual(status, 200)

        self.app.config['JWT_HEADER_TYPE'] = ''
        status, data = self._jwt_get('/protected',
                                     access_token,
                                     header_type='')
        self.assertEqual(data, {'msg': 'hello world'})
        self.assertEqual(status, 200)

        self.app.config['JWT_HEADER_TYPE'] = ''
        status, data = self._jwt_get('/protected',
                                     access_token,
                                     header_type='Bearer')
        self.assertIn('msg', data)
        self.assertEqual(status, 422)

        self.app.config['JWT_HEADER_TYPE'] = 'Bearer'
        self.app.config['JWT_HEADER_NAME'] = 'Auth'
        status, data = self._jwt_get('/protected',
                                     access_token,
                                     header_name='Auth',
                                     header_type='Bearer')
        self.assertIn('msg', data)
        self.assertEqual(status, 200)

        status, data = self._jwt_get('/protected',
                                     access_token,
                                     header_name='Authorization',
                                     header_type='Bearer')
        self.assertIn('msg', data)
        self.assertEqual(status, 401)

    def test_cookie_methods_fail_with_headers_configured(self):
        app = Flask(__name__)
        app.config['JWT_TOKEN_LOCATION'] = ['headers']
        app.secret_key = 'super=secret'
        app.testing = True
        JWTManager(app)
        client = app.test_client()

        @app.route('/login-bad', methods=['POST'])
        def bad_login():
            access_token = create_access_token('test')
            resp = jsonify({'login': True})
            set_access_cookies(resp, access_token)
            return resp, 200

        @app.route('/refresh-bad', methods=['POST'])
        def bad_refresh():
            refresh_token = create_refresh_token('test')
            resp = jsonify({'login': True})
            set_refresh_cookies(resp, refresh_token)
            return resp, 200

        @app.route('/logout-bad', methods=['POST'])
        def bad_logout():
            resp = jsonify({'logout': True})
            unset_jwt_cookies(resp)
            return resp, 200

        with self.assertRaises(RuntimeWarning):
            client.post('/login-bad')
        with self.assertRaises(RuntimeWarning):
            client.post('/refresh-bad')
        with self.assertRaises(RuntimeWarning):
            client.post('/logout-bad')
class TestNetworkPort(BaseTest):
    """Tests for NetworkPort blueprint"""
    def setUp(self):
        """Tests preparation"""

        # creates a test client
        self.app = Flask(__name__)

        self.app.register_blueprint(network_port.network_port)

        @self.app.errorhandler(status.HTTP_500_INTERNAL_SERVER_ERROR)
        def internal_server_error(error):
            """Creates a Internal Server Error response"""

            redfish_error = RedfishError(
                "InternalError",
                "The request failed due to an internal service error.  "
                "The service is still operational.")
            redfish_error.add_extended_info("InternalError")
            error_str = redfish_error.serialize()
            return Response(response=error_str,
                            status=status.HTTP_500_INTERNAL_SERVER_ERROR,
                            mimetype="application/json")

        @self.app.errorhandler(status.HTTP_404_NOT_FOUND)
        def not_found(error):
            """Creates a Not Found Error response"""
            redfish_error = RedfishError("GeneralError", error.description)
            error_str = redfish_error.serialize()
            return Response(response=error_str,
                            status=status.HTTP_404_NOT_FOUND,
                            mimetype='application/json')

        self.app = self.app.test_client()

        # propagate the exceptions to the test client
        self.app.testing = True

    @mock.patch.object(network_port, 'g')
    def test_get_network_port(self, g):
        """Tests NetworkPort"""

        # Loading server_hardware mockup value
        with open('oneview_redfish_toolkit/mockups/oneview/ServerHardware.json'
                  ) as f:
            server_hardware = json.load(f)

        # Loading NetworkPort mockup result
        with open('oneview_redfish_toolkit/mockups/redfish/'
                  'NetworkPort1-Ethernet.json') as f:
            network_port_mockup = json.load(f)

        # Create mock response
        g.oneview_client.server_hardware.get.return_value = server_hardware

        # Get NetworkPort
        response = self.app.get(
            "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/"
            "NetworkAdapters/3/NetworkPorts/1")

        # Gets json from response
        result = json.loads(response.data.decode("utf-8"))

        # Tests response
        self.assertEqual(status.HTTP_200_OK, response.status_code)
        self.assertEqual("application/json", response.mimetype)
        self.assertEqual(network_port_mockup, result)

    @mock.patch.object(network_port, 'g')
    def test_get_network_port_fibre_channel(self, g):
        """Tests NetworkPort"""

        # Loading server_hardware mockup value
        with open('oneview_redfish_toolkit/mockups/oneview/'
                  'ServerHardwareFibreChannel.json') as f:
            server_hardware = json.load(f)

        # Loading NetworkPort mockup result
        with open('oneview_redfish_toolkit/mockups/redfish/'
                  'NetworkPort1-FibreChannel.json') as f:
            network_port_mockup = json.load(f)

        # Create mock response
        g.oneview_client.server_hardware.get.return_value = server_hardware

        # Get NetworkPort
        response = self.app.get(
            "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/"
            "NetworkAdapters/2/NetworkPorts/1")

        # Gets json from response
        result = json.loads(response.data.decode("utf-8"))

        # Tests response
        self.assertEqual(status.HTTP_200_OK, response.status_code)
        self.assertEqual("application/json", response.mimetype)
        self.assertEqual(network_port_mockup, result)

    @mock.patch.object(network_port, 'g')
    def test_get_network_port_invalid_device_id(self, g):
        """Tests NetworkPort"""

        # Loading server_hardware mockup value
        with open('oneview_redfish_toolkit/mockups/oneview/ServerHardware.json'
                  ) as f:
            server_hardware = json.load(f)

        # Create mock response
        g.oneview_client.server_hardware.get.return_value = server_hardware

        # Get NetworkPort
        response = self.app.get(
            "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/"
            "NetworkAdapters/invalid_id/NetworkPorts/1")

        # Tests response
        self.assertEqual(status.HTTP_404_NOT_FOUND, response.status_code)
        self.assertEqual("application/json", response.mimetype)

    @mock.patch.object(network_port, 'g')
    def test_get_network_port_sh_not_found(self, g):
        """Tests NetworkPort server hardware not found"""

        e = HPOneViewException({
            'errorCode': 'RESOURCE_NOT_FOUND',
            'message': 'server-hardware not found',
        })
        g.oneview_client.server_hardware.get.side_effect = e

        # Get NetworkPort
        response = self.app.get(
            "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/"
            "NetworkAdapters/3/NetworkPorts/1")

        self.assertEqual(status.HTTP_404_NOT_FOUND, response.status_code)
        self.assertEqual("application/json", response.mimetype)

    @mock.patch.object(network_port, 'g')
    def test_get_network_port_sh_exception(self, g):
        """Tests NetworkPort unknown exception"""

        e = HPOneViewException({
            'errorCode': 'ANOTHER_ERROR',
            'message': 'server-hardware error',
        })
        g.oneview_client.server_hardware.get.side_effect = e

        # Get NetworkPort
        response = self.app.get(
            "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/"
            "NetworkAdapters/3/NetworkPorts/1")

        self.assertEqual(status.HTTP_500_INTERNAL_SERVER_ERROR,
                         response.status_code)
        self.assertEqual("application/json", response.mimetype)
Esempio n. 53
0
            complex_parameters = self.get_known_parameter_values()
            if len(complex_parameters) > 0:
                result = self.search_for_complex(complex_parameters)
            else:
                text_response = 'Введён запрос поиска без параметров или же по запросу ничего не найдено...'
                result = {
                    'speech': text_response,
                    'displayText': text_response,
                    'source': 'webhookdata'
                }
        result['contextOut'] = obj['result']['contexts']
        return result


@app.route('/', methods=['POST'])
def post():
    wb = Webhook()
    req = request.get_json(silent=True, force=True)
    res = wb.get_result(req)
    return make_response(jsonify(res))


if __name__ == '__main__':
    file = open('sample_request.json', 'r')
    j = json.load(file)
    p = app.test_client().post('/', data=json.dumps(j))
    if p.status_code == 200:
        j = json.loads(p.data)
        print(j['data'])
# [END app]
Esempio n. 54
0
class BootstrapTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.testing = True
        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
        self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
        self.app.secret_key = 'for test'
        self.bootstrap = Bootstrap(self.app)  # noqa

        @self.app.route('/')
        def index():
            return render_template_string(
                '{{ bootstrap.load_css() }}{{ bootstrap.load_js() }}')

        self.context = self.app.test_request_context()
        self.context.push()
        self.client = self.app.test_client()

    def tearDown(self):
        self.context.pop()

    def test_extension_init(self):
        self.assertIn('bootstrap', current_app.extensions)

    def test_load_css(self):
        rv = self.bootstrap.load_css()
        self.assertIn('bootstrap.min.css', rv)

    def test_load_js(self):
        rv = self.bootstrap.load_js()
        self.assertIn('bootstrap.min.js', rv)

    def test_local_resources(self):
        current_app.config['BOOTSTRAP_SERVE_LOCAL'] = True

        response = self.client.get('/')
        data = response.get_data(as_text=True)
        self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', data)
        self.assertIn('bootstrap.min.js', data)
        self.assertIn('bootstrap.min.css', data)
        self.assertIn('jquery.min.js', data)

        css_response = self.client.get(
            '/bootstrap/static/css/bootstrap.min.css')
        js_response = self.client.get('/bootstrap/static/js/bootstrap.min.js')
        jquery_response = self.client.get('/bootstrap/static/jquery.min.js')
        self.assertNotEqual(css_response.status_code, 404)
        self.assertNotEqual(js_response.status_code, 404)
        self.assertNotEqual(jquery_response.status_code, 404)

        css_rv = self.bootstrap.load_css()
        js_rv = self.bootstrap.load_js()
        self.assertIn('/bootstrap/static/css/bootstrap.min.css', css_rv)
        self.assertIn('/bootstrap/static/js/bootstrap.min.js', js_rv)
        self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv)
        self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv)

    def test_cdn_resources(self):
        response = self.client.get('/')
        data = response.get_data(as_text=True)
        self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', data)
        self.assertIn('bootstrap.min.js', data)
        self.assertIn('bootstrap.min.css', data)

        css_rv = self.bootstrap.load_css()
        js_rv = self.bootstrap.load_js()
        self.assertNotIn('/bootstrap/static/css/bootstrap.min.css', css_rv)
        self.assertNotIn('/bootstrap/static/js/bootstrap.min.js', js_rv)
        self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv)
        self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv)

    def test_render_field(self):
        @self.app.route('/field')
        def test():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_field %}
            {{ render_field(form.username) }}
            {{ render_field(form.password) }}
            ''',
                                          form=form)

        response = self.client.get('/field')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<input class="form-control" id="username" name="username"', data)
        self.assertIn(
            '<input class="form-control" id="password" name="password"', data)

    def test_render_form(self):
        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
                    {% from 'bootstrap/form.html' import render_form %}
                    {{ render_form(form) }}
                    ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<input class="form-control" id="username" name="username"', data)
        self.assertIn(
            '<input class="form-control" id="password" name="password"', data)

    def test_render_form_row(self):
        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
                    {% from 'bootstrap/form.html' import render_form_row %}
                    {{ render_form_row([form.username, form.password]) }}
                    ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="form-row">', data)
        self.assertIn('<div class="col">', data)

    def test_render_form_row_row_class(self):
        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
                    {% from 'bootstrap/form.html' import render_form_row %}
                    {{ render_form_row([form.username, form.password], row_class='row') }}
                    ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="row">', data)

    def test_render_form_row_col_class_default(self):
        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
                    {% from 'bootstrap/form.html' import render_form_row %}
                    {{ render_form_row([form.username, form.password], col_class_default='col-md-6') }}
                    ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="col-md-6">', data)

    def test_render_form_row_col_map(self):
        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
                    {% from 'bootstrap/form.html' import render_form_row %}
                    {{ render_form_row([form.username, form.password], col_map={'username': '******'}) }}
                    ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="col">', data)
        self.assertIn('<div class="col-md-6">', data)

    def test_render_pager(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)

        @self.app.route('/pager')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(100):
                m = Message()
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            return render_template_string('''
                            {% from 'bootstrap/pagination.html' import render_pager %}
                            {{ render_pager(pagination) }}
                            ''',
                                          pagination=pagination,
                                          messages=messages)

        response = self.client.get('/pager')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn('Previous', data)
        self.assertIn('Next', data)
        self.assertIn('<li class="page-item disabled">', data)

        response = self.client.get('/pager?page=2')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn('Previous', data)
        self.assertIn('Next', data)
        self.assertNotIn('<li class="page-item disabled">', data)

    def test_render_pagination(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)

        @self.app.route('/pagination')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(100):
                m = Message()
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            return render_template_string('''
                                    {% from 'bootstrap/pagination.html' import render_pagination %}
                                    {{ render_pagination(pagination) }}
                                    ''',
                                          pagination=pagination,
                                          messages=messages)

        response = self.client.get('/pagination')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn(
            '<a class="page-link" href="#">1 <span class="sr-only">(current)</span></a>',
            data)
        self.assertIn('10</a>', data)

        response = self.client.get('/pagination?page=2')
        data = response.get_data(as_text=True)
        self.assertIn('<nav aria-label="Page navigation">', data)
        self.assertIn('1</a>', data)
        self.assertIn(
            '<a class="page-link" href="#">2 <span class="sr-only">(current)</span></a>',
            data)
        self.assertIn('10</a>', data)

    def test_render_nav_item(self):
        @self.app.route('/nav_item')
        def test():
            return render_template_string('''
                    {% from 'bootstrap/nav.html' import render_nav_item %}
                    {{ render_nav_item('test', 'Home') }}
                    ''')

        response = self.client.get('/nav_item')
        data = response.get_data(as_text=True)
        self.assertIn('<a class="nav-item nav-link active"', data)

    def test_render_breadcrumb_item(self):
        @self.app.route('/breadcrumb_item')
        def test():
            return render_template_string('''
                    {% from 'bootstrap/nav.html' import render_breadcrumb_item %}
                    {{ render_breadcrumb_item('test', 'Home') }}
                    ''')

        response = self.client.get('/breadcrumb_item')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<li class="breadcrumb-item active"  aria-current="page">', data)

    def test_render_static(self):
        @self.app.route('/test_static')
        def test():
            return render_template_string('''
                            {% from 'bootstrap/utils.html' import render_static %}
                            {{ render_static('css', 'test.css') }}
                            {{ render_static('js', 'test.js') }}
                            {{ render_static('icon', 'test.ico') }}
                            ''')

        response = self.client.get('/test_static')
        data = response.get_data(as_text=True)
        self.assertIn(
            '<link rel="stylesheet" href="/static/test.css" type="text/css">',
            data)
        self.assertIn(
            '<script type="text/javascript" src="/static/test.js"></script>',
            data)
        self.assertIn('<link rel="icon" href="/static/test.ico">', data)

    def test_render_messages(self):
        @self.app.route('/messages')
        def test_messages():
            flash('test message', 'danger')
            return render_template_string('''
                            {% from 'bootstrap/utils.html' import render_messages %}
                            {{ render_messages() }}
                            ''')

        @self.app.route('/container')
        def test_container():
            flash('test message', 'danger')
            return render_template_string('''
                            {% from 'bootstrap/utils.html' import render_messages %}
                            {{ render_messages(container=True) }}
                            ''')

        @self.app.route('/dismissible')
        def test_dismissible():
            flash('test message', 'danger')
            return render_template_string('''
                            {% from 'bootstrap/utils.html' import render_messages %}
                            {{ render_messages(dismissible=True) }}
                            ''')

        @self.app.route('/dismiss_animate')
        def test_dismiss_animate():
            flash('test message', 'danger')
            return render_template_string('''
                            {% from 'bootstrap/utils.html' import render_messages %}
                            {{ render_messages(dismissible=True, dismiss_animate=True) }}
                            ''')

        response = self.client.get('/messages')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="alert alert-danger"', data)

        response = self.client.get('/container')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="container flashed-messages">', data)

        response = self.client.get('/dismissible')
        data = response.get_data(as_text=True)
        self.assertIn('alert-dismissible', data)
        self.assertIn(
            '<button type="button" class="close" data-dismiss="alert"', data)
        self.assertNotIn('fade show', data)

        response = self.client.get('/dismiss_animate')
        data = response.get_data(as_text=True)
        self.assertIn('alert-dismissible', data)
        self.assertIn(
            '<button type="button" class="close" data-dismiss="alert"', data)
        self.assertIn('fade show', data)

    # test WTForm fields for render_form and render_field
    def test_render_form_enctype(self):
        class SingleUploadForm(FlaskForm):
            avatar = FileField('Avatar')

        class MultiUploadForm(FlaskForm):
            photos = MultipleFileField('Multiple photos')

        @self.app.route('/single')
        def single():
            form = SingleUploadForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        @self.app.route('/multi')
        def multi():
            form = SingleUploadForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        response = self.client.get('/single')
        data = response.get_data(as_text=True)
        self.assertIn('multipart/form-data', data)

        response = self.client.get('/multi')
        data = response.get_data(as_text=True)
        self.assertIn('multipart/form-data', data)

    # test render_kw class for WTForms field
    def test_form_render_kw_class(self):
        class TestForm(FlaskForm):
            username = StringField('Username')
            password = PasswordField('Password',
                                     render_kw={'class': 'my-password-class'})
            submit = SubmitField(render_kw={'class': 'my-awesome-class'})

        @self.app.route('/render_kw')
        def render_kw():
            form = TestForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        response = self.client.get('/render_kw')
        data = response.get_data(as_text=True)
        self.assertIn('class="form-control"', data)
        self.assertNotIn('class="form-control "', data)
        self.assertIn('class="form-control my-password-class"', data)
        self.assertIn('my-awesome-class', data)
        self.assertIn('btn', data)

    # test WTForm field description for BooleanField
    def test_form_description_for_booleanfield(self):
        class TestForm(FlaskForm):
            remember = BooleanField('Remember me',
                                    description='Just check this')

        @self.app.route('/description')
        def description():
            form = TestForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        response = self.client.get('/description')
        data = response.get_data(as_text=True)
        self.assertIn('Remember me', data)
        self.assertIn(
            '<small class="form-text text-muted">Just check this</small>',
            data)

    def test_button_size(self):
        self.assertEqual(current_app.config['BOOTSTRAP_BTN_SIZE'], 'md')
        current_app.config['BOOTSTRAP_BTN_SIZE'] = 'lg'

        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn('btn-lg', data)

        @self.app.route('/form2')
        def test_overwrite():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form, button_size='sm') }}
            ''',
                                          form=form)

        response = self.client.get('/form2')
        data = response.get_data(as_text=True)
        self.assertNotIn('btn-lg', data)
        self.assertIn('btn-sm', data)

    def test_button_style(self):
        self.assertEqual(current_app.config['BOOTSTRAP_BTN_STYLE'],
                         'secondary')
        current_app.config['BOOTSTRAP_BTN_STYLE'] = 'primary'

        @self.app.route('/form')
        def test():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        response = self.client.get('/form')
        data = response.get_data(as_text=True)
        self.assertIn('btn-primary', data)

        @self.app.route('/form2')
        def test_overwrite():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form, button_style='success') }}
            ''',
                                          form=form)

        response = self.client.get('/form2')
        data = response.get_data(as_text=True)
        self.assertNotIn('btn-primary', data)
        self.assertIn('btn-success', data)

        @self.app.route('/form3')
        def test_button_map():
            form = HelloForm()
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form, button_map={'submit': 'warning'}) }}
            ''',
                                          form=form)

        response = self.client.get('/form3')
        data = response.get_data(as_text=True)
        self.assertNotIn('btn-primary', data)
        self.assertIn('btn-warning', data)

    def test_error_message_for_radiofield_and_booleanfield(self):
        class TestForm(FlaskForm):
            remember = BooleanField('Remember me', validators=[DataRequired()])
            option = RadioField(choices=[('dog', 'Dog'), ('cat', 'Cat'),
                                         ('bird', 'Bird'), ('alien', 'Alien')],
                                validators=[DataRequired()])

        @self.app.route('/error', methods=['GET', 'POST'])
        def error():
            form = TestForm()
            if form.validate_on_submit():
                pass
            return render_template_string('''
            {% from 'bootstrap/form.html' import render_form %}
            {{ render_form(form) }}
            ''',
                                          form=form)

        response = self.client.post('/error', follow_redirects=True)
        data = response.get_data(as_text=True)
        self.assertIn('This field is required', data)

    def test_render_simple_table(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)
            text = db.Column(db.Text)

        @self.app.route('/table')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(10):
                m = Message(text='Test message {}'.format(i + 1))
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            titles = [('id', '#'), ('text', 'Message')]
            return render_template_string('''
                                    {% from 'bootstrap/table.html' import render_table %}
                                    {{ render_table(messages, titles) }}
                                    ''',
                                          titles=titles,
                                          messages=messages)

        response = self.client.get('/table')
        data = response.get_data(as_text=True)
        self.assertIn('<table class="table">', data)
        self.assertIn('<th scope="col">#</th>', data)
        self.assertIn('<th scope="col">Message</th>', data)
        self.assertIn('<th scope="col">Message</th>', data)
        self.assertIn('<th scope="row">1</th>', data)
        self.assertIn('<td>Test message 1</td>', data)

    def test_render_customized_table(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)
            text = db.Column(db.Text)

        @self.app.route('/table')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(10):
                m = Message(text='Test message {}'.format(i + 1))
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            titles = [('id', '#'), ('text', 'Message')]
            return render_template_string('''
                                    {% from 'bootstrap/table.html' import render_table %}
                                    {{ render_table(messages, titles, table_classes='table-striped',
                                    header_classes='thead-dark', caption='Messages') }}
                                    ''',
                                          titles=titles,
                                          messages=messages)

        response = self.client.get('/table')
        data = response.get_data(as_text=True)
        self.assertIn('<table class="table table-striped">', data)
        self.assertIn('<thead class="thead-dark">', data)
        self.assertIn('<caption>Messages</caption>', data)

    def test_render_responsive_table(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)
            text = db.Column(db.Text)

        @self.app.route('/table')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(10):
                m = Message(text='Test message {}'.format(i + 1))
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            titles = [('id', '#'), ('text', 'Message')]
            return render_template_string('''
                                    {% from 'bootstrap/table.html' import render_table %}
                                    {{ render_table(messages, titles, responsive=True,
                                    responsive_class='table-responsive-sm') }}
                                    ''',
                                          titles=titles,
                                          messages=messages)

        response = self.client.get('/table')
        data = response.get_data(as_text=True)
        self.assertIn('<div class="table-responsive-sm">', data)

    def test_build_table_titles(self):
        db = SQLAlchemy(self.app)

        class Message(db.Model):
            id = db.Column(db.Integer, primary_key=True)
            text = db.Column(db.Text)

        @self.app.route('/table')
        def test():
            db.drop_all()
            db.create_all()
            for i in range(10):
                m = Message(text='Test message {}'.format(i + 1))
                db.session.add(m)
            db.session.commit()
            page = request.args.get('page', 1, type=int)
            pagination = Message.query.paginate(page, per_page=10)
            messages = pagination.items
            return render_template_string('''
                                    {% from 'bootstrap/table.html' import render_table %}
                                    {{ render_table(messages) }}
                                    ''',
                                          messages=messages)

        response = self.client.get('/table')
        data = response.get_data(as_text=True)
        self.assertIn('<table class="table">', data)
        self.assertIn('<th scope="col">#</th>', data)
        self.assertIn('<th scope="col">Text</th>', data)
        self.assertIn('<th scope="col">Text</th>', data)
        self.assertIn('<th scope="row">1</th>', data)
        self.assertIn('<td>Test message 1</td>', data)

    def test_build_table_titles_with_empty_data(self):
        @self.app.route('/table')
        def test():
            messages = []
            return render_template_string('''
                                    {% from 'bootstrap/table.html' import render_table %}
                                    {{ render_table(messages) }}
                                    ''',
                                          messages=messages)

        response = self.client.get('/table')
        self.assertEqual(response.status_code, 200)
Esempio n. 55
0
class UnicodeCookieUserIDTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SECRET_KEY'] = 'deterministic'
        self.app.config['SESSION_PROTECTION'] = None
        self.remember_cookie_name = 'remember'
        self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name
        self.login_manager = LoginManager()
        self.login_manager.init_app(self.app)
        self.login_manager._login_disabled = False

        @self.app.route('/')
        def index():
            return u'Welcome!'

        @self.app.route('/login-germanjapanese-remember')
        def login_germanjapanese_remember():
            return unicode(login_user(germanjapanese, remember=True))

        @self.app.route('/username')
        def username():
            if current_user.is_authenticated:
                return current_user.name
            return u'Anonymous'

        @self.app.route('/userid')
        def user_id():
            if current_user.is_authenticated:
                return current_user.id
            return u'wrong_id'

        @self.login_manager.user_loader
        def load_user(user_id):
            return USERS[unicode(user_id)]

        # This will help us with the possibility of typoes in the tests. Now
        # we shouldn't have to check each response to help us set up state
        # (such as login pages) to make sure it worked: we will always
        # get an exception raised (rather than return a 404 response)
        @self.app.errorhandler(404)
        def handle_404(e):
            raise e

        unittest.TestCase.setUp(self)

    def _delete_session(self, c):
        # Helper method to cause the session to be deleted
        # as if the browser was closed. This will remove
        # the session regardless of the permament flag
        # on the session!
        with c.session_transaction() as sess:
            sess.clear()

    def test_remember_me_username(self):
        with self.app.test_client() as c:
            c.get('/login-germanjapanese-remember')
            self._delete_session(c)
            result = c.get('/username')
            self.assertEqual(u'Müller', result.data.decode('utf-8'))

    def test_remember_me_user_id(self):
        with self.app.test_client() as c:
            c.get('/login-germanjapanese-remember')
            self._delete_session(c)
            result = c.get('/userid')
            self.assertEqual(u'佐藤', result.data.decode('utf-8'))
Esempio n. 56
0
class UserAPITestCase(unittest.TestCase):

	def __init__(self, *args, **kwargs):
		self.app = Flask(__name__)
		self.app.secret_key = 'test_key'
		#add an index rule because some views redirect to index
		self.app.add_url_rule('/',
			view_func=TestView.as_view('index', template='index.html'),
			methods=['GET', 'POST'])
		self.app.add_url_rule('/restricted/',
			view_func=RestrictedView.as_view('restricted', template='index.html'),
			methods=['GET'])
		self.app.add_url_rule('/logout/', view_func=LoginAPI.logout)
		self.app.add_url_rule('/login/', view_func=LoginAPI.as_view('login'),
			methods=['GET', 'POST'])
		self.app.add_url_rule('/delete/<user_id>/', view_func=UserAPI.as_view('user_api'),
			methods=['DELETE'])
		self.test_client = self.app.test_client()
		self.user = None
		super(UserAPITestCase, self).__init__(*args, **kwargs)

	def setUp(self):
		#nothing to do for set up
		self.user = DatabaseTestCase.add_user()
			
	def tearDown(self):
		#nothing to do for tear down
		self.user = DatabaseTestCase.delete_user()
		

	def login(self, password):
		return self.test_client.post('/login/', data=dict(
			username='******',
			password=password),
			follow_redirects=True)

	def test_delete_user(self):
		'''test setting a user to inactive through UserAPI.
		ensures that deletion only works when user is logged in.'''
		with self.app.test_client() as c:
			#first check when user isn't logged in
			rv = c.delete('/delete/{0}/'.format(self.user.id))
			assert User.objects(username='******')[0].active == True
			with c.session_transaction() as sess:
				sess['username'] = '******'
			rv = c.delete('/delete/{0}/'.format(self.user.id))
			assert User.objects(username='******')[0].active == False

	def test_login(self):
		#good password works
		assert 'OK' in self.login('test_password').data
		#bad password doesn't
		assert 'error' in self.login('bad_password').data
		self.user.active = False
		self.user.save()
		#inactive users can't log in
		assert 'error' in self.login('test_password').data

	def test_user_logout(self):
		rv = self.test_client.get('/logout/', follow_redirects=True)
		assert 'OK' in rv.data

	def test_restricted_view(self):
		rv = self.test_client.get('/restricted/', follow_redirects=True)
		assert 'not logged in' in rv.data
Esempio n. 57
0
class CompressionAlgoTests(unittest.TestCase):
    """
    Test different scenarios for compression algorithm negotiation between
    client and server. Please note that algorithm names (even the "supported"
    ones) in these tests **do not** indicate that all of these are actually
    supported by this extension.
    """
    def setUp(self):
        super(CompressionAlgoTests, self).setUp()

        # Create the app here but don't call `Compress()` on it just yet; we need
        # to be able to modify the settings in various tests. Calling `Compress(self.app)`
        # twice would result in two `@after_request` handlers, which would be bad.
        self.app = Flask(__name__)
        self.app.testing = True

        small_path = os.path.join(os.getcwd(), 'tests', 'templates', 'small.html')
        self.small_size = os.path.getsize(small_path) - 1

        @self.app.route('/small/')
        def small():
            return render_template('small.html')

    def test_setting_compress_algorithm_simple_string(self):
        """ Test that a single entry in `COMPRESS_ALGORITHM` still works for backwards compatibility """
        self.app.config['COMPRESS_ALGORITHM'] = 'gzip'
        c = Compress(self.app)
        self.assertListEqual(c.enabled_algorithms, ['gzip'])

    def test_setting_compress_algorithm_cs_string(self):
        """ Test that `COMPRESS_ALGORITHM` can be a comma-separated string """
        self.app.config['COMPRESS_ALGORITHM'] = 'gzip, br, zstd'
        c = Compress(self.app)
        self.assertListEqual(c.enabled_algorithms, ['gzip', 'br', 'zstd'])

    def test_setting_compress_algorithm_list(self):
        """ Test that `COMPRESS_ALGORITHM` can be a list of strings """
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate']
        c = Compress(self.app)
        self.assertListEqual(c.enabled_algorithms, ['gzip', 'br', 'deflate'])

    def test_one_algo_supported(self):
        """ Tests requesting a single supported compression algorithm """
        accept_encoding = 'gzip'
        self.app.config['COMPRESS_ALGORITHM'] = ['br', 'gzip']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'gzip')

    def test_one_algo_unsupported(self):
        """ Tests requesting single unsupported compression algorithm """
        accept_encoding = 'some-alien-algorithm'
        self.app.config['COMPRESS_ALGORITHM'] = ['br', 'gzip']
        c = Compress(self.app)
        self.assertIsNone(c._choose_compress_algorithm(accept_encoding))

    def test_multiple_algos_supported(self):
        """ Tests requesting multiple supported compression algorithms """
        accept_encoding = 'br, gzip, zstd'
        self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip']
        c = Compress(self.app)
        # When the decision is tied, we expect to see the first server-configured algorithm
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'zstd')

    def test_multiple_algos_unsupported(self):
        """ Tests requesting multiple unsupported compression algorithms """
        accept_encoding = 'future-algo, alien-algo, forbidden-algo'
        self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip']
        c = Compress(self.app)
        self.assertIsNone(c._choose_compress_algorithm(accept_encoding))

    def test_multiple_algos_with_wildcard(self):
        """ Tests requesting multiple unsupported compression algorithms and a wildcard """
        accept_encoding = 'future-algo, alien-algo, forbidden-algo, *'
        self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip']
        c = Compress(self.app)
        # We expect to see the first server-configured algorithm
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'zstd')

    def test_multiple_algos_with_different_quality(self):
        """ Tests requesting multiple supported compression algorithms with different q-factors """
        accept_encoding = 'zstd;q=0.8, br;q=0.9, gzip;q=0.5'
        self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'br')

    def test_multiple_algos_with_equal_quality(self):
        """ Tests requesting multiple supported compression algorithms with equal q-factors """
        accept_encoding = 'zstd;q=0.5, br;q=0.5, gzip;q=0.5'
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'zstd']
        c = Compress(self.app)
        # We expect to see the first server-configured algorithm
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'gzip')

    def test_default_quality_is_1(self):
        """ Tests that when making mixed-quality requests, the default q-factor is 1.0 """
        accept_encoding = 'deflate, br;q=0.999, gzip;q=0.5'
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'deflate')

    def test_default_wildcard_quality_is_0(self):
        """ Tests that a wildcard has a default q-factor of 0.0 """
        accept_encoding = 'br;q=0.001, *'
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'br')

    def test_wildcard_quality(self):
        """ Tests that a wildcard with q=0 is discarded """
        accept_encoding = '*;q=0'
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), None)

    def test_identity(self):
        """ Tests that identity is understood """
        accept_encoding = 'identity;q=1, br;q=0.5, *;q=0'
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), None)

    def test_chrome_ranged_requests(self):
        """ Tests that Chrome ranged requests behave as expected """
        accept_encoding = 'identity;q=1, *;q=0'
        self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate']
        c = Compress(self.app)
        self.assertEqual(c._choose_compress_algorithm(accept_encoding), None)

    def test_content_encoding_is_correct(self):
        """ Test that the `Content-Encoding` header matches the compression algorithm """
        self.app.config['COMPRESS_ALGORITHM'] = ['br', 'gzip', 'deflate']
        Compress(self.app)

        headers_gzip = [('Accept-Encoding', 'gzip')]
        client = self.app.test_client()
        response_gzip = client.options('/small/', headers=headers_gzip)
        self.assertIn('Content-Encoding', response_gzip.headers)
        self.assertEqual(response_gzip.headers.get('Content-Encoding'), 'gzip')

        headers_br = [('Accept-Encoding', 'br')]
        client = self.app.test_client()
        response_br = client.options('/small/', headers=headers_br)
        self.assertIn('Content-Encoding', response_br.headers)
        self.assertEqual(response_br.headers.get('Content-Encoding'), 'br')

        headers_deflate = [('Accept-Encoding', 'deflate')]
        client = self.app.test_client()
        response_deflate = client.options('/small/', headers=headers_deflate)
        self.assertIn('Content-Encoding', response_deflate.headers)
        self.assertEqual(response_deflate.headers.get('Content-Encoding'), 'deflate')
Esempio n. 58
0
class UrlTests(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.testing = True

        small_path = os.path.join(os.getcwd(), 'tests', 'templates',
                                  'small.html')

        large_path = os.path.join(os.getcwd(), 'tests', 'templates',
                                  'large.html')

        self.small_size = os.path.getsize(small_path) - 1
        self.large_size = os.path.getsize(large_path) - 1

        Compress(self.app)

        @self.app.route('/small/')
        def small():
            return render_template('small.html')

        @self.app.route('/large/')
        def large():
            return render_template('large.html')

    def client_get(self, ufs):
        client = self.app.test_client()
        response = client.get(ufs, headers=[('Accept-Encoding', 'gzip')])
        self.assertEqual(response.status_code, 200)
        return response

    def test_br_algorithm(self):
        client = self.app.test_client()
        headers = [('Accept-Encoding', 'br')]

        response = client.options('/small/', headers=headers)
        self.assertEqual(response.status_code, 200)

        response = client.options('/large/', headers=headers)
        self.assertEqual(response.status_code, 200)

    def test_compress_min_size(self):
        """ Tests COMPRESS_MIN_SIZE correctly affects response data. """
        response = self.client_get('/small/')
        self.assertEqual(self.small_size, len(response.data))

        response = self.client_get('/large/')
        self.assertNotEqual(self.large_size, len(response.data))

    def test_mimetype_mismatch(self):
        """ Tests if mimetype not in COMPRESS_MIMETYPES. """
        response = self.client_get('/static/1.png')
        self.assertEqual(response.mimetype, 'image/png')

    def test_content_length_options(self):
        client = self.app.test_client()
        headers = [('Accept-Encoding', 'gzip')]
        response = client.options('/small/', headers=headers)
        self.assertEqual(response.status_code, 200)

    def test_gzip_compression_level(self):
        """ Tests COMPRESS_LEVEL correctly affects response data. """
        self.app.config['COMPRESS_LEVEL'] = 1
        client = self.app.test_client()
        response = client.get('/large/', headers=[('Accept-Encoding', 'gzip')])
        response1_size = len(response.data)

        self.app.config['COMPRESS_LEVEL'] = 6
        client = self.app.test_client()
        response = client.get('/large/', headers=[('Accept-Encoding', 'gzip')])
        response6_size = len(response.data)

        self.assertNotEqual(response1_size, response6_size)

    def test_br_compression_level(self):
        """ Tests that COMPRESS_BR_LEVEL correctly affects response data. """
        self.app.config['COMPRESS_BR_LEVEL'] = 4
        client = self.app.test_client()
        response = client.get('/large/', headers=[('Accept-Encoding', 'br')])
        response4_size = len(response.data)

        self.app.config['COMPRESS_BR_LEVEL'] = 11
        client = self.app.test_client()
        response = client.get('/large/', headers=[('Accept-Encoding', 'br')])
        response11_size = len(response.data)

        self.assertNotEqual(response4_size, response11_size)

    def test_deflate_compression_level(self):
        """ Tests COMPRESS_DELATE_LEVEL correctly affects response data. """
        self.app.config['COMPRESS_DEFLATE_LEVEL'] = -1
        client = self.app.test_client()
        response = client.get('/large/', headers=[('Accept-Encoding', 'deflate')])
        response_size = len(response.data)

        self.app.config['COMPRESS_DEFLATE_LEVEL'] = 1
        client = self.app.test_client()
        response = client.get('/large/', headers=[('Accept-Encoding', 'deflate')])
        response1_size = len(response.data)

        self.assertNotEqual(response_size, response1_size)
Esempio n. 59
0
class FlstatsTestCase(unittest.TestCase):
    def setUp(self):
        """Creates a Flask test app and registers two routes
        as well as the flstats blueprint.
        """

        self.app = Flask(__name__)
        self.app.register_blueprint(webstatistics)
        self.client = self.app.test_client()

        @self.app.route('/url1')
        @statistics
        def url1():
            return random.randint(0, 1000)

        @self.app.route('/url2')
        @statistics
        def url2():
            return random.randint(0, 1000)

    def test_url1(self):
        """Test with one URL only"""

        self.client.get('/url1')

        # We make sure data processing is complete
        sleep(0.1)

        response = self.client.get('/flstats/')
        self.assertEqual(response.status, '200 OK')

        # Statistics tests
        data = json.loads(response.data)

        stats = data['stats']
        self.assertEqual(len(stats), 1)

        stat = stats.pop()
        self.assertEqual(stat['url'], 'http://localhost/url1')
        self.assertEqual(stat['throughput'], 1)
        self.assertTrue(stat['min'] == stat['avg'] == stat['max'])

        for i in range(0, 9):
            self.client.get('/url1')

        # We make sure data processing is complete
        sleep(0.1)

        response = self.client.get('/flstats/')
        self.assertEqual(response.status, '200 OK')

        # Statistics tests
        data = json.loads(response.data)

        stats = data['stats']
        self.assertEqual(len(stats), 1)

        stat = stats.pop()
        self.assertEqual(stat['url'], 'http://localhost/url1')
        self.assertEqual(stat['throughput'], 9)
        self.assertTrue(stat['min'] <= stat['avg'] <= stat['max'])

    def test_url2(self):
        """Test with two URLs"""

        for i in range(0, 20):
            self.client.get('/url2')

        # We make sure data processing is complete
        sleep(0.1)

        response = self.client.get('/flstats/')
        self.assertEqual(response.status, '200 OK')

        # Statistics tests
        data = json.loads(response.data)

        stats = data['stats']
        self.assertEqual(len(stats), 2)

        for stat in stats:
            if stat['url'] == 'http://localhost/url1':
                self.assertEqual(stat['throughput'], 0)
                self.assertTrue(stat['min'] <= stat['avg'] <= stat['max'])
            elif stat['url'] == 'http://localhost/url2':
                self.assertEqual(stat['throughput'], 20)
                self.assertTrue(stat['min'] <= stat['avg'] <= stat['max'])
            else:
                self.fail('Invalid URL, WTF?!')
class TestViews(TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SCHEDULER_VIEWS_ENABLED'] = True
        self.scheduler = APScheduler(app=self.app)
        self.scheduler.start()
        self.client = self.app.test_client()

    def test_add_job(self):
        job = {
            'id': 'job1',
            'func': 'test_views:job1',
            'trigger': 'date',
            'run_date': '2020-12-01T12:30:01+00:00',
        }

        response = self.client.post('/scheduler/jobs', data=json.dumps(job))
        self.assertEqual(response.status_code, 200)

        job2 = json.loads(response.get_data(as_text=True))

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('run_date'), job2.get('run_date'))

    def test_delete_job(self):
        self.__add_job()

        response = self.client.delete('/scheduler/jobs/job1')
        self.assertEqual(response.status_code, 204)

        response = self.client.get('/scheduler/jobs/job1')
        self.assertEqual(response.status_code, 404)

    def test_get_job(self):
        job = self.__add_job()

        response = self.client.get('/scheduler/jobs/job1')
        self.assertEqual(response.status_code, 200)

        job2 = json.loads(response.get_data(as_text=True))

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('minutes'), job2.get('minutes'))

    def test_get_all_jobs(self):
        job = self.__add_job()

        response = self.client.get('/scheduler/jobs')
        self.assertEqual(response.status_code, 200)

        jobs = json.loads(response.get_data(as_text=True))

        self.assertEqual(len(jobs), 1)

        job2 = jobs[0]

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('minutes'), job2.get('minutes'))

    def test_update_job(self):
        job = self.__add_job()

        data_to_update = {
            'args': [1]
        }

        response = self.client.patch('/scheduler/jobs/job1', data=json.dumps(data_to_update))
        self.assertEqual(response.status_code, 200)

        job2 = json.loads(response.get_data(as_text=True))

        self.assertEqual(job.get('id'), job2.get('id'))
        self.assertEqual(job.get('func'), job2.get('func'))
        self.assertEqual(data_to_update.get('args'), job2.get('args'))
        self.assertEqual(job.get('trigger'), job2.get('trigger'))
        self.assertEqual(job.get('minutes'), job2.get('minutes'))

    def test_pause_and_resume_job(self):
        self.__add_job()

        response = self.client.post('/scheduler/jobs/job1/pause')
        self.assertEqual(response.status_code, 200)
        job = json.loads(response.get_data(as_text=True))
        self.assertIsNone(job.get('next_run_time'))

        response = self.client.post('/scheduler/jobs/job1/resume')
        self.assertEqual(response.status_code, 200)
        job = json.loads(response.get_data(as_text=True))
        self.assertIsNotNone(job.get('next_run_time'))

    def __add_job(self):
        job = {
            'id': 'job1',
            'func': 'test_views:job1',
            'trigger': 'interval',
            'minutes': 10,
        }

        response = self.client.post('/scheduler/jobs', data=json.dumps(job))
        return json.loads(response.get_data(as_text=True))