class TestCase(unittest.TestCase): '''An helper mixin for common operations''' def setUp(self): '''Initialize an Flask application''' self.app = Flask(__name__) @contextmanager def context(self): with self.app.test_request_context('/'): yield def get_specs(self, prefix='/api', app=None): '''Get a Swagger specification for a RestPlus API''' with self.app.test_client() as client: response = client.get('{0}/specs.json'.format(prefix)) self.assertEquals(response.status_code, 200) self.assertEquals(response.content_type, 'application/json') return json.loads(response.data.decode('utf8')) def get_declaration(self, namespace='default', prefix='/api', status=200, app=None): '''Get an API declaration for a given namespace''' with self.app.test_client() as client: response = client.get('{0}/{1}.json'.format(prefix, namespace)) self.assertEquals(response.status_code, status) self.assertEquals(response.content_type, 'application/json') return json.loads(response.data.decode('utf8'))
class DefaultsTestCase(FlaskCorsTestCase): def setUp(self): self.app = Flask(__name__) @self.app.route('/', methods=['GET','OPTIONS']) @cross_origin() def wildcard(): return 'Welcome!' def test_wildcard_defaults_no_origin(self): ''' If there is no Origin header in the request, the Access-Control-Allow-Origin header should not be included, according to the w3 spec. ''' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb('/') self.assertEqual(result.headers.get(AccessControlAllowOrigin), '*') def test_wildcard_defaults_origin(self): ''' If there is no Origin header in the request, the Access-Control-Allow-Origin header should be included, if and only if the always_send parameter is `True`, which is the default value. ''' example_origin = 'http://example.com' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb('/',headers = {'Origin': example_origin}) self.assertEqual(result.headers.get(AccessControlAllowOrigin),'*')
class W3TestCase(FlaskCorsTestCase): def setUp(self): self.app = Flask(__name__) @self.app.route('/', methods=['GET','OPTIONS']) @cross_origin(origins='*', send_wildcard=False, always_send=False) def allowOrigins(): ''' This sets up flask-cors to echo the request's `Origin` header, only if it is actually set. This behavior is most similar to the actual W3 specification, http://www.w3.org/TR/cors/ but is not the default because it is more common to use the wildcard approach. ''' return 'Welcome!' def test_wildcard_origin_header(self): ''' If there is an Origin header in the request, the Access-Control-Allow-Origin header should be echoed back. ''' example_origin = 'http://example.com' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb('/', headers = {'Origin': example_origin}) self.assertEqual(result.headers.get(AccessControlAllowOrigin),example_origin) def test_wildcard_no_origin_header(self): ''' If there is no Origin header in the request, the Access-Control-Allow-Origin header should not be included. ''' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb('/') self.assertTrue(AccessControlAllowOrigin not in result.headers)
class SupportsCredentialsCase(FlaskCorsTestCase): def setUp(self): self.app = Flask(__name__) @self.app.route('/test_credentials') @cross_origin(supports_credentials=True) def test_credentials(): return 'Credentials!' @self.app.route('/test_open') @cross_origin() def test_open(): return 'Open!' def test_credentialed_request(self): ''' The specified route should return the Access-Control-Allow-Credentials header. ''' with self.app.test_client() as c: result = c.get('/test_credentials') header = result.headers.get(ACL_CREDENTIALS) self.assertEquals(header, 'true') def test_open_request(self): ''' The default behavior should be to disallow credentials. ''' with self.app.test_client() as c: result = c.get('/test_open') self.assertTrue(ACL_CREDENTIALS not in result.headers)
class TestCase(unittest.TestCase): '''An helper mixin for common operations''' def setUp(self): '''Initialize an Flask application''' self.app = Flask(__name__) @contextmanager def context(self, **kwargs): with self.app.test_request_context('/', **kwargs): yield @contextmanager def settings(self, **settings): ''' A context manager to alter app settings during a test and restore it after.. ''' original = {} # backup for key, value in settings.items(): original[key] = self.app.config.get(key) self.app.config[key] = value yield # restore for key, value in original.items(): self.app.config[key] = value @contextmanager def assert_warning(self, category=Warning): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') # Cause all warnings to always be triggered. yield self.assertGreaterEqual(len(w), 1, 'It should raise a warning') warning = w[0] self.assertEqual(warning.category, category, 'It should raise {0}'.format(category.__name__)) def get(self, url, **kwargs): with self.app.test_client() as client: return client.get(url, **kwargs) def post(self, url, **kwargs): with self.app.test_client() as client: return client.post(url, **kwargs) def get_json(self, url, status=200, **kwargs): response = self.get(url, **kwargs) self.assertEqual(response.status_code, status) self.assertEqual(response.content_type, 'application/json') return json.loads(response.data.decode('utf8')) def get_specs(self, prefix='', status=200, **kwargs): '''Get a Swagger specification for a RestPlus API''' return self.get_json('{0}/swagger.json'.format(prefix), status=status, **kwargs) def assertDataEqual(self, tested, expected): '''Compare data without caring about order and type (dict vs. OrderedDict)''' assert_data_equal(tested, expected)
class OriginsW3TestCase(FlaskCorsTestCase): def setUp(self): self.app = Flask(__name__) @self.app.route('/') @cross_origin(origins='*', send_wildcard=False, always_send=False) def allowOrigins(): ''' This sets up flask-cors to echo the request's `Origin` header, only if it is actually set. This behavior is most similar to the actual W3 specification, http://www.w3.org/TR/cors/ but is not the default because it is more common to use the wildcard configuration in order to support CDN caching. ''' return 'Welcome!' @self.app.route('/default-origins') @cross_origin(send_wildcard=False, always_send=False) def noWildcard(): ''' With the default origins configuration, send_wildcard should still be respected. ''' return 'Welcome!' def test_wildcard_origin_header(self): ''' If there is an Origin header in the request, the Access-Control-Allow-Origin header should be echoed back. ''' example_origin = 'http://example.com' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb('/', headers={'Origin': example_origin}) self.assertEqual( result.headers.get(ACL_ORIGIN), example_origin ) def test_wildcard_no_origin_header(self): ''' If there is no Origin header in the request, the Access-Control-Allow-Origin header should not be included. ''' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb('/') self.assertTrue(ACL_ORIGIN not in result.headers) def test_wildcard_default_origins(self): ''' If there is an Origin header in the request, the Access-Control-Allow-Origin header should be echoed back. ''' example_origin = 'http://example.com' with self.app.test_client() as c: for verb in self.iter_verbs(c): result = verb( '/default-origins', headers={'Origin': example_origin} ) self.assertEqual( result.headers.get(ACL_ORIGIN), example_origin )
def test_bugsnag_custom_data(self, deliver): meta_data = [{"hello": {"world": "once"}}, {"again": {"hello": "world"}}] app = Flask("bugsnag") @app.route("/hello") def hello(): bugsnag.configure_request(meta_data=meta_data.pop()) raise SentinelError("oops") handle_exceptions(app) app.test_client().get('/hello') app.test_client().get('/hello') self.assertEqual(deliver.call_count, 2) payload = deliver.call_args_list[0][0][0] event = payload['events'][0] self.assertEqual(event['metaData'].get('hello'), None) self.assertEqual(event['metaData']['again']['hello'], 'world') payload = deliver.call_args_list[1][0][0] event = payload['events'][0] self.assertEqual(event['metaData']['hello']['world'], 'once') self.assertEqual(event['metaData'].get('again'), None)
def testGetTpls(self): app = Flask(__name__) with app.test_client() as c: testRequest = c.get('/tpls.json?which=asdfjdskfjs') self.assertEquals(statserv.server.send_error(request, 'template does'\ ' not exist'), statserv.server.get_tpls(), 'The get_tpls method really shouldn\'t try to send '\ 'back a template for \'asdfjdskfjs.\'') tplsempty = False with app.test_client() as c: testRequest = c.get('/tpls.json?which=') tplsempty = statserv.server.get_tpls() with app.test_client() as c: testRequest = c.get('/tpls.json?which=all') self.assertEquals(statserv.server.get_tpls(), tplsempty, 'The get_tpls method should send back all '\ 'templates on both which=all and which=.') with app.test_client() as c: testRequest = c.get('/tpls.json?callback=blah'\ '&which=header&which=home') header = open(statserv.server.determine_path()\ + '/tpls/header.tpl').read() home = open(statserv.server.determine_path()\ + '/tpls/home.tpl').read() response = statserv.server.make_response('blah', dict({'header': header, 'home': home})) self.assertEquals(statserv.server.get_tpls(), response, 'The single-template support does not seem to be '\ 'working properly.')
class KazooTestCase(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.config['TESTING'] = True self.kazoo = Kazoo(self.app) self.node_prefix = '/kazoo-test/' self.node_name = 'test-kazoo-node' self.node_value = uuid.uuid1().bytes self.app.extensions['kazoo']['client'].create('/kazoo-test/%s' % self.node_name, self.node_value, makepath=True) @self.app.route('/write/<path>') def write(path): return str(kazoo_client.set(self.node_prefix + path, b'OK')) @self.app.route('/read/<path>') def read(path): return kazoo_client.get(self.node_prefix + path)[0] def tearDown(self): try: self.app.extensions['kazoo']['client'].delete(self.node_prefix + self.node_name) except NoNodeError: print "Teardown got NoNodeError" def test_kazoo_set(self): with self.app.test_client() as c: results = c.get('/write/%s' % self.node_name) self.assertEqual(results.status_code, 200) def test_kazoo_read(self): with self.app.test_client() as c: results = c.get('/read/%s' % self.node_name) self.assertEqual(results.data, self.node_value)
class LogicalAPINamingTestCase(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.tasks = [{"id": 1, "task": "Do the laundry"}, {"id": 2, "task": "Do the dishes"}] def get_tasks(request): return self.tasks def post_task(request): data = request.json self.tasks.append({"task": data["task"]}) return {}, 201 self.get_tasks = get_tasks self.post_task = post_task def test_naming_a_single_api(self): api_1 = Api(version="v1", name="read-only-methods") api_1.register_endpoint(ApiEndpoint(http_method="GET", endpoint="/task/", handler=self.get_tasks)) self.app.register_blueprint(api_1) self.app.config["TESTING"] = True client = self.app.test_client() resp = client.get("/v1/task/", content_type="application/json") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.headers["Content-Type"], "application/json") data = json.loads(resp.data.decode(resp.charset)) self.assertEqual(len(data), 2) def test_naming_multiple_logic_apis(self): api_1 = Api(version="v1", name="read-only-methods") api_1.register_endpoint(ApiEndpoint(http_method="GET", endpoint="/task/", handler=self.get_tasks)) self.app.register_blueprint(api_1) api_2 = Api(version="v1", name="write-methods") api_2.register_endpoint(ApiEndpoint(http_method="POST", endpoint="/task/", handler=self.post_task)) self.app.register_blueprint(api_2) self.app.config["TESTING"] = True client = self.app.test_client() # Testing GET resp = client.get("/v1/task/", content_type="application/json") self.assertEqual(resp.status_code, 200) self.assertEqual(resp.headers["Content-Type"], "application/json") data = json.loads(resp.data.decode(resp.charset)) self.assertEqual(len(data), 2) # Testing POST resp = client.post("/v1/task/", content_type="application/json", data=json.dumps({"task": "New Task!"})) self.assertEqual(resp.status_code, 201) self.assertEqual(resp.headers["Content-Type"], "application/json") self.assertEqual(len(self.tasks), 3)
def test_modal_edit(): # bootstrap 2 - test edit_modal app_bs2 = Flask(__name__) admin_bs2 = Admin(app_bs2, template_mode="bootstrap2") class EditModalOn(fileadmin.FileAdmin): edit_modal = True editable_extensions = ('txt',) class EditModalOff(fileadmin.FileAdmin): edit_modal = False editable_extensions = ('txt',) path = op.join(op.dirname(__file__), 'files') edit_modal_on = EditModalOn(path, '/files/', endpoint='edit_modal_on') edit_modal_off = EditModalOff(path, '/files/', endpoint='edit_modal_off') admin_bs2.add_view(edit_modal_on) admin_bs2.add_view(edit_modal_off) client_bs2 = app_bs2.test_client() # bootstrap 2 - ensure modal window is added when edit_modal is enabled rv = client_bs2.get('/admin/edit_modal_on/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('fa_modal_window' in data) # bootstrap 2 - test edit modal disabled rv = client_bs2.get('/admin/edit_modal_off/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('fa_modal_window' not in data) # bootstrap 3 app_bs3 = Flask(__name__) admin_bs3 = Admin(app_bs3, template_mode="bootstrap3") admin_bs3.add_view(edit_modal_on) admin_bs3.add_view(edit_modal_off) client_bs3 = app_bs3.test_client() # bootstrap 3 - ensure modal window is added when edit_modal is enabled rv = client_bs3.get('/admin/edit_modal_on/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('fa_modal_window' in data) # bootstrap 3 - test modal disabled rv = client_bs3.get('/admin/edit_modal_off/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('fa_modal_window' not in data)
class ExposeHeadersTestCase(FlaskCorsTestCase): def setUp(self): self.app = Flask(__name__) @self.app.route('/test_default') @cross_origin() def test_default(): return 'Welcome!' @self.app.route('/test_list') @cross_origin(expose_headers=["Foo", "Bar"]) def test_list(): return 'Welcome!' @self.app.route('/test_string') @cross_origin(expose_headers="Foo") def test_string(): return 'Welcome!' @self.app.route('/test_set') @cross_origin(expose_headers=set(["Foo", "Bar"])) def test_set(): return 'Welcome!' def test_default(self): with self.app.test_client() as c: resp = c.get('/test_default') self.assertTrue(resp.headers.get(ACL_EXPOSE_HEADERS) is None, "Default should have no allowed headers") def test_list_serialized(self): ''' If there is an Origin header in the request, the Access-Control-Allow-Origin header should be echoed back. ''' with self.app.test_client() as c: resp = c.get('/test_list') self.assertEqual(resp.headers.get(ACL_EXPOSE_HEADERS), 'Bar, Foo') def test_string_serialized(self): ''' If there is an Origin header in the request, the Access-Control-Allow-Origin header should be echoed back. ''' with self.app.test_client() as c: resp = c.get('/test_string') self.assertEqual(resp.headers.get(ACL_EXPOSE_HEADERS), 'Foo') def test_set_serialized(self): ''' If there is an Origin header in the request, the Access-Control-Allow-Origin header should be echoed back. ''' with self.app.test_client() as c: resp = c.get('/test_set') self.assertEqual(resp.headers.get(ACL_EXPOSE_HEADERS), 'Bar, Foo')
class TestCase(unittest.TestCase): """An helper mixin for common operations""" def setUp(self): """Initialize an Flask application""" self.app = Flask(__name__) @contextmanager def context(self): with self.app.test_request_context("/"): yield @contextmanager def settings(self, **settings): """ A context manager to alter app settings during a test and restore it after.. """ original = {} # backup for key, value in settings.items(): original[key] = self.app.config.get(key) self.app.config[key] = value yield # restore for key, value in original.items(): self.app.config[key] = value def get_specs(self, prefix="/api", app=None, status=200): """Get a Swagger specification for a RestPlus API""" with self.app.test_client() as client: response = client.get("{0}/swagger.json".format(prefix)) self.assertEquals(response.status_code, status) self.assertEquals(response.content_type, "application/json") return json.loads(response.data.decode("utf8")) def get_declaration(self, namespace="default", prefix="/api", status=200, app=None): """Get an API declaration for a given namespace""" with self.app.test_client() as client: response = client.get("{0}/swagger.json".format(prefix, namespace)) self.assertEquals(response.status_code, status) self.assertEquals(response.content_type, "application/json") return json.loads(response.data.decode("utf8")) def get_json(self, url, status=200): with self.app.test_client() as client: response = client.get(url) self.assertEquals(response.status_code, status) self.assertEquals(response.content_type, "application/json") return json.loads(response.data.decode("utf8"))
def test_multiple_apps(self): app1 = Flask(__name__) app2 = Flask(__name__) limiter = Limiter(global_limits = ["1/second"]) limiter.init_app(app1) limiter.init_app(app2) @app1.route("/ping") def ping(): return "PONG" @app1.route("/slowping") @limiter.limit("1/minute") def slow_ping(): return "PONG" @app2.route("/ping") @limiter.limit("2/second") def ping_2(): return "PONG" @app2.route("/slowping") @limiter.limit("2/minute") def slow_ping_2(): return "PONG" with hiro.Timeline().freeze() as timeline: with app1.test_client() as cli: self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 200) timeline.forward(59) self.assertEqual(cli.get("/slowping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/slowping").status_code, 200) with app2.test_client() as cli: self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/ping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/ping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 200) timeline.forward(59) self.assertEqual(cli.get("/slowping").status_code, 200) self.assertEqual(cli.get("/slowping").status_code, 429) timeline.forward(1) self.assertEqual(cli.get("/slowping").status_code, 200)
def test_bugsnag_notify(deliver): app = Flask("bugsnag") @app.route("/hello") def hello(): bugsnag.notify(SentinalError("oops")) return "OK" handle_exceptions(app) app.test_client().get('/hello') eq_(deliver.call_count, 1) payload = deliver.call_args[0][0] eq_(payload['events'][0]['metaData']['request']['url'], 'http://localhost/hello')
def test_bugsnag_crash(deliver): app = Flask("bugsnag") @app.route("/hello") def hello(): raise SentinalError("oops") handle_exceptions(app) app.test_client().get('/hello') eq_(deliver.call_count, 1) payload = deliver.call_args[0][0] eq_(payload['events'][0]['exceptions'][0]['errorClass'], 'test_flask.SentinalError') eq_(payload['events'][0]['metaData']['request']['url'], 'http://localhost/hello')
def test_bugsnag_crash(deliver): app = Flask("bugsnag") @app.route("/hello") def hello(): raise SentinalError("oops") handle_exceptions(app) app.test_client().get("/hello") eq_(deliver.call_count, 1) payload = deliver.call_args[0][0] eq_(payload["events"][0]["exceptions"][0]["errorClass"], "test_flask.SentinalError") eq_(payload["events"][0]["metaData"]["request"]["url"], "http://localhost/hello")
def test_bugsnag_notify(self): app = Flask("bugsnag") @app.route("/hello") def hello(): bugsnag.notify(SentinelError("oops")) return "OK" handle_exceptions(app) app.test_client().get('/hello') self.assertEqual(1, len(self.server.received)) payload = self.server.received[0]['json_body'] self.assertEqual(payload['events'][0]['metaData']['request']['url'], 'http://localhost/hello')
def test_bugsnag_includes_unknown_content_type_posted_data(deliver): app = Flask("bugsnag") @app.route("/form", methods=["PUT"]) def hello(): raise SentinalError("oops") handle_exceptions(app) app.test_client().put("/form", data="_data", content_type="application/octet-stream") eq_(deliver.call_count, 1) payload = deliver.call_args[0][0] eq_(payload["events"][0]["exceptions"][0]["errorClass"], "test_flask.SentinalError") eq_(payload["events"][0]["metaData"]["request"]["url"], "http://localhost/form") ok_("_data" in payload["events"][0]["metaData"]["request"]["data"]["body"])
def test_bugsnag_includes_posted_json_data(deliver): app = Flask("bugsnag") @app.route("/ajax", methods=["POST"]) def hello(): raise SentinalError("oops") handle_exceptions(app) app.test_client().post("/ajax", data='{"key": "value"}', content_type="application/json") eq_(deliver.call_count, 1) payload = deliver.call_args[0][0] eq_(payload["events"][0]["exceptions"][0]["errorClass"], "test_flask.SentinalError") eq_(payload["events"][0]["metaData"]["request"]["url"], "http://localhost/ajax") eq_(payload["events"][0]["metaData"]["request"]["data"], dict(key="value"))
def test_menu_render(): menu = decorators.menu menu.clear() app = Flask(__name__) app.testing = True @menu("Hello World", group_name="admin") class Hello(object): @menu("Index") def index(self): pass @menu("Page 2") def index2(self): pass @menu("Monster") class Monster(object): @menu("Home") def maggi(self): pass with app.test_client() as c: c.get("/") assert len(menu.render()) == 2
def test_flask(self): app = Flask(__name__) db = SQLAlchemy(app) app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:' class Cities(db.Model): __tablename__ = 'users' id = Column(Integer, primary_key=True) name = Column(String) population = Column(Integer) def __init__(self, name, population): self.name = name self.population = population app.config['TESTING'] = True app = app.test_client() db.create_all() city = Cities("Cordoba", 1000000) db.session.add(city) city = Cities("Rafaela", 99000) db.session.add(city) db.session.commit() query_string = '{ "sort": { "population" : "desc" } }' results = elastic_query(Cities, query_string) assert(results[0].name == 'Cordoba')
def main(): app = Flask(__name__) app.config.update( DB_CONNECTION_STRING=':memory:', CACHE_TYPE='simple', SQLALCHEMY_DATABASE_URI='sqlite://', ) app.debug = True injector = init_app(app=app, modules=[AppModule]) configure_views(app=app, cached=injector.get(Cache).cached) post_init_app(app, injector) client = app.test_client() response = client.get('/') print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.post('/', data={'key': 'foo', 'value': 'bar'}) print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.get('/') print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.get('/hello') print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.delete('/hello') print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.get('/') print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.get('/hello') print('%s\n%s%s' % (response.status, response.headers, response.data)) response = client.delete('/hello') print('%s\n%s%s' % (response.status, response.headers, response.data))
def setUp(self): app = Flask(__name__) Funnel(app) app.config['CSS_BUNDLES'] = { 'css-bundle': ( 'css/test.css', 'less/test.less', 'scss/test.scss', 'stylus/test.styl', ), } app.config['JS_BUNDLES'] = { 'js-bundle': ( 'js/test1.js', 'js/test2.js', 'coffee/test.coffee', ), } @app.route('/') def index(): return render_template_string( "{{ css('css-bundle') }} {{ js('js-bundle') }}") self.app = app self.client = app.test_client()
def test_custom_headers_from_config(self): app = Flask(__name__) app.config.setdefault(C.HEADER_LIMIT, "X-Limit") app.config.setdefault(C.HEADER_REMAINING, "X-Remaining") app.config.setdefault(C.HEADER_RESET, "X-Reset") limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True) @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual( resp.headers.get('X-Limit'), '10' ) self.assertEqual( resp.headers.get('X-Remaining'), '0' ) self.assertEqual( resp.headers.get('X-Reset'), str(int(time.time() + 49)) )
def test_whitelisting(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["1/minute"], headers_enabled=True) @app.route("/") def t(): return "test" @limiter.request_filter def w(): if request.headers.get("internal", None) == "true": return True return False with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: self.assertEqual(cli.get("/").status_code, 200) self.assertEqual(cli.get("/").status_code, 429) timeline.forward(60) self.assertEqual(cli.get("/").status_code, 200) for i in range(0,10): self.assertEqual( cli.get("/", headers = {"internal": "true"}).status_code, 200 )
def setUp(self): self.engine = Engine("op1", "op2") self.engine.op1.setup(in_name="in", out_name="middle", required=False) self.engine.op2.setup(in_name="middle", out_name="out") self.engine.op1.set(OptProductEx()) foisdouze = OptProductEx("foisdouze") foisdouze.force_option_value("factor", 12) self.engine.op2.set(foisdouze, OptProductEx()) egn_view = EngineView(self.engine, name="my_egn") egn_view.add_input("in", Numeric(vtype=int, min=-5, max=5)) egn_view.add_input("middle", Numeric(vtype=int)) print(self.engine.needed_inputs()) egn_view.add_output("in") egn_view.add_output("middle") egn_view.add_output("out") api = ReliureAPI() api.register_view(egn_view) app = Flask(__name__) app.config['TESTING'] = True app.register_blueprint(api, url_prefix="/api") self.app = app.test_client()
def test_headers_breach(self): app = Flask(__name__) limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True) @app.route("/t1") @limiter.limit("2/second; 10 per minute; 20/hour") def t(): return "test" with hiro.Timeline().freeze() as timeline: with app.test_client() as cli: for i in range(11): resp = cli.get("/t1") timeline.forward(1) self.assertEqual( resp.headers.get('X-RateLimit-Limit'), '10' ) self.assertEqual( resp.headers.get('X-RateLimit-Remaining'), '0' ) self.assertEqual( resp.headers.get('X-RateLimit-Reset'), str(int(time.time() + 49)) )
class MethodViewLoginTestCase(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.login_manager = LoginManager() self.login_manager.init_app(self.app) self.login_manager._login_disabled = False class SecretEndpoint(MethodView): decorators = [ login_required, fresh_login_required, ] def options(self): return u'' def get(self): return u'' self.app.add_url_rule('/secret', view_func=SecretEndpoint.as_view('secret')) def test_options_call_exempt(self): with self.app.test_client() as c: result = c.open('/secret', method='OPTIONS') self.assertEqual(result.status_code, 200)
def setUp(self): app = Flask(__name__) app.config['SECRET_KEY'] = 'my secret' digest_auth_my_realm = HTTPDigestAuth(realm='My Realm') @digest_auth_my_realm.get_password def get_digest_password_3(username): if username == 'susan': return 'hello' elif username == 'john': return 'bye' else: return None @app.route('/') def index(): return 'index' @app.route('/digest-with-realm') @digest_auth_my_realm.login_required def digest_auth_my_realm_route(): return 'digest_auth_my_realm:' + digest_auth_my_realm.username() self.app = app self.client = app.test_client()
def __init__(self): app = Flask(__name__) app.register_blueprint(api) self.app = app.test_client()
def client(app: Flask) -> FlaskClient: with app.test_client() as c: return c
def setUp(cls): """Runs before every test case""" app = Flask(__name__) healthcheck.HealthView.register(app) app.config['TESTING'] = True cls.app = app.test_client()
def on_success(service: Flask): with service.test_client() as client: return client
class TestEndpointsWithHeadersAndCookies(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.secret_key = 'super=secret' self.app.config['JWT_TOKEN_LOCATION'] = ['cookies', 'headers'] self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/' self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/auth/refresh' self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() @self.app.route('/auth/login_cookies', methods=['POST']) def login_cookies(): # Create the tokens we will be sending back to the user access_token = create_access_token(identity='test') refresh_token = create_refresh_token(identity='test') # Set the JWTs and the CSRF double submit protection cookies in this response resp = jsonify({'login': True}) set_access_cookies(resp, access_token) set_refresh_cookies(resp, refresh_token) return resp, 200 @self.app.route('/auth/login_headers', methods=['POST']) def login_headers(): ret = { 'access_token': create_access_token('test', fresh=True), 'refresh_token': create_refresh_token('test') } return jsonify(ret), 200 @self.app.route('/api/protected') @jwt_required def protected(): return jsonify({'msg': "hello world"}) def _jwt_post(self, url, jwt): response = self.client.post( url, content_type='application/json', headers={'Authorization': 'Bearer {}'.format(jwt)}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) return status_code, data def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'): header_type = '{} {}'.format(header_type, jwt).strip() response = self.client.get(url, headers={header_name: header_type}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) return status_code, data def _login_cookies(self): resp = self.client.post('/auth/login_cookies') index = 1 access_cookie_str = resp.headers[index][1] access_cookie_key = access_cookie_str.split('=')[0] access_cookie_value = "".join(access_cookie_str.split('=')[1:]) self.client.set_cookie('localhost', access_cookie_key, access_cookie_value) index += 1 if self.app.config['JWT_COOKIE_CSRF_PROTECT']: access_csrf_str = resp.headers[index][1] access_csrf_key = access_csrf_str.split('=')[0] access_csrf_value = "".join(access_csrf_str.split('=')[1:]) self.client.set_cookie('localhost', access_csrf_key, access_csrf_value) index += 1 access_csrf = access_csrf_value.split(';')[0] else: access_csrf = "" refresh_cookie_str = resp.headers[index][1] refresh_cookie_key = refresh_cookie_str.split('=')[0] refresh_cookie_value = "".join(refresh_cookie_str.split('=')[1:]) self.client.set_cookie('localhost', refresh_cookie_key, refresh_cookie_value) index += 1 if self.app.config['JWT_COOKIE_CSRF_PROTECT']: refresh_csrf_str = resp.headers[index][1] refresh_csrf_key = refresh_csrf_str.split('=')[0] refresh_csrf_value = "".join(refresh_csrf_str.split('=')[1:]) self.client.set_cookie('localhost', refresh_csrf_key, refresh_csrf_value) refresh_csrf = refresh_csrf_value.split(';')[0] else: refresh_csrf = "" return access_csrf, refresh_csrf def _login_headers(self): resp = self.client.post('/auth/login_headers') data = json.loads(resp.get_data(as_text=True)) return data['access_token'], data['refresh_token'] def test_accessing_endpoint_with_headers(self): access_token, _ = self._login_headers() header_type = '{} {}'.format('Bearer', access_token).strip() response = self.client.get('/api/protected', headers={'Authorization': header_type}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'}) def test_accessing_endpoint_with_cookies(self): access_csrf, _ = self._login_cookies() response = self.client.get('/api/protected', headers={'X-CSRF-TOKEN': access_csrf}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'}) def test_accessing_endpoint_without_jwt(self): response = self.client.get('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data)
class TestEndpointsWithCookies(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.secret_key = 'super=secret' self.app.config['JWT_TOKEN_LOCATION'] = 'cookies' self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/' self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/auth/refresh' self.app.config['JWT_ACCESS_COOKIE_NAME'] = 'access_token_cookie' self.app.config['JWT_ALGORITHM'] = 'HS256' self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() @self.app.route('/auth/login', methods=['POST']) def login(): # Create the tokens we will be sending back to the user access_token = create_access_token(identity='test') refresh_token = create_refresh_token(identity='test') # Set the JWTs and the CSRF double submit protection cookies in this response resp = jsonify({'login': True}) set_access_cookies(resp, access_token) set_refresh_cookies(resp, refresh_token) return resp, 200 @self.app.route('/auth/logout', methods=['POST']) def logout(): resp = jsonify({'logout': True}) unset_jwt_cookies(resp) return resp, 200 @self.app.route('/auth/refresh', methods=['POST']) @jwt_refresh_token_required def refresh(): username = get_jwt_identity() access_token = create_access_token(username, fresh=False) resp = jsonify({'refresh': True}) set_access_cookies(resp, access_token) return resp, 200 @self.app.route('/api/protected', methods=['POST']) @jwt_required def protected(): return jsonify({'msg': "hello world"}) def _login(self): resp = self.client.post('/auth/login') index = 1 access_cookie_str = resp.headers[index][1] access_cookie_key = access_cookie_str.split('=')[0] access_cookie_value = "".join(access_cookie_str.split('=')[1:]) self.client.set_cookie('localhost', access_cookie_key, access_cookie_value) index += 1 if self.app.config['JWT_COOKIE_CSRF_PROTECT']: access_csrf_str = resp.headers[index][1] access_csrf_key = access_csrf_str.split('=')[0] access_csrf_value = "".join(access_csrf_str.split('=')[1:]) self.client.set_cookie('localhost', access_csrf_key, access_csrf_value) index += 1 access_csrf = access_csrf_value.split(';')[0] else: access_csrf = "" refresh_cookie_str = resp.headers[index][1] refresh_cookie_key = refresh_cookie_str.split('=')[0] refresh_cookie_value = "".join(refresh_cookie_str.split('=')[1:]) self.client.set_cookie('localhost', refresh_cookie_key, refresh_cookie_value) index += 1 if self.app.config['JWT_COOKIE_CSRF_PROTECT']: refresh_csrf_str = resp.headers[index][1] refresh_csrf_key = refresh_csrf_str.split('=')[0] refresh_csrf_value = "".join(refresh_csrf_str.split('=')[1:]) self.client.set_cookie('localhost', refresh_csrf_key, refresh_csrf_value) refresh_csrf = refresh_csrf_value.split(';')[0] else: refresh_csrf = "" return access_csrf, refresh_csrf def test_headers(self): # Try with default options resp = self.client.post('/auth/login') access_cookie = resp.headers[1][1] access_csrf = resp.headers[2][1] refresh_cookie = resp.headers[3][1] refresh_csrf = resp.headers[4][1] self.assertIn('access_token_cookie', access_cookie) self.assertIn('csrf_access_token', access_csrf) self.assertIn('Path=/', access_csrf) self.assertIn('refresh_token_cookie', refresh_cookie) self.assertIn('csrf_refresh_token', refresh_csrf) self.assertIn('Path=/', refresh_csrf) # Try with overwritten options self.app.config['JWT_ACCESS_COOKIE_NAME'] = 'new_access_cookie' self.app.config['JWT_REFRESH_COOKIE_NAME'] = 'new_refresh_cookie' self.app.config['JWT_ACCESS_CSRF_COOKIE_NAME'] = 'x_csrf_access_token' self.app.config[ 'JWT_REFRESH_CSRF_COOKIE_NAME'] = 'x_csrf_refresh_token' self.app.config['JWT_ACCESS_COOKIE_PATH'] = None self.app.config['JWT_REFRESH_COOKIE_PATH'] = None resp = self.client.post('/auth/login') access_cookie = resp.headers[1][1] access_csrf = resp.headers[2][1] refresh_cookie = resp.headers[3][1] refresh_csrf = resp.headers[4][1] self.assertIn('new_access_cookie', access_cookie) self.assertIn('x_csrf_access_token', access_csrf) self.assertIn('Path=/', access_csrf) self.assertIn('new_refresh_cookie', refresh_cookie) self.assertIn('x_csrf_refresh_token', refresh_csrf) self.assertIn('Path=/', refresh_csrf) # Try logout headers resp = self.client.post('/auth/logout') refresh_cookie = resp.headers[1][1] access_cookie = resp.headers[2][1] self.assertIn('Expires=Thu, 01-Jan-1970', refresh_cookie) self.assertIn('Expires=Thu, 01-Jan-1970', access_cookie) def test_endpoints_with_cookies(self): self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False # Try access without logging in response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Try refresh without logging in response = self.client.post('/auth/refresh') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Try with logging in self._login() response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'}) # Try refresh without logging in response = self.client.post('/auth/refresh') access_cookie_str = response.headers[1][1] status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertIn('access_token_cookie', access_cookie_str) self.assertEqual(status_code, 200) self.assertEqual(data, {'refresh': True}) # Try accessing endpoint with newly refreshed token access_cookie_key = access_cookie_str.split('=')[0] access_cookie_value = "".join(access_cookie_str.split('=')[1:]) self.client.set_cookie('localhost', access_cookie_key, access_cookie_value) response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'}) def test_access_endpoints_with_cookies_and_csrf(self): self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True # Try without logging in response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Login access_csrf, refresh_csrf = self._login() # Try with logging in but without double submit csrf protection response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Try with logged in and bad header name for double submit token response = self.client.post('/api/protected', headers={'bad-header-name': 'banana'}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Try with logged in and bad header data for double submit token response = self.client.post('/api/protected', headers={'X-CSRF-TOKEN': 'banana'}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Try with logged in and good double submit token response = self.client.post('/api/protected', headers={'X-CSRF-TOKEN': access_csrf}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'}) def test_access_endpoints_with_cookie_missing_csrf_field(self): # Test accessing a csrf protected endpoint with a cookie that does not # have a csrf token in it self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False self._login() self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 422) self.assertIn('msg', data) def test_access_endpoints_with_cookie_csrf_claim_not_string(self): now = datetime.utcnow() token_data = { 'exp': now + timedelta(minutes=5), 'iat': now, 'nbf': now, 'jti': 'banana', 'identity': 'banana', 'type': 'refresh', 'csrf': 404 } secret = self.app.secret_key algorithm = self.app.config['JWT_ALGORITHM'] encoded_token = jwt.encode(token_data, secret, algorithm).decode('utf-8') access_cookie_key = self.app.config['JWT_ACCESS_COOKIE_NAME'] self.client.set_cookie('localhost', access_cookie_key, encoded_token) self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True response = self.client.post('/api/protected') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) def test_custom_csrf_methods(self): @self.app.route('/protected-post', methods=['POST']) @jwt_required def protected_post(): return jsonify({'msg': "hello world"}) @self.app.route('/protected-get', methods=['GET']) @jwt_required def protected_get(): return jsonify({'msg': "hello world"}) # Login (saves jwts in the cookies for the test client self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True self._login() # Test being able to access GET without CSRF protection, and POST with # CSRF protection self.app.config['JWT_CSRF_METHODS'] = ['POST'] response = self.client.post('/protected-post') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) response = self.client.get('/protected-get') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'}) # Now swap it around, and verify the JWT_CRSF_METHODS are being honored self.app.config['JWT_CSRF_METHODS'] = ['GET'] response = self.client.get('/protected-get') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) response = self.client.post('/protected-post') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertEqual(data, {'msg': 'hello world'})
class LoginTestCase(unittest.TestCase): ''' Tests for results of the login_user function ''' def setUp(self): self.app = Flask(__name__) self.app.config['SECRET_KEY'] = 'deterministic' self.app.config['SESSION_PROTECTION'] = None self.app.config['TESTING'] = True self.remember_cookie_name = 'remember' self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name self.login_manager = LoginManager() self.login_manager.init_app(self.app) self.login_manager._login_disabled = False @self.app.route('/') def index(): return u'Welcome!' @self.app.route('/secret') def secret(): return self.login_manager.unauthorized() @self.app.route('/login-notch') def login_notch(): return unicode(login_user(notch)) @self.app.route('/login-notch-remember') def login_notch_remember(): return unicode(login_user(notch, remember=True)) @self.app.route('/login-notch-permanent') def login_notch_permanent(): session.permanent = True return unicode(login_user(notch)) @self.app.route('/needs-refresh') def needs_refresh(): return self.login_manager.needs_refresh() @self.app.route('/confirm-login') def _confirm_login(): confirm_login() return u'' @self.app.route('/username') def username(): if current_user.is_authenticated(): return current_user.name return u'Anonymous' @self.app.route('/is-fresh') def is_fresh(): return unicode(login_fresh()) @self.app.route('/logout') def logout(): return unicode(logout_user()) @self.login_manager.user_loader def load_user(user_id): return USERS[int(user_id)] @self.login_manager.header_loader def load_user_from_header(header_value): if header_value.startswith('Basic '): header_value = header_value.replace('Basic ', '', 1) try: user_id = base64.b64decode(header_value) except TypeError: pass return USERS.get(int(user_id)) @self.login_manager.request_loader def load_user_from_request(request): user_id = request.args.get('user_id') try: user_id = int(float(user_id)) except TypeError: pass return USERS.get(user_id) @self.app.route('/empty_session') def empty_session(): return unicode(u'modified=%s' % session.modified) # This will help us with the possibility of typoes in the tests. Now # we shouldn't have to check each response to help us set up state # (such as login pages) to make sure it worked: we will always # get an exception raised (rather than return a 404 response) @self.app.errorhandler(404) def handle_404(e): raise e unittest.TestCase.setUp(self) def _get_remember_cookie(self, test_client): our_cookies = test_client.cookie_jar._cookies['localhost.local']['/'] return our_cookies[self.remember_cookie_name] def _delete_session(self, c): # Helper method to cause the session to be deleted # as if the browser was closed. This will remove # the session regardless of the permament flag # on the session! with c.session_transaction() as sess: sess.clear() # # Login # def test_test_request_context_users_are_anonymous(self): with self.app.test_request_context(): self.assertTrue(current_user.is_anonymous()) def test_defaults_anonymous(self): with self.app.test_client() as c: result = c.get('/username') self.assertEqual(u'Anonymous', result.data.decode('utf-8')) def test_login_user(self): with self.app.test_request_context(): result = login_user(notch) self.assertTrue(result) self.assertEqual(current_user.name, u'Notch') def test_login_user_emits_signal(self): with self.app.test_request_context(): with listen_to(user_logged_in) as listener: login_user(notch) listener.assert_heard_one(self.app, user=notch) def test_login_inactive_user(self): with self.app.test_request_context(): result = login_user(creeper) self.assertTrue(current_user.is_anonymous()) self.assertFalse(result) def test_login_inactive_user_forced(self): with self.app.test_request_context(): login_user(creeper, force=True) self.assertEqual(current_user.name, u'Creeper') def test_login_user_with_header(self): user_id = 2 user_name = USERS[user_id].name self.login_manager.request_callback = None with self.app.test_client() as c: basic_fmt = 'Basic {0}' decoded = bytes.decode(base64.b64encode(str.encode(str(user_id)))) headers = [('Authorization', basic_fmt.format(decoded))] result = c.get('/username', headers=headers) self.assertEqual(user_name, result.data.decode('utf-8')) def test_login_invalid_user_with_header(self): user_id = 4 user_name = u'Anonymous' self.login_manager.request_callback = None with self.app.test_client() as c: basic_fmt = 'Basic {0}' decoded = bytes.decode(base64.b64encode(str.encode(str(user_id)))) headers = [('Authorization', basic_fmt.format(decoded))] result = c.get('/username', headers=headers) self.assertEqual(user_name, result.data.decode('utf-8')) def test_login_user_with_request(self): user_id = 2 user_name = USERS[user_id].name with self.app.test_client() as c: url = '/username?user_id={user_id}'.format(user_id=user_id) result = c.get(url) self.assertEqual(user_name, result.data.decode('utf-8')) def test_login_invalid_user_with_request(self): user_id = 4 user_name = u'Anonymous' with self.app.test_client() as c: url = '/username?user_id={user_id}'.format(user_id=user_id) result = c.get(url) self.assertEqual(user_name, result.data.decode('utf-8')) # # Logout # def test_logout_logs_out_current_user(self): with self.app.test_request_context(): login_user(notch) logout_user() self.assertTrue(current_user.is_anonymous()) def test_logout_emits_signal(self): with self.app.test_request_context(): login_user(notch) with listen_to(user_logged_out) as listener: logout_user() listener.assert_heard_one(self.app, user=notch) # # Unauthorized # def test_unauthorized_fires_unauthorized_signal(self): with self.app.test_client() as c: with listen_to(user_unauthorized) as listener: c.get('/secret') listener.assert_heard_one(self.app) def test_unauthorized_flashes_message_with_login_view(self): self.login_manager.login_view = '/login' expected_message = self.login_manager.login_message = u'Log in!' expected_category = self.login_manager.login_message_category = 'login' with self.app.test_client() as c: c.get('/secret') msgs = get_flashed_messages(category_filter=[expected_category]) self.assertEqual([expected_message], msgs) def test_unauthorized_flash_message_localized(self): def _gettext(msg): if msg == u'Log in!': return u'Einloggen' self.login_manager.login_view = '/login' self.login_manager.localize_callback = _gettext self.login_manager.login_message = u'Log in!' expected_message = u'Einloggen' expected_category = self.login_manager.login_message_category = 'login' with self.app.test_client() as c: c.get('/secret') msgs = get_flashed_messages(category_filter=[expected_category]) self.assertEqual([expected_message], msgs) self.login_manager.localize_callback = None def test_unauthorized_uses_authorized_handler(self): @self.login_manager.unauthorized_handler def _callback(): return Response('This is secret!', 401) with self.app.test_client() as c: result = c.get('/secret') self.assertEqual(result.status_code, 401) self.assertEqual(u'This is secret!', result.data.decode('utf-8')) def test_unauthorized_aborts_with_401(self): with self.app.test_client() as c: result = c.get('/secret') self.assertEqual(result.status_code, 401) def test_unauthorized_redirects_to_login_view(self): self.login_manager.login_view = 'login' @self.app.route('/login') def login(): return 'Login Form Goes Here!' with self.app.test_client() as c: result = c.get('/secret') self.assertEqual(result.status_code, 302) self.assertEqual(result.location, 'http://localhost/login?next=%2Fsecret') # # Session Persistence/Freshness # def test_login_persists(self): with self.app.test_client() as c: c.get('/login-notch') result = c.get('/username') self.assertEqual(u'Notch', result.data.decode('utf-8')) def test_logout_persists(self): with self.app.test_client() as c: c.get('/login-notch') c.get('/logout') result = c.get('/username') self.assertEqual(result.data.decode('utf-8'), u'Anonymous') def test_incorrect_id_logs_out(self): # Ensure that any attempt to reload the user by the ID # will seem as if the user is no longer valid @self.login_manager.user_loader def new_user_loader(user_id): return with self.app.test_client() as c: # Successfully logs in c.get('/login-notch') result = c.get('/username') self.assertEqual(u'Anonymous', result.data.decode('utf-8')) def test_authentication_is_fresh(self): with self.app.test_client() as c: c.get('/login-notch-remember') result = c.get('/is-fresh') self.assertEqual(u'True', result.data.decode('utf-8')) def test_remember_me(self): with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) result = c.get('/username') self.assertEqual(u'Notch', result.data.decode('utf-8')) def test_remember_me_uses_custom_cookie_parameters(self): name = self.app.config['REMEMBER_COOKIE_NAME'] = 'myname' duration = self.app.config['REMEMBER_COOKIE_DURATION'] = \ timedelta(days=2) domain = self.app.config['REMEMBER_COOKIE_DOMAIN'] = '.localhost.local' with self.app.test_client() as c: c.get('/login-notch-remember') # TODO: Is there a better way to test this? self.assertTrue(domain in c.cookie_jar._cookies, 'Custom domain not found as cookie domain') domain_cookie = c.cookie_jar._cookies[domain] self.assertTrue(name in domain_cookie['/'], 'Custom name not found as cookie name') cookie = domain_cookie['/'][name] expiration_date = datetime.fromtimestamp(cookie.expires) expected_date = datetime.now() + duration difference = expected_date - expiration_date fail_msg = 'The expiration date {0} was far from the expected {1}' fail_msg = fail_msg.format(expiration_date, expected_date) self.assertLess(difference, timedelta(seconds=10), fail_msg) self.assertGreater(difference, timedelta(seconds=-10), fail_msg) def test_remember_me_is_unfresh(self): with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8')) def test_user_loaded_from_cookie_fired(self): with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) with listen_to(user_loaded_from_cookie) as listener: c.get('/username') listener.assert_heard_one(self.app, user=notch) def test_user_loaded_from_header_fired(self): user_id = 1 user_name = USERS[user_id].name self.login_manager.request_callback = None with self.app.test_client() as c: with listen_to(user_loaded_from_header) as listener: headers = [( 'Authorization', 'Basic %s' % (bytes.decode(base64.b64encode(str.encode(str(user_id))))), )] result = c.get('/username', headers=headers) self.assertEqual(user_name, result.data.decode('utf-8')) listener.assert_heard_one(self.app, user=USERS[user_id]) def test_user_loaded_from_request_fired(self): user_id = 1 user_name = USERS[user_id].name with self.app.test_client() as c: with listen_to(user_loaded_from_request) as listener: url = '/username?user_id={user_id}'.format(user_id=user_id) result = c.get(url) self.assertEqual(user_name, result.data.decode('utf-8')) listener.assert_heard_one(self.app, user=USERS[user_id]) def test_logout_stays_logged_out_with_remember_me(self): with self.app.test_client() as c: c.get('/login-notch-remember') c.get('/logout') result = c.get('/username') self.assertEqual(result.data.decode('utf-8'), u'Anonymous') def test_needs_refresh_uses_handler(self): @self.login_manager.needs_refresh_handler def _on_refresh(): return u'Needs Refresh!' with self.app.test_client() as c: c.get('/login-notch-remember') result = c.get('/needs-refresh') self.assertEqual(u'Needs Refresh!', result.data.decode('utf-8')) def test_needs_refresh_fires_needs_refresh_signal(self): with self.app.test_client() as c: c.get('/login-notch-remember') with listen_to(user_needs_refresh) as listener: c.get('/needs-refresh') listener.assert_heard_one(self.app) def test_needs_refresh_fires_flash_when_redirect_to_refresh_view(self): self.login_manager.refresh_view = '/refresh_view' self.login_manager.needs_refresh_message = u'Refresh' self.login_manager.needs_refresh_message_category = 'refresh' category_filter = [self.login_manager.needs_refresh_message_category] with self.app.test_client() as c: c.get('/login-notch-remember') c.get('/needs-refresh') msgs = get_flashed_messages(category_filter=category_filter) self.assertIn(self.login_manager.needs_refresh_message, msgs) def test_needs_refresh_flash_message_localized(self): def _gettext(msg): if msg == u'Refresh': return u'Aktualisieren' self.login_manager.refresh_view = '/refresh_view' self.login_manager.localize_callback = _gettext self.login_manager.needs_refresh_message = u'Refresh' self.login_manager.needs_refresh_message_category = 'refresh' category_filter = [self.login_manager.needs_refresh_message_category] with self.app.test_client() as c: c.get('/login-notch-remember') c.get('/needs-refresh') msgs = get_flashed_messages(category_filter=category_filter) self.assertIn(u'Aktualisieren', msgs) self.login_manager.localize_callback = None def test_needs_refresh_aborts_403(self): with self.app.test_client() as c: c.get('/login-notch-remember') result = c.get('/needs-refresh') self.assertEqual(result.status_code, 403) def test_redirects_to_refresh_view(self): @self.app.route('/refresh-view') def refresh_view(): return '' self.login_manager.refresh_view = 'refresh_view' with self.app.test_client() as c: c.get('/login-notch-remember') result = c.get('/needs-refresh') self.assertEqual(result.status_code, 302) expected = 'http://localhost/refresh-view?next=%2Fneeds-refresh' self.assertEqual(result.location, expected) def test_confirm_login(self): with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) self.assertEqual(u'False', c.get('/is-fresh').data.decode('utf-8')) c.get('/confirm-login') self.assertEqual(u'True', c.get('/is-fresh').data.decode('utf-8')) def test_user_login_confirmed_signal_fired(self): with self.app.test_client() as c: with listen_to(user_login_confirmed) as listener: c.get('/confirm-login') listener.assert_heard_one(self.app) def test_session_not_modified(self): with self.app.test_client() as c: # Within the request we think we didn't modify the session. self.assertEquals(u'modified=False', c.get('/empty_session').data.decode('utf-8')) # But after the request, the session could be modified by the # "after_request" handlers that call _update_remember_cookie. # Ensure that if nothing changed the session is not modified. self.assertFalse(session.modified) # # Session Protection # def test_session_protection_basic_passes_successive_requests(self): self.app.config['SESSION_PROTECTION'] = 'basic' with self.app.test_client() as c: c.get('/login-notch-remember') username_result = c.get('/username') self.assertEqual(u'Notch', username_result.data.decode('utf-8')) fresh_result = c.get('/is-fresh') self.assertEqual(u'True', fresh_result.data.decode('utf-8')) def test_session_protection_strong_passes_successive_requests(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-remember') username_result = c.get('/username') self.assertEqual(u'Notch', username_result.data.decode('utf-8')) fresh_result = c.get('/is-fresh') self.assertEqual(u'True', fresh_result.data.decode('utf-8')) def test_session_protection_basic_marks_session_unfresh(self): self.app.config['SESSION_PROTECTION'] = 'basic' with self.app.test_client() as c: c.get('/login-notch-remember') username_result = c.get('/username', headers=[('User-Agent', 'different')]) self.assertEqual(u'Notch', username_result.data.decode('utf-8')) fresh_result = c.get('/is-fresh') self.assertEqual(u'False', fresh_result.data.decode('utf-8')) def test_session_protection_basic_fires_signal(self): self.app.config['SESSION_PROTECTION'] = 'basic' with self.app.test_client() as c: c.get('/login-notch-remember') with listen_to(session_protected) as listener: c.get('/username', headers=[('User-Agent', 'different')]) listener.assert_heard_one(self.app) def test_session_protection_basic_skips_when_remember_me(self): self.app.config['SESSION_PROTECTION'] = 'basic' with self.app.test_client() as c: c.get('/login-notch-remember') # clear session to force remember me (and remove old session id) self._delete_session(c) # should not trigger protection because "sess" is empty with listen_to(session_protected) as listener: c.get('/username') listener.assert_heard_none(self.app) def test_session_protection_strong_skips_when_remember_me(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-remember') # clear session to force remember me (and remove old session id) self._delete_session(c) # should not trigger protection because "sess" is empty with listen_to(session_protected) as listener: c.get('/username') listener.assert_heard_none(self.app) def test_permanent_strong_session_protection_marks_session_unfresh(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-permanent') username_result = c.get('/username', headers=[('User-Agent', 'different')]) self.assertEqual(u'Notch', username_result.data.decode('utf-8')) fresh_result = c.get('/is-fresh') self.assertEqual(u'False', fresh_result.data.decode('utf-8')) def test_permanent_strong_session_protection_fires_signal(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-permanent') with listen_to(session_protected) as listener: c.get('/username', headers=[('User-Agent', 'different')]) listener.assert_heard_one(self.app) def test_session_protection_strong_deletes_session(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-remember') username_result = c.get('/username', headers=[('User-Agent', 'different')]) self.assertEqual(u'Anonymous', username_result.data.decode('utf-8')) def test_session_protection_strong_fires_signal_user_agent(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-remember') with listen_to(session_protected) as listener: c.get('/username', headers=[('User-Agent', 'different')]) listener.assert_heard_one(self.app) def test_session_protection_strong_fires_signal_x_forwarded_for(self): self.app.config['SESSION_PROTECTION'] = 'strong' with self.app.test_client() as c: c.get('/login-notch-remember', headers=[('X-Forwarded-For', '10.1.1.1')]) with listen_to(session_protected) as listener: c.get('/username', headers=[('X-Forwarded-For', '10.1.1.2')]) listener.assert_heard_one(self.app) def test_session_protection_skip_when_off_and_anonymous(self): with self.app.test_client() as c: # no user access with listen_to(user_accessed) as user_listener: results = c.get('/') user_listener.assert_heard_none(self.app) # access user with no session data with listen_to(session_protected) as session_listener: results = c.get('/username') self.assertEqual(results.data.decode('utf-8'), u'Anonymous') session_listener.assert_heard_none(self.app) # verify no session data has been set self.assertFalse(session) def test_session_protection_skip_when_basic_and_anonymous(self): self.app.config['SESSION_PROTECTION'] = 'basic' with self.app.test_client() as c: # no user access with listen_to(user_accessed) as user_listener: results = c.get('/') user_listener.assert_heard_none(self.app) # access user with no session data with listen_to(session_protected) as session_listener: results = c.get('/username') self.assertEqual(results.data.decode('utf-8'), u'Anonymous') session_listener.assert_heard_none(self.app) # verify no session data has been set other than '_id' self.assertIsNotNone(session.get('_id')) self.assertTrue(len(session) == 1) # # Custom Token Loader # def test_custom_token_loader(self): @self.login_manager.token_loader def load_token(token): return USER_TOKENS.get(token) with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) # Test that remember me functionality still works self.assertEqual(u'Notch', c.get('/username').data.decode('utf-8')) # Test that we used the custom authentication token remember_cookie = self._get_remember_cookie(c) expected_value = make_secure_token(u'Notch', key='deterministic') self.assertEqual(expected_value, remember_cookie.value) def test_change_api_key_with_token_loader(self): @self.login_manager.token_loader def load_token(token): return USER_TOKENS.get(token) with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) self.app.config['SECRET_KEY'] = 'ima change this now' result = c.get('/username') self.assertEqual(result.data.decode('utf-8'), u'Notch') def test_custom_token_loader_with_no_user(self): @self.login_manager.token_loader def load_token(token): return with self.app.test_client() as c: c.get('/login-notch-remember') self._delete_session(c) result = c.get('/username') self.assertEqual(result.data.decode('utf-8'), u'Anonymous') # # Lazy Access User # def test_requests_without_accessing_session(self): with self.app.test_client() as c: c.get('/login-notch') #no session access with listen_to(user_accessed) as listener: c.get('/') listener.assert_heard_none(self.app) #should have a session access with listen_to(user_accessed) as listener: result = c.get('/username') listener.assert_heard_one(self.app) self.assertEqual(result.data.decode('utf-8'), u'Notch') # # View Decorators # def test_login_required_decorator(self): @self.app.route('/protected') @login_required def protected(): return u'Access Granted' with self.app.test_client() as c: result = c.get('/protected') self.assertEqual(result.status_code, 401) c.get('/login-notch') result2 = c.get('/protected') self.assertIn(u'Access Granted', result2.data.decode('utf-8')) def test_decorators_are_disabled(self): @self.app.route('/protected') @login_required @fresh_login_required def protected(): return u'Access Granted' self.app.login_manager._login_disabled = True with self.app.test_client() as c: result = c.get('/protected') self.assertIn(u'Access Granted', result.data.decode('utf-8')) def test_fresh_login_required_decorator(self): @self.app.route('/very-protected') @fresh_login_required def very_protected(): return 'Access Granted' with self.app.test_client() as c: result = c.get('/very-protected') self.assertEqual(result.status_code, 401) c.get('/login-notch-remember') logged_in_result = c.get('/very-protected') self.assertEqual(u'Access Granted', logged_in_result.data.decode('utf-8')) self._delete_session(c) stale_result = c.get('/very-protected') self.assertEqual(stale_result.status_code, 403) c.get('/confirm-login') refreshed_result = c.get('/very-protected') self.assertEqual(u'Access Granted', refreshed_result.data.decode('utf-8')) # # Misc # @unittest.skipIf(werkzeug_version.startswith("0.9"), "wait for upstream implementing RFC 5987") def test_chinese_user_agent(self): with self.app.test_client() as c: result = c.get('/', headers=[('User-Agent', u'中文')]) self.assertEqual(u'Welcome!', result.data.decode('utf-8')) @unittest.skipIf(werkzeug_version.startswith("0.9"), "wait for upstream implementing RFC 5987") def test_russian_cp1251_user_agent(self): with self.app.test_client() as c: headers = [('User-Agent', u'ЯЙЮя'.encode('cp1251'))] response = c.get('/', headers=headers) self.assertEqual(response.data.decode('utf-8'), u'Welcome!') def test_make_secure_token_default_key(self): with self.app.test_request_context(): self.assertEqual(make_secure_token('foo'), '0f05743a2b617b2625362ab667c0dbdf4c9ec13a') def test_user_context_processor(self): with self.app.test_request_context(): _ucp = self.app.context_processor(_user_context_processor) self.assertIsInstance(_ucp()['current_user'], AnonymousUserMixin)
def client(request): parent_app = Flask(__name__) parent_app.register_blueprint(app, url_prefix=request.param) with parent_app.test_client() as client: yield client
def setUp(self): app = Flask(__name__) app.config['SECRET_KEY'] = 'my secret' basic_auth = HTTPBasicAuth() basic_auth_my_realm = HTTPBasicAuth(scheme='CustomBasic', realm='My Realm') basic_custom_auth = HTTPBasicAuth() basic_verify_auth = HTTPBasicAuth() digest_auth = HTTPDigestAuth() digest_auth_my_realm = HTTPDigestAuth(scheme='CustomDigest', realm='My Realm') digest_auth_ha1_pw = HTTPDigestAuth(use_ha1_pw=True) @digest_auth_ha1_pw.get_password def get_digest_password(username): if username == 'susan': return get_ha1(username, 'hello', digest_auth_ha1_pw.realm) elif username == 'john': return get_ha1(username, 'bye', digest_auth_ha1_pw.realm) else: return None @basic_auth.get_password def get_basic_password(username): if username == 'john': return 'hello' elif username == 'susan': return 'bye' else: return None @basic_auth_my_realm.get_password def get_basic_password_2(username): if username == 'john': return 'johnhello' elif username == 'susan': return 'susanbye' else: return None @basic_auth_my_realm.hash_password def basic_auth_my_realm_hash_password(username, password): return username + password @basic_auth_my_realm.error_handler def basic_auth_my_realm_error(): return 'custom error' @basic_custom_auth.get_password def get_basic_custom_auth_get_password(username): if username == 'john': return md5('hello').hexdigest() elif username == 'susan': return md5('bye').hexdigest() else: return None @basic_custom_auth.hash_password def basic_custom_auth_hash_password(password): return md5(password).hexdigest() @basic_verify_auth.verify_password def basic_verify_auth_verify_password(username, password): g.anon = False if username == 'john': return password == 'hello' elif username == 'susan': return password == 'bye' elif username == '': g.anon = True return True return False @digest_auth.get_password def get_digest_password_2(username): if username == 'susan': return 'hello' elif username == 'john': return 'bye' else: return None @digest_auth_my_realm.get_password def get_digest_password_3(username): if username == 'susan': return 'hello' elif username == 'john': return 'bye' else: return None @app.route('/') def index(): return 'index' @app.route('/basic') @basic_auth.login_required def basic_auth_route(): return 'basic_auth:' + basic_auth.username() @app.route('/basic-with-realm') @basic_auth_my_realm.login_required def basic_auth_my_realm_route(): return 'basic_auth_my_realm:' + basic_auth_my_realm.username() @app.route('/basic-custom') @basic_custom_auth.login_required def basic_custom_auth_route(): return 'basic_custom_auth:' + basic_custom_auth.username() @app.route('/basic-verify') @basic_verify_auth.login_required def basic_verify_auth_route(): return 'basic_verify_auth:' + basic_verify_auth.username() + \ ' anon:' + str(g.anon) @app.route('/digest') @digest_auth.login_required def digest_auth_route(): return 'digest_auth:' + digest_auth.username() @app.route('/digest_ha1_pw') @digest_auth_ha1_pw.login_required def digest_auth_ha1_pw_route(): return 'digest_auth:' + digest_auth.username() @app.route('/digest-with-realm') @digest_auth_my_realm.login_required def digest_auth_my_realm_route(): return 'digest_auth_my_realm:' + digest_auth_my_realm.username() self.app = app self.basic_auth = basic_auth self.basic_auth_my_realm = basic_auth_my_realm self.basic_custom_auth = basic_custom_auth self.basic_verify_auth = basic_verify_auth self.digest_auth = digest_auth self.client = app.test_client()
val = self.serializer.dumps(dict(session)) self.redis.setex(self.key_prefix + session.sid, val, int(app.permanent_session_lifetime.total_seconds())) response.set_cookie(app.session_cookie_name, session.sid, expires=expires, httponly=http_only, domain=domain, path=path, secure=secure) if __name__ == '__main__': from flask import Flask, session app = Flask(__name__) RedisSession.init_app(app) @app.route('/') def index(): session['test'] = {"test": "test"} return 'set session ok' @app.route('/dump') def dump(): print session return 'dump session ok' with app.test_client() as c: print c.get('/') print c.get('/dump')
class FlaskTestCase(unittest.TestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder import ModelRestApi self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False self.app.config["FAB_API_MAX_PAGE_SIZE"] = MAX_PAGE_SIZE self.app.config["WTF_CSRF_ENABLED"] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) # Create models and insert data insert_data(self.db.session, MODEL1_DATA_SIZE) class Model1Api(ModelRestApi): datamodel = SQLAInterface(Model1) list_columns = [ "field_integer", "field_float", "field_string", "field_date", ] description_columns = { "field_integer": "Field Integer", "field_float": "Field Float", "field_string": "Field String", } self.model1api = Model1Api self.appbuilder.add_api(Model1Api) class Model1ApiFieldsInfo(Model1Api): datamodel = SQLAInterface(Model1) add_columns = ["field_integer", "field_float", "field_string", "field_date"] edit_columns = ["field_string", "field_integer"] self.model1apifieldsinfo = Model1ApiFieldsInfo self.appbuilder.add_api(Model1ApiFieldsInfo) class Model1FuncApi(ModelRestApi): datamodel = SQLAInterface(Model1) list_columns = [ "field_integer", "field_float", "field_string", "field_date", "full_concat", ] description_columns = { "field_integer": "Field Integer", "field_float": "Field Float", "field_string": "Field String", } self.model1funcapi = Model1Api self.appbuilder.add_api(Model1FuncApi) class Model1ApiExcludeCols(ModelRestApi): datamodel = SQLAInterface(Model1) list_exclude_columns = ["field_integer", "field_float", "field_date"] show_exclude_columns = list_exclude_columns edit_exclude_columns = list_exclude_columns add_exclude_columns = list_exclude_columns self.appbuilder.add_api(Model1ApiExcludeCols) class Model1ApiOrder(ModelRestApi): datamodel = SQLAInterface(Model1) base_order = ("field_integer", "desc") self.appbuilder.add_api(Model1ApiOrder) class Model1ApiRestrictedPermissions(ModelRestApi): datamodel = SQLAInterface(Model1) base_permissions = ["can_get", "can_info"] self.appbuilder.add_api(Model1ApiRestrictedPermissions) class Model1ApiFiltered(ModelRestApi): datamodel = SQLAInterface(Model1) base_filters = [ ["field_integer", FilterGreater, 2], ["field_integer", FilterSmaller, 4], ] self.appbuilder.add_api(Model1ApiFiltered) class ModelWithEnumsApi(ModelRestApi): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_api(ModelWithEnumsApi) class Model1BrowserLogin(ModelRestApi): datamodel = SQLAInterface(Model1) allow_browser_login = True self.appbuilder.add_api(Model1BrowserLogin) class ModelMMApi(ModelRestApi): datamodel = SQLAInterface(ModelMMParent) self.appbuilder.add_api(ModelMMApi) class ModelMMRequiredApi(ModelRestApi): datamodel = SQLAInterface(ModelMMParentRequired) self.appbuilder.add_api(ModelMMRequiredApi) class Model1CustomValidationApi(ModelRestApi): datamodel = SQLAInterface(Model1) validators_columns = {"field_string": validate_name} self.appbuilder.add_api(Model1CustomValidationApi) class Model2Api(ModelRestApi): datamodel = SQLAInterface(Model2) list_columns = ["group"] show_columns = ["group"] self.model2api = Model2Api self.appbuilder.add_api(Model2Api) class Model2ApiFilteredRelFields(ModelRestApi): datamodel = SQLAInterface(Model2) list_columns = ["group"] show_columns = ["group"] add_query_rel_fields = { "group": [ ["field_integer", FilterGreater, 2], ["field_integer", FilterSmaller, 4], ] } edit_query_rel_fields = add_query_rel_fields self.model2apifilteredrelfields = Model2ApiFilteredRelFields self.appbuilder.add_api(Model2ApiFilteredRelFields) role_admin = self.appbuilder.sm.find_role("Admin") self.appbuilder.sm.add_user( USERNAME, "admin", "user", "*****@*****.**", role_admin, PASSWORD ) def tearDown(self): self.appbuilder = None self.app = None self.db = None @staticmethod def auth_client_get(client, token, uri): return client.get(uri, headers={"Authorization": "Bearer {}".format(token)}) @staticmethod def auth_client_delete(client, token, uri): return client.delete(uri, headers={"Authorization": "Bearer {}".format(token)}) @staticmethod def auth_client_put(client, token, uri, json): return client.put( uri, json=json, headers={"Authorization": "Bearer {}".format(token)} ) @staticmethod def auth_client_post(client, token, uri, json): return client.post( uri, json=json, headers={"Authorization": "Bearer {}".format(token)} ) @staticmethod def _login(client, username, password): """ Login help method :param client: Flask test client :param username: username :param password: password :return: Flask client response class """ return client.post( "api/{}/security/login".format(API_SECURITY_VERSION), data=json.dumps( { API_SECURITY_USERNAME_KEY: username, API_SECURITY_PASSWORD_KEY: password, API_SECURITY_PROVIDER_KEY: "db", } ), content_type="application/json", ) def login(self, client, username, password): # Login with default admin rv = self._login(client, username, password) try: return json.loads(rv.data.decode("utf-8")).get("access_token") except Exception: return rv def browser_login(self, client, username, password): # Login with default admin return client.post( "/login/", data=dict(username=username, password=password), follow_redirects=True, ) def browser_logout(self, client): return client.get("/logout/") def test_auth_login(self): """ REST Api: Test auth login """ client = self.app.test_client() rv = self._login(client, USERNAME, PASSWORD) eq_(rv.status_code, 200) assert json.loads(rv.data.decode("utf-8")).get( API_SECURITY_ACCESS_TOKEN_KEY, False ) def test_auth_login_failed(self): """ REST Api: Test auth login failed """ client = self.app.test_client() rv = self._login(client, "fail", "fail") eq_(json.loads(rv.data), {"message": "Not authorized"}) eq_(rv.status_code, 401) def test_auth_login_bad(self): """ REST Api: Test auth login bad request """ client = self.app.test_client() rv = client.post("api/v1/security/login", data="BADADATA") eq_(rv.status_code, 400) def test_auth_authorization_browser(self): """ REST Api: Test auth with browser login """ client = self.app.test_client() rv = self.browser_login(client, USERNAME, PASSWORD) # Test access with browser login uri = "api/v1/model1browserlogin/1" rv = client.get(uri) eq_(rv.status_code, 200) # Test unauthorized access with browser login uri = "api/v1/model1api/1" rv = client.get(uri) eq_(rv.status_code, 401) # Test access wihout cookie or JWT rv = self.browser_logout(client) # Test access with browser login uri = "api/v1/model1browserlogin/1" rv = client.get(uri) eq_(rv.status_code, 401) # Test access with JWT but without cookie token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1browserlogin/1" rv = self.auth_client_get(client, token, uri) eq_(rv.status_code, 200) def test_auth_authorization(self): """ REST Api: Test auth base limited authorization """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # Test unauthorized DELETE pk = 1 uri = "api/v1/model1apirestrictedpermissions/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 401) # Test unauthorized POST item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1apirestrictedpermissions/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 401) # Test unauthorized GET uri = "api/v1/model1apirestrictedpermissions/1" rv = self.auth_client_get(client, token, uri) eq_(rv.status_code, 200) def test_get_item(self): """ REST Api: Test get item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) for i in range(1, MODEL1_DATA_SIZE): rv = self.auth_client_get(client, token, "api/v1/model1api/{}".format(i)) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 200) self.assert_get_item(rv, data, i - 1) def assert_get_item(self, rv, data, value): eq_( data[API_RESULT_RES_KEY], { "field_date": None, "field_float": float(value), "field_integer": value, "field_string": "test{}".format(value), }, ) # test descriptions eq_(data["description_columns"], self.model1api.description_columns) # test labels eq_( data[API_LABEL_COLUMNS_RES_KEY], { "field_date": "Field Date", "field_float": "Field Float", "field_integer": "Field Integer", "field_string": "Field String", }, ) eq_(rv.status_code, 200) def test_get_item_select_cols(self): """ REST Api: Test get item with select columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) for i in range(1, MODEL1_DATA_SIZE): uri = "api/v1/model1api/{}?q=({}:!(field_integer))".format( i, API_SELECT_COLUMNS_RIS_KEY ) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"field_integer": i - 1}) eq_( data[API_DESCRIPTION_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}, ) eq_(data[API_LABEL_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}) eq_(rv.status_code, 200) def test_get_item_select_meta_data(self): """ REST Api: Test get item select meta data """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) selectable_keys = [ API_DESCRIPTION_COLUMNS_RIS_KEY, API_LABEL_COLUMNS_RIS_KEY, API_SHOW_COLUMNS_RIS_KEY, API_SHOW_TITLE_RIS_KEY, ] for selectable_key in selectable_keys: argument = {API_SELECT_KEYS_RIS_KEY: [selectable_key]} uri = "api/v1/model1api/1?{}={}".format( API_URI_RIS_KEY, prison.dumps(argument) ) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data.keys()), 1 + 2) # always exist id, result # We assume that rison meta key equals result meta key assert selectable_key in data def test_get_item_excluded_cols(self): """ REST Api: Test get item with excluded columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 rv = self.auth_client_get( client, token, "api/v1/model1apiexcludecols/{}".format(pk) ) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"field_string": "test0"}) eq_(rv.status_code, 200) def test_get_item_not_found(self): """ REST Api: Test get item not found """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = MODEL1_DATA_SIZE + 1 rv = self.auth_client_get(client, token, "api/v1/model1api/{}".format(pk)) eq_(rv.status_code, 404) def test_get_item_base_filters(self): """ REST Api: Test get item with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # We can't get a base filtered item pk = 1 rv = self.auth_client_get( client, token, "api/v1/model1apifiltered/{}".format(pk) ) eq_(rv.status_code, 404) # This one is ok pk=4 field_integer=3 2>3<4 pk = 4 rv = self.auth_client_get( client, token, "api/v1/model1apifiltered/{}".format(pk) ) eq_(rv.status_code, 200) def test_get_item_1m_field(self): """ REST Api: Test get item with 1-N related field """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # We can't get a base filtered item pk = 1 rv = self.auth_client_get(client, token, "api/v1/model2api/{}".format(pk)) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 200) expected_rel_field = { "group": { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", "id": 1, } } eq_(data[API_RESULT_RES_KEY], expected_rel_field) def test_get_item_mm_field(self): """ REST Api: Test get item with N-N related field """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # We can't get a base filtered item pk = 1 rv = self.auth_client_get(client, token, "api/v1/modelmmapi/{}".format(pk)) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 200) expected_rel_field = [ {"field_string": "1", "id": 1}, {"field_string": "2", "id": 2}, {"field_string": "3", "id": 3}, ] eq_(data[API_RESULT_RES_KEY]["children"], expected_rel_field) def test_get_list(self): """ REST Api: Test get list """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) rv = self.auth_client_get(client, token, "api/v1/model1api/") data = json.loads(rv.data.decode("utf-8")) # Tests count property eq_(data["count"], MODEL1_DATA_SIZE) # Tests data result default page size eq_(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) for i in range(1, self.model1api.page_size): self.assert_get_list(rv, data[API_RESULT_RES_KEY][i - 1], i - 1) @staticmethod def assert_get_list(rv, data, value): eq_( data, { "field_date": None, "field_float": float(value), "field_integer": value, "field_string": "test{}".format(value), }, ) eq_(rv.status_code, 200) def test_get_list_order(self): """ REST Api: Test get list order params """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test string order asc arguments = {"order_column": "field_integer", "order_direction": "asc"} uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", }, ) eq_(rv.status_code, 200) # test string order desc arguments = {"order_column": "field_integer", "order_direction": "desc"} uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(MODEL1_DATA_SIZE - 1), "field_integer": MODEL1_DATA_SIZE - 1, "field_string": "test{}".format(MODEL1_DATA_SIZE - 1), }, ) eq_(rv.status_code, 200) def test_get_list_base_order(self): """ REST Api: Test get list with base order """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test string order asc rv = self.auth_client_get(client, token, "api/v1/model1apiorder/") data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(MODEL1_DATA_SIZE - 1), "field_integer": MODEL1_DATA_SIZE - 1, "field_string": "test{}".format(MODEL1_DATA_SIZE - 1), }, ) # Test override arguments = {"order_column": "field_integer", "order_direction": "asc"} uri = "api/v1/model1apiorder/?{}={}".format( API_URI_RIS_KEY, prison.dumps(arguments) ) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", }, ) def test_get_list_page(self): """ REST Api: Test get list page params """ page_size = 5 client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test page zero arguments = { "page_size": page_size, "page": 0, "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", }, ) eq_(rv.status_code, 200) eq_(len(data[API_RESULT_RES_KEY]), page_size) # test page one arguments = { "page_size": page_size, "page": 1, "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(page_size), "field_integer": page_size, "field_string": "test{}".format(page_size), }, ) eq_(rv.status_code, 200) eq_(len(data[API_RESULT_RES_KEY]), page_size) def test_get_list_max_page_size(self): """ REST Api: Test get list max page size config setting """ page_size = 100 # Max is globally set to MAX_PAGE_SIZE client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test page zero arguments = { "page_size": page_size, "page": 0, "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) print("URI {}".format(uri)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data[API_RESULT_RES_KEY]), MAX_PAGE_SIZE) def test_get_list_filters(self): """ REST Api: Test get list filter params """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) filter_value = 5 # test string order asc arguments = { API_FILTERS_RIS_KEY: [ {"col": "field_integer", "opr": "gt", "value": filter_value} ], "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(filter_value + 1), "field_integer": filter_value + 1, "field_string": "test{}".format(filter_value + 1), }, ) eq_(rv.status_code, 200) def test_get_list_select_cols(self): """ REST Api: Test get list with selected columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) argument = { API_SELECT_COLUMNS_RIS_KEY: ["field_integer"], "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(argument)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY][0], {"field_integer": 0}) eq_(data[API_LABEL_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}) eq_(data[API_DESCRIPTION_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}) eq_(data[API_LIST_COLUMNS_RES_KEY], ["field_integer"]) eq_(rv.status_code, 200) def test_get_list_select_meta_data(self): """ REST Api: Test get list select meta data """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) selectable_keys = [ API_DESCRIPTION_COLUMNS_RIS_KEY, API_LABEL_COLUMNS_RIS_KEY, API_ORDER_COLUMNS_RIS_KEY, API_LIST_COLUMNS_RIS_KEY, API_LIST_TITLE_RIS_KEY, ] for selectable_key in selectable_keys: argument = {API_SELECT_KEYS_RIS_KEY: [selectable_key]} uri = "api/v1/model1api/?{}={}".format( API_URI_RIS_KEY, prison.dumps(argument) ) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data.keys()), 1 + 3) # always exist count, ids, result # We assume that rison meta key equals result meta key assert selectable_key in data def test_get_list_exclude_cols(self): """ REST Api: Test get list with excluded columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1apiexcludecols/" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY][0], {"field_string": "test0"}) def test_get_list_base_filters(self): """ REST Api: Test get list with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) arguments = {"order_column": "field_integer", "order_direction": "desc"} uri = "api/v1/model1apifiltered/?{}={}".format( API_URI_RIS_KEY, prison.dumps(arguments) ) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_result = [ { "field_date": None, "field_float": 3.0, "field_integer": 3, "field_string": "test3", } ] eq_(data[API_RESULT_RES_KEY], expected_result) def test_info_filters(self): """ REST Api: Test info filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1api/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_filters = { "field_date": [ {"name": "Equal to", "operator": "eq"}, {"name": "Greater than", "operator": "gt"}, {"name": "Smaller than", "operator": "lt"}, {"name": "Not Equal to", "operator": "neq"}, ], "field_float": [ {"name": "Equal to", "operator": "eq"}, {"name": "Greater than", "operator": "gt"}, {"name": "Smaller than", "operator": "lt"}, {"name": "Not Equal to", "operator": "neq"}, ], "field_integer": [ {"name": "Equal to", "operator": "eq"}, {"name": "Greater than", "operator": "gt"}, {"name": "Smaller than", "operator": "lt"}, {"name": "Not Equal to", "operator": "neq"}, ], "field_string": [ {"name": "Starts with", "operator": "sw"}, {"name": "Ends with", "operator": "ew"}, {"name": "Contains", "operator": "ct"}, {"name": "Equal to", "operator": "eq"}, {"name": "Not Starts with", "operator": "nsw"}, {"name": "Not Ends with", "operator": "new"}, {"name": "Not Contains", "operator": "nct"}, {"name": "Not Equal to", "operator": "neq"}, ], } eq_(data["filters"], expected_filters) def test_info_fields(self): """ REST Api: Test info fields (add, edit) """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1apifieldsinfo/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expect_add_fields = [ { "description": "Field Integer", "label": "Field Integer", "name": "field_integer", "required": False, "unique": False, "type": "Integer", }, { "description": "Field Float", "label": "Field Float", "name": "field_float", "required": False, "unique": False, "type": "Float", }, { "description": "Field String", "label": "Field String", "name": "field_string", "required": True, "unique": True, "type": "String", "validate": ["<Length(min=None, max=50, equal=None, error=None)>"], }, { "description": "", "label": "Field Date", "name": "field_date", "required": False, "unique": False, "type": "Date", }, ] expect_edit_fields = list() for edit_col in self.model1apifieldsinfo.edit_columns: for item in expect_add_fields: if item["name"] == edit_col: expect_edit_fields.append(item) eq_(data[API_ADD_COLUMNS_RES_KEY], expect_add_fields) eq_(data[API_EDIT_COLUMNS_RES_KEY], expect_edit_fields) def test_info_fields_rel_field(self): """ REST Api: Test info fields with related fields """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model2api/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_rel_add_field = { "count": MODEL2_DATA_SIZE, "description": "", "label": "Group", "name": "group", "required": True, "unique": False, "type": "Related", "values": [], } for i in range(self.model2api.page_size): expected_rel_add_field["values"].append( {"id": i + 1, "value": "test{}".format(i)} ) for rel_field in data[API_ADD_COLUMNS_RES_KEY]: if rel_field["name"] == "group": eq_(rel_field, expected_rel_add_field) def test_info_fields_rel_filtered_field(self): """ REST Api: Test info fields with filtered related fields """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model2apifilteredrelfields/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_rel_add_field = { "description": "", "label": "Group", "name": "group", "required": True, "unique": False, "type": "Related", "count": 1, "values": [{"id": 4, "value": "test3"}], } for rel_field in data[API_ADD_COLUMNS_RES_KEY]: if rel_field["name"] == "group": eq_(rel_field, expected_rel_add_field) for rel_field in data[API_EDIT_COLUMNS_RES_KEY]: if rel_field["name"] == "group": eq_(rel_field, expected_rel_add_field) def test_info_permissions(self): """ REST Api: Test info permissions """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1api/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_permissions = [ "can_delete", "can_get", "can_info", "can_post", "can_put", ] eq_(sorted(data[API_PERMISSIONS_RES_KEY]), expected_permissions) uri = "api/v1/model1apirestrictedpermissions/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_permissions = ["can_get", "can_info"] eq_(sorted(data[API_PERMISSIONS_RES_KEY]), expected_permissions) def test_info_select_meta_data(self): """ REST Api: Test info select meta data """ # select meta for add fields client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) selectable_keys = [ API_ADD_COLUMNS_RIS_KEY, API_EDIT_COLUMNS_RIS_KEY, API_PERMISSIONS_RIS_KEY, API_FILTERS_RIS_KEY, API_ADD_TITLE_RIS_KEY, API_EDIT_TITLE_RIS_KEY, ] for selectable_key in selectable_keys: arguments = {API_SELECT_KEYS_RIS_KEY: [selectable_key]} uri = "api/v1/model1api/_info?{}={}".format( API_URI_RIS_KEY, prison.dumps(arguments) ) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data.keys()), 1) # We assume that rison meta key equals result meta key assert selectable_key in data def test_delete_item(self): """ REST Api: Test delete item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 2 uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model, None) def test_delete_item_not_found(self): """ REST Api: Test delete item not found """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = MODEL1_DATA_SIZE + 1 uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 404) def test_delete_item_base_filters(self): """ REST Api: Test delete item with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # Try to delete a filtered item pk = 1 uri = "api/v1/model1apifiltered/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 404) def test_update_item(self): """ REST Api: Test update item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 3 item = dict(field_string="test_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model.field_string, "test_Put") eq_(model.field_integer, 0) eq_(model.field_float, 0.0) def test_update_custom_validation(self): """ REST Api: Test update item custom validation """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 3 item = dict(field_string="test_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1customvalidationapi/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) pk = 3 item = dict(field_string="Atest_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1customvalidationapi/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) def test_update_item_base_filters(self): """ REST Api: Test update item with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 4 item = dict(field_string="test_Put", field_integer=3, field_float=3.0) uri = "api/v1/model1apifiltered/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model.field_string, "test_Put") eq_(model.field_integer, 3) eq_(model.field_float, 3.0) # We can't update an item that is base filtered pk = 1 uri = "api/v1/model1apifiltered/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 404) def test_update_item_not_found(self): """ REST Api: Test update item not found """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = MODEL1_DATA_SIZE + 1 item = dict(field_string="test_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 404) def test_update_val_size(self): """ REST Api: Test update validate size """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 field_string = "a" * 51 item = dict(field_string=field_string, field_integer=11, field_float=11.0) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Longer than maximum length 50.") def test_update_mm_field(self): """ REST Api: Test update m-m field """ model = ModelMMChild() model.field_string = "update_m,m" self.appbuilder.get_session.add(model) self.appbuilder.get_session.commit() client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 item = dict(children=[4]) uri = "api/v1/modelmmapi/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"children": [4], "field_string": "0"}) def test_update_item_val_type(self): """ REST Api: Test update validate type """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer="test{}".format(MODEL1_DATA_SIZE + 1), field_float=11.0, ) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_integer"][0], "Not a valid integer.") item = dict(field_string=11, field_integer=11, field_float=11.0) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Not a valid string.") def test_update_item_excluded_cols(self): """ REST Api: Test update item with excluded cols """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 item = dict(field_string="test_Put", field_integer=1000) uri = "api/v1/model1apiexcludecols/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model.field_integer, 0) eq_(model.field_float, 0.0) eq_(model.field_date, None) def test_create_item(self): """ REST Api: Test create item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1api/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 201) eq_(data[API_RESULT_RES_KEY], item) model = ( self.db.session.query(Model1) .filter_by(field_string="test{}".format(MODEL1_DATA_SIZE + 1)) .first() ) eq_(model.field_string, "test{}".format(MODEL1_DATA_SIZE + 1)) eq_(model.field_integer, MODEL1_DATA_SIZE + 1) eq_(model.field_float, float(MODEL1_DATA_SIZE + 1)) def test_create_item_custom_validation(self): """ REST Api: Test create item custom validation """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1customvalidationapi/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 422) eq_(data, {"message": {"field_string": ["Name must start with an A"]}}) item = dict( field_string="A{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1customvalidationapi/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 201) def test_create_item_val_size(self): """ REST Api: Test create validate size """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) field_string = "a" * 51 item = dict( field_string=field_string, field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), ) uri = "api/v1/model1api/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Longer than maximum length 50.") def test_create_item_val_type(self): """ REST Api: Test create validate type """ # Test integer as string client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE), field_integer="test{}".format(MODEL1_DATA_SIZE), field_float=float(MODEL1_DATA_SIZE), ) uri = "api/v1/model1api/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_integer"][0], "Not a valid integer.") # Test string as integer item = dict( field_string=MODEL1_DATA_SIZE, field_integer=MODEL1_DATA_SIZE, field_float=float(MODEL1_DATA_SIZE), ) rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Not a valid string.") def test_create_item_excluded_cols(self): """ REST Api: Test create with excluded columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict(field_string="test{}".format(MODEL1_DATA_SIZE + 1)) uri = "api/v1/model1apiexcludecols/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 201) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 2), field_integer=MODEL1_DATA_SIZE + 2, ) rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 201) model = ( self.db.session.query(Model1) .filter_by(field_string="test{}".format(MODEL1_DATA_SIZE + 1)) .first() ) eq_(model.field_integer, None) eq_(model.field_float, None) eq_(model.field_date, None) def test_create_item_with_enum(self): """ REST Api: Test create item with enum """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict(enum2="e1") uri = "api/v1/modelwithenumsapi/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 201) model = self.db.session.query(ModelWithEnums).get(data["id"]) eq_(model.enum2, TmpEnum.e1) def test_create_item_mm_field(self): """ REST Api: Test create with M-M field """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string='new1', children=[1, 2] ) uri = "api/v1/modelmmapi/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 201) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"children": [1, 2], "field_string": "new1"}) # Test without M-M field data, default is not required item = dict( field_string='new2' ) uri = "api/v1/modelmmapi/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 201) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"children": [], "field_string": "new2"}) # Test without M-M field data, default is required item = dict( field_string='new1' ) uri = "api/v1/modelmmrequiredapi/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data, {"message": {"children": ["Missing data for required field."]}}) def test_get_list_col_function(self): """ REST Api: Test get list of objects with columns as functions """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1funcapi/" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) # Tests count property eq_(data["count"], MODEL1_DATA_SIZE) # Tests data result default page size eq_(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) for i in range(1, self.model1api.page_size): item = data[API_RESULT_RES_KEY][i - 1] eq_( item["full_concat"], "{}.{}.{}.{}".format("test" + str(i - 1), i - 1, float(i - 1), None), ) def test_openapi(self): """ REST Api: Test OpenAPI spec """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/_openapi" rv = self.auth_client_get(client, token, uri) eq_(rv.status_code, 200)
class TestAPI(TestCase): def setUp(self): self.app = Flask(__name__) self.scheduler = APScheduler() self.scheduler.api_enabled = True self.scheduler.init_app(self.app) self.scheduler.start() self.client = self.app.test_client() def test_scheduler_info(self): response = self.client.get(self.scheduler.api_prefix) self.assertEqual(response.status_code, 200) info = json.loads(response.get_data(as_text=True)) self.assertIsNotNone(info['current_host']) self.assertEqual(info['allowed_hosts'], ['*']) self.assertTrue(info['running']) def test_add_job(self): job = { 'id': 'job1', 'func': 'tests.test_api:job1', 'trigger': 'date', 'run_date': '2025-12-01T12:30:01+00:00', } response = self.client.post(self.scheduler.api_prefix + '/jobs', data=json.dumps(job)) self.assertEqual(response.status_code, 200) job2 = json.loads(response.get_data(as_text=True)) self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('run_date'), job2.get('run_date')) def test_add_conflicted_job(self): job = { 'id': 'job1', 'func': 'tests.test_api:job1', 'trigger': 'date', 'run_date': '2025-12-01T12:30:01+00:00', } response = self.client.post(self.scheduler.api_prefix + '/jobs', data=json.dumps(job)) self.assertEqual(response.status_code, 200) response = self.client.post(self.scheduler.api_prefix + '/jobs', data=json.dumps(job)) self.assertEqual(response.status_code, 409) def test_add_invalid_job(self): job = { 'id': None, } response = self.client.post(self.scheduler.api_prefix + '/jobs', data=json.dumps(job)) self.assertEqual(response.status_code, 500) def test_delete_job(self): self.__add_job() response = self.client.delete(self.scheduler.api_prefix + '/jobs/job1') self.assertEqual(response.status_code, 204) response = self.client.get(self.scheduler.api_prefix + '/jobs/job1') self.assertEqual(response.status_code, 404) def test_delete_job_not_found(self): response = self.client.delete(self.scheduler.api_prefix + '/jobs/job1') self.assertEqual(response.status_code, 404) def test_get_job(self): job = self.__add_job() response = self.client.get(self.scheduler.api_prefix + '/jobs/job1') self.assertEqual(response.status_code, 200) job2 = json.loads(response.get_data(as_text=True)) self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('minutes'), job2.get('minutes')) def test_get_job_not_found(self): response = self.client.get(self.scheduler.api_prefix + '/jobs/job1') self.assertEqual(response.status_code, 404) def test_get_all_jobs(self): job = self.__add_job() response = self.client.get(self.scheduler.api_prefix + '/jobs') self.assertEqual(response.status_code, 200) jobs = json.loads(response.get_data(as_text=True)) self.assertEqual(len(jobs), 1) job2 = jobs[0] self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('minutes'), job2.get('minutes')) def test_update_job(self): job = self.__add_job() data_to_update = { 'args': [1], 'trigger': 'cron', 'minute': '*/1', 'start_date': '2025-01-01' # means midnight local time } response = self.client.patch(self.scheduler.api_prefix + '/jobs/job1', data=json.dumps(data_to_update)) self.assertEqual(response.status_code, 200) job2 = json.loads(response.get_data(as_text=True)) self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(data_to_update.get('args'), job2.get('args')) self.assertEqual(data_to_update.get('trigger'), job2.get('trigger')) self.assertEqual( datetime(2025, 1, 1, tzinfo=tzlocal()).isoformat(), job2.get('start_date')) self.assertEqual( datetime(2025, 1, 1, tzinfo=tzlocal()).isoformat(), job2.get('next_run_time')) def test_update_job_not_found(self): data_to_update = { 'args': [1], 'trigger': 'cron', 'minute': '*/1', 'start_date': '2025-01-01' } response = self.client.patch(self.scheduler.api_prefix + '/jobs/job1', data=json.dumps(data_to_update)) self.assertEqual(response.status_code, 404) def test_update_invalid_job(self): self.__add_job() data_to_update = { 'trigger': 'invalid_trigger', } response = self.client.patch(self.scheduler.api_prefix + '/jobs/job1', data=json.dumps(data_to_update)) self.assertEqual(response.status_code, 500) def test_pause_and_resume_job(self): self.__add_job() response = self.client.post(self.scheduler.api_prefix + '/jobs/job1/pause') self.assertEqual(response.status_code, 200) job = json.loads(response.get_data(as_text=True)) self.assertIsNone(job.get('next_run_time')) response = self.client.post(self.scheduler.api_prefix + '/jobs/job1/resume') self.assertEqual(response.status_code, 200) job = json.loads(response.get_data(as_text=True)) self.assertIsNotNone(job.get('next_run_time')) def test_pause_and_resume_job_not_found(self): response = self.client.post(self.scheduler.api_prefix + '/jobs/job1/pause') self.assertEqual(response.status_code, 404) response = self.client.post(self.scheduler.api_prefix + '/jobs/job1/resume') self.assertEqual(response.status_code, 404) def test_run_job(self): self.__add_job() response = self.client.post(self.scheduler.api_prefix + '/jobs/job1/run') self.assertEqual(response.status_code, 200) job = json.loads(response.get_data(as_text=True)) self.assertIsNotNone(job.get('next_run_time')) def test_run_job_not_found(self): response = self.client.post(self.scheduler.api_prefix + '/jobs/job1/run') self.assertEqual(response.status_code, 404) def __add_job(self): job = { 'id': 'job1', 'func': 'tests.test_api:job1', 'trigger': 'interval', 'minutes': 10, } response = self.client.post(self.scheduler.api_prefix + '/jobs', data=json.dumps(job)) return json.loads(response.get_data(as_text=True))
from flask import Flask from main import rest_methods import json import unittest app = Flask(__name__) app.register_blueprint(rest_methods) test_client = app.test_client() class TestRESTMethods(unittest.TestCase): # testing a get request def test_get_method(self): with test_client.get("/one/") as response: self.assertEqual( json.loads(response.get_data(as_text=True)), { "request_method": "GET", "request_variables": { "id": "one" }, "request_args": {}, "request_data": {} }) # testing a post request def test_post_method(self):
class TestUserClaimsVerification(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.secret_key = 'super=secret' self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() @self.jwt_manager.claims_verification_loader def claims_verification(user_claims): expected_keys = ['foo', 'bar'] for key in expected_keys: if key not in user_claims: return False return True @self.app.route('/auth/login', methods=['POST']) def login(): ret = {'access_token': create_access_token('test')} return jsonify(ret), 200 @self.app.route('/protected') @jwt_required def protected(): return jsonify({'msg': "hello world"}) def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'): header_type = '{} {}'.format(header_type, jwt).strip() response = self.client.get(url, headers={header_name: header_type}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) return status_code, data def test_valid_user_claims(self): @self.jwt_manager.user_claims_loader def user_claims_callback(identity): return {'foo': 'baz', 'bar': 'boom'} response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] status, data = self._jwt_get('/protected', access_token) self.assertEqual(data, {'msg': 'hello world'}) self.assertEqual(status, 200) def test_empty_claims_verification_error(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] status, data = self._jwt_get('/protected', access_token) self.assertEqual(data, {'msg': 'User claims verification failed'}) self.assertEqual(status, 400) def test_bad_claims_verification_error(self): @self.jwt_manager.user_claims_loader def user_claims_callback(identity): return {'super': 'banana'} response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] status, data = self._jwt_get('/protected', access_token) self.assertEqual(data, {'msg': 'User claims verification failed'}) self.assertEqual(status, 400) def test_bad_claims_custom_error_callback(self): @self.jwt_manager.claims_verification_failed_loader def user_claims_callback(): return jsonify({'foo': 'bar'}), 404 response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] status, data = self._jwt_get('/protected', access_token) self.assertEqual(data, {'foo': 'bar'}) self.assertEqual(status, 404)
class TestIwebModule(unittest.TestCase): def setUp(self): self.data_store = DataStore() self.iweb = IWeb(self.data_store) self.app = Flask(__name__) self.base_iweb_api = UbersmithBase(self.data_store) self.iweb.hook_to(self.base_iweb_api) self.base_iweb_api.hook_to(self.app) def test_log_event_successfully(self): with self.app.test_client() as c: resp = c.post('api/2.0/', data={ "method": "iweb.log_event", "event_type": "Some event type", "reference_type": "client", "action": "Client Id 12345 performed a 'method'.", "clientid": "clientid", "user": "******", "reference_id": "clientid" }) self.assertEqual(resp.status_code, 200) self.assertEqual(json.loads(resp.data.decode('utf-8')), { "data": "1", "error_code": None, "error_message": "", "status": True }) self.assertEqual( self.data_store.event_log[0], { "event_type": "Some event type", "reference_type": "client", "action": "Client Id 12345 performed a 'method'.", "clientid": "clientid", "user": "******", "reference_id": "clientid" }) def test_add_role_successfully(self): with self.app.test_client() as c: resp = c.post('api/2.0/', data={ "method": "iweb.acl_admin_role_add", 'name': 'A Admin Role', 'descr': 'A Admin Role', 'acls[admin.portal][read]': 1, }) role_id = next(iter(self.data_store.roles)) self.assertEqual(resp.status_code, 200) self.assertEqual( json.loads(resp.data.decode('utf-8')), { "data": role_id, "error_code": None, "error_message": "", "status": True }) self.assertEqual( self.data_store.roles, { role_id: { 'role_id': role_id, 'name': 'A Admin Role', 'descr': 'A Admin Role', 'acls': { 'admin.portal': { 'read': '1' } } } }) def test_add_role_that_already_exists_fails(self): self.data_store.roles = { 'some_role_id': { 'role_id': 'some_role_id', 'name': 'A Admin Role' } } with self.app.test_client() as c: resp = c.post('api/2.0/', data={ "method": "iweb.acl_admin_role_add", 'name': 'A Admin Role', 'descr': 'A Admin Role', 'acls[admin.portal][read]': 1, }) self.assertEqual(resp.status_code, 200) self.assertEqual( json.loads(resp.data.decode('utf-8')), { "data": "", "error_code": 1, "error_message": "The specified Role Name is already in use", "status": False }) def test_add_user_role_successfully(self): user_id = 'some_user_id' self.data_store.roles = {'a_role_id': {}, 'another_role_id': {}} with self.app.test_client() as c: c.post('api/2.0/', data={ "method": "iweb.user_role_assign", "user_id": user_id, "role_id": "a_role_id" }) resp = c.post('api/2.0/', data={ "method": "iweb.user_role_assign", "user_id": user_id, "role_id": "another_role_id" }) self.assertEqual(resp.status_code, 200) self.assertEqual(json.loads(resp.data.decode('utf-8')), { "status": True, "error_code": None, "error_message": "", "data": 1 }) self.assertEqual(self.data_store.user_mapping[user_id], {'roles': {'a_role_id', 'another_role_id'}}) def test_add_same_role_to_user_not_allowed(self): role_id = 'some_role_id' user_id = 'some_user_id' self.data_store.roles = {'1': {}} self.data_store.user_mapping[user_id] = {'roles': {role_id}} with self.app.test_client() as c: resp = c.post('api/2.0/', data={ "method": "iweb.user_role_assign", "user_id": user_id, "role_id": role_id }) self.assertEqual(resp.status_code, 200) self.assertEqual( json.loads(resp.data.decode('utf-8')), { "error_code": 1, "error_message": "Can't assign role with id '{}' to user " "with id '{}'".format(role_id, user_id), "status": False, "data": "" })
def setUp(self): app = Flask(__name__) app.register_blueprint(InternalBlueprint) app.testing = True self.app = app.test_client()
class BootstrapTestCase(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.testing = True self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False self.app.secret_key = 'for test' self.bootstrap = Bootstrap(self.app) # noqa @self.app.route('/') def index(): return render_template_string( '{{ bootstrap.load_css() }}{{ bootstrap.load_js() }}') self.context = self.app.test_request_context() self.context.push() self.client = self.app.test_client() def tearDown(self): self.context.pop() def test_extension_init(self): self.assertIn('bootstrap', current_app.extensions) def test_load_css(self): rv = self.bootstrap.load_css() self.assertIn('bootstrap.min.css', rv) def test_load_js(self): rv = self.bootstrap.load_js() self.assertIn('bootstrap.min.js', rv) def test_local_resources(self): current_app.config['BOOTSTRAP_SERVE_LOCAL'] = True response = self.client.get('/') data = response.get_data(as_text=True) self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', data) self.assertIn('bootstrap.min.js', data) self.assertIn('bootstrap.min.css', data) self.assertIn('jquery.min.js', data) css_response = self.client.get( '/bootstrap/static/css/bootstrap.min.css') js_response = self.client.get('/bootstrap/static/js/bootstrap.min.js') jquery_response = self.client.get('/bootstrap/static/jquery.min.js') self.assertNotEqual(css_response.status_code, 404) self.assertNotEqual(js_response.status_code, 404) self.assertNotEqual(jquery_response.status_code, 404) css_rv = self.bootstrap.load_css() js_rv = self.bootstrap.load_js() self.assertIn('/bootstrap/static/css/bootstrap.min.css', css_rv) self.assertIn('/bootstrap/static/js/bootstrap.min.js', js_rv) self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv) self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv) def test_cdn_resources(self): response = self.client.get('/') data = response.get_data(as_text=True) self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', data) self.assertIn('bootstrap.min.js', data) self.assertIn('bootstrap.min.css', data) css_rv = self.bootstrap.load_css() js_rv = self.bootstrap.load_js() self.assertNotIn('/bootstrap/static/css/bootstrap.min.css', css_rv) self.assertNotIn('/bootstrap/static/js/bootstrap.min.js', js_rv) self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv) self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv) def test_render_field(self): @self.app.route('/field') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_field %} {{ render_field(form.username) }} {{ render_field(form.password) }} ''', form=form) response = self.client.get('/field') data = response.get_data(as_text=True) self.assertIn( '<input class="form-control" id="username" name="username"', data) self.assertIn( '<input class="form-control" id="password" name="password"', data) def test_render_form(self): @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn( '<input class="form-control" id="username" name="username"', data) self.assertIn( '<input class="form-control" id="password" name="password"', data) def test_render_pager(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) @self.app.route('/pager') def test(): db.drop_all() db.create_all() for i in range(100): m = Message() db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items return render_template_string(''' {% from 'bootstrap/pagination.html' import render_pager %} {{ render_pager(pagination) }} ''', pagination=pagination, messages=messages) response = self.client.get('/pager') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn('Previous', data) self.assertIn('Next', data) self.assertIn('<li class="page-item disabled">', data) response = self.client.get('/pager?page=2') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn('Previous', data) self.assertIn('Next', data) self.assertNotIn('<li class="page-item disabled">', data) def test_render_pagination(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) @self.app.route('/pagination') def test(): db.drop_all() db.create_all() for i in range(100): m = Message() db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items return render_template_string(''' {% from 'bootstrap/pagination.html' import render_pagination %} {{ render_pagination(pagination) }} ''', pagination=pagination, messages=messages) response = self.client.get('/pagination') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn( '<a class="page-link" href="#">1 <span class="sr-only">(current)</span></a>', data) self.assertIn('10</a>', data) response = self.client.get('/pagination?page=2') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn('1</a>', data) self.assertIn( '<a class="page-link" href="#">2 <span class="sr-only">(current)</span></a>', data) self.assertIn('10</a>', data) def test_render_nav_item(self): @self.app.route('/nav_item') def test(): return render_template_string(''' {% from 'bootstrap/nav.html' import render_nav_item %} {{ render_nav_item('test', 'Home') }} ''') response = self.client.get('/nav_item') data = response.get_data(as_text=True) self.assertIn('<a class="nav-item nav-link active"', data) def test_render_breadcrumb_item(self): @self.app.route('/breadcrumb_item') def test(): return render_template_string(''' {% from 'bootstrap/nav.html' import render_breadcrumb_item %} {{ render_breadcrumb_item('test', 'Home') }} ''') response = self.client.get('/breadcrumb_item') data = response.get_data(as_text=True) self.assertIn( '<li class="breadcrumb-item active" aria-current="page">', data) def test_render_static(self): @self.app.route('/test_static') def test(): return render_template_string(''' {% from 'bootstrap/utils.html' import render_static %} {{ render_static('css', 'test.css') }} {{ render_static('js', 'test.js') }} {{ render_static('icon', 'test.ico') }} ''') response = self.client.get('/test_static') data = response.get_data(as_text=True) self.assertIn( '<link rel="stylesheet" href="/static/test.css" type="text/css">', data) self.assertIn( '<script type="text/javascript" src="/static/test.js"></script>', data) self.assertIn('<link rel="icon" href="/static/test.ico">', data)
class WebServer: def __init__(self, database): self.db = database self.app = Flask(__name__) UPLOAD_FOLDER = "\\uploads\\sounds" self.app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER self.app.secret_key = SECRET_KEY self.set_routes() def get_app(self): return self.app.test_client() def start(self): self.app.run(debug=False, host='0.0.0.0') def shutdown_server(self): func = request.environ.get('werkzeug.server.shutdown') if func is None: raise RuntimeError('Not running with the Werkzeug Server') func() #Returns true if filename has an allowed extension def allowed_file(filename): return WebServer.get_file_ext(filename) in ALLOWED_EXTENSIONS def get_file_ext(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() def set_routes(self): @self.app.route("/") def index(): user = self.db.get_user( id=session['userID']) if 'userID' in session else None return render_template('index.html', user=user) @self.app.route('/signup', methods=['POST', 'GET']) def signup(): if request.method == 'POST': email = request.form['email'] if not self.db.is_email_used(email): salt = uuid.uuid4().hex if len(request.form['password']) < 8: return render_template('signup.html', message="short-pass") if request.form['password'] == request.form['confpass']: pass_hash = hashlib.sha512( str(request.form['password'] + salt).encode('utf-8')).hexdigest() user = User(email, pass_hash, salt) user = self.db.insert_user(user) session['userID'] = user.id return redirect('/') else: return render_template('signup.html', message="no-match-pass") else: return render_template('signup.html', message="used-email") else: return render_template('signup.html', message="") @self.app.route('/login', methods=['POST', 'GET']) def login(): if request.method == 'POST': email = request.form['email'] if self.db.is_email_used(email): user = self.db.get_user(email=email) input_pass_hash = hashlib.sha512( str(request.form['password'] + user.salt).encode('utf-8')).hexdigest() if input_pass_hash == user.hash: session['userID'] = user.id if 'ref_url' in session: ref = session['ref_url'] session.pop('ref_url', None) return redirect('/' + ref) else: return redirect('/') else: # user = self.db.get_user(id=session['userID']) if 'userID' in session else None return render_template('login.html', message="wrong-pass") else: # user = self.db.get_user(id=session['userID']) if 'userID' in session else None return render_template('login.html', message="no-match-email") else: # GET if 'ref' in request.args: session['ref_url'] = request.args.get('ref') # user = self.db.get_user(id=session['userID']) if 'userID' in session else None return render_template('login.html', message="") @self.app.route('/logout', methods=['GET']) def logout(): session.pop('userID', None) return redirect('/') @self.app.route('/dashboard', methods=['GET']) def dashboard(): if 'userID' not in session: return redirect('/login?ref=dashboard') simulations = self.db.get_user_sims(session['userID']) sounds = self.db.get_user_sounds(session['userID']) robots = self.db.get_user_robots(session['userID']) microphones = self.db.get_user_mics(session['userID']) user_added_items = self.db.get_all( 'SELECT * FROM user_added_items WHERE userID = ?', [session['userID']], type=UserAddedItem) added_items = [ self.db.get_one('SELECT * FROM public_items WHERE id = ?', [x.itemID], type=PublicItem) for x in user_added_items ] add_count = { x.id: len( self.db.get_all( 'SELECT * FROM user_added_items WHERE itemID = ?', [x.id], type=UserAddedItem)) for x in added_items } return render_template('dashboard.html', user=self.db.get_user(id=session['userID']), \ simulations=simulations, sounds=sounds, robots=robots, microphones=microphones, addedItems=added_items, addCount=add_count) @self.app.route('/simulator', methods=['GET']) def simulator(): if 'userID' not in session: return redirect('/login?ref=simulator') simconf = "" if 'sim' in request.args: sim = self.db.get_simulation(request.args['sim']) if sim.userID == session['userID']: simconf = open(sim.pathToConfig).read() sounds = self.db.get_user_sounds(session['userID']) robots = self.db.get_user_robots(session['userID']) return render_template('simulator.html', user=self.db.get_user(id=session['userID']), sounds=sounds, robots=robots, simconf=simconf) @self.app.route("/removesimulation", methods=['POST']) def remove_sim(): if 'userID' not in session: return jsonify({'success': 'false'}) sim = self.db.get_simulation(request.form['sim_id']) try: os.remove(sim.pathToConfig) os.remove(sim.pathToZip) except: print("Cannot delete simulation files!") self.db.delete_simulation(request.form['sim_id']) return jsonify({'success': 'true'}) @self.app.route("/removesound", methods=['POST']) def remove_sound(): if 'userID' not in session: return jsonify({'success': 'false'}) sound = self.db.get_sound(request.form['sound_id']) try: os.remove(sound.pathToFile) except: print("Cannot delete sound file!") self.db.delete_sound(request.form['sound_id']) return jsonify({'success': 'true'}) @self.app.route("/removerobot", methods=['POST']) def remove_robot(): if 'userID' not in session: return jsonify({'success': 'false'}) robot = self.db.get_robot(request.form['robot_id']) try: os.remove(robot.pathToConfig) except: print("Cannot delete robot file!") self.db.delete_robot(request.form['robot_id']) return jsonify({'success': 'true'}) @self.app.route("/removemic", methods=['POST']) def remove_mic(): if 'userID' not in session: return jsonify({'success': 'false'}) mic = self.db.get_microphone(request.form['mic_id']) try: os.remove(mic.pathToConfig) except: print("Cannot delete microphone file!") self.db.delete_microphone(request.form['mic_id']) return jsonify({'success': 'true'}) @self.app.route('/revoke_simulation', methods=['POST']) def revoke_simulation(): if 'userID' not in session: return jsonify({"success": "false"}) simid = request.form['sim_id'] sim = self.db.get_simulation(simid) endTask(sim.taskID) self.db.run_query("UPDATE simulations SET state = ? WHERE id = ?", ("cancelled", simid)) return jsonify({"success": "true"}) @self.app.route('/simulator/run_simulation', methods=['POST']) def run_simulation(): if 'userID' not in session: return jsonify({"success": "false", "reason": "no user found"}) if 'sim_id' in request.form: old_sim = self.db.get_simulation(request.form['sim_id']) sim_file = open(old_sim.pathToConfig).read() sim_conf = json.loads(sim_file) else: sim_conf = json.loads(request.form['config']) # Link variables sim_conf['simulation_config'] = util.link_vars( sim_conf['simulation_config'], sim_conf['variables']) seed = str(uuid.uuid4()) sim_conf['simulation_config']['seed'] = seed if 'sim_to_update' in request.form: sim = self.db.get_simulation(request.form['sim_to_update']) if sim.userID == session['userID']: filename = sim.pathToConfig with open(filename, 'w') as f: json.dump(sim_conf, f, sort_keys=False, indent=4, ensure_ascii=False) return jsonify({"success": "true"}) else: return jsonify({"success": "false"}) else: # Save the JSON config to a file unique_name = uuid.uuid4() filename = UPLOAD_DIR + "simulation_configs/{0}.json".format( unique_name) print("putting sim file in: {0}".format(filename)) with open(filename, 'w') as f: json.dump(sim_conf, f, sort_keys=False, indent=4, ensure_ascii=False) sounds = {'utterance': None, 'bgnoise': None} # Check for utterance ID if 'utterance' in request.files: sounds['utterance'] = processFileUpload( request.files['utterance'], session['userID']) else: sounds['utterance'] = self.db.get_sound( request.form['utterance_id']) # Checking the utterance file. utt = sounds['utterance'] if utt is None or ( utt.userID != session['userID'] and not self.db.item_is_public(utt.id, "SOUND")): return jsonify({ "success": "false", "reason": "No sound found." }) if 'bgnoise' in request.files: sounds['bgnoise'] = processFileUpload( request.files['bgnoise'], session['userID']) else: if int(request.form['bgnoise_id']) >= 0: sounds['bgnoise'] = self.db.get_sound( request.form['bgnoise_id']) bgn = sounds['bgnoise'] if bgn is not None and ( bgn.userID != session['userID'] and not self.db.item_is_public(bgn.id, "SOUND")): return jsonify({ "success": "false", "reason": "No background noise sound found." }) sounds['utterance'] = sounds['utterance'].pathToFile sounds['bgnoise'] = sounds['bgnoise'].pathToFile if sounds[ 'bgnoise'] is not None else None robotid = request.form['robot_id'] robot = self.db.get_one("SELECT * FROM robots WHERE id =?", [robotid], type=Robot) if robot is None or ( robot.userID != session['userID'] and not self.db.item_is_public(robot.id, "ROBOT")): print("Robot: {0}".format(robot)) if robot is not None: print("userID {0}, sessionID {1}".format( robot.userID, session['userID'])) print("public item {0}".format( self.db.item_is_public(robot.id, "ROBOT"))) return jsonify({ "success": "false", "reason": "No robot found." }) robot_conf = open(robot.pathToConfig).read() robot_conf_dict = json.loads(robot_conf) date = str(dt.now().date()) sim = Simulation(filename, date, seed, session['userID']) sim = self.db.insert_simulation(sim) print("Running sim with: ", sim_cong, robot_conf_dict, sounds) runSimulation.delay(sim_conf, robot_conf_dict, sounds, unique_name, sim.id) return jsonify({"success": "true"}) def processFileUpload(file, user_id): # if user does not select file, browser also submit a empty part without filename if file.filename == '': # print("no file selected in uploads!") return None if file and WebServer.allowed_file(file.filename): unique_name = uuid.uuid4() if WebServer.get_file_ext(file.filename) == 'wav': savename = UPLOAD_DIR + 'sounds/{0}.wav'.format( unique_name) file.save(savename) sound = self.db.insert_sound( Sound(file.filename, savename, user_id)) return sound else: # Must be a txt/mic response file savename = UPLOAD_DIR + 'mic_responses/{0}.txt'.format( unique_name) file.save(savename) mic = self.db.insert_microphone( Microphone(file.filename, savename, user_id)) return mic else: return None # Conf = the robot config_robot # files = the dictionary or sound and file paths def insertFilePaths(conf, files, mot_id_map, mic_id_map): motors = conf['robot_config']['motors'] mics = conf['robot_config']['microphones'] for motor in motors: if str(motor['id']) in mot_id_map: snd_id = str(mot_id_map[str(motor['id'])]) if snd_id in files['sounds']: motor['sound']['uid'] = files['sounds'][snd_id].id motor['sound']['path'] = files['sounds'][ snd_id].pathToFile else: # else Must be a pre existing sound sound_obj = self.db.get_sound(snd_id) motor['sound']['uid'] = sound_obj.id motor['sound']['path'] = sound_obj.pathToFile for mic in mics: if str(mic['id']) in mic_id_map: res_id = str(mic_id_map[str(mic['id'])]) if res_id in files['responses']: mic['mic_style']['uid'] = files['responses'][res_id].id mic['mic_style']['path'] = files['responses'][ res_id].pathToFile else: # else Must be a pre existing response mic_obj = self.db.get_microphone(res_id) mic['mic_style']['uid'] = mic_obj.id mic['mic_style']['path'] = mic_obj.pathToFile return conf @self.app.route('/designer/save', methods=['POST']) def save_robot_config(): if 'userID' not in session: return jsonify({"success": "false"}) # Process Sounds files = {'sounds': {}, 'responses': {}} for id in request.files: f = request.files[id] file_obj = processFileUpload(f, session['userID']) if file_obj == None: return jsonify({ "success": "false", "reason": "invalid file" }) elif isinstance(file_obj, Sound): files['sounds'][str(id)] = file_obj else: files['responses'][str(id)] = file_obj #Load config and mot_id to i map conf = json.loads(request.form['robot-config']) # robot config mot_id_map = json.loads(request.form['mot_id_map'] ) # map of motor id to motor sound file/id mic_id_map = json.loads(request.form['mic_id_map'] ) # map of mic id to mic response file/id # Update the config with new id values conf = insertFilePaths(conf, files, mot_id_map, mic_id_map) if 'robot_to_update' in request.form: robot = self.db.get_robot(request.form['robot_to_update']) if robot.userID == session['userID']: filename = robot.pathToConfig with open(filename, 'w') as f: json.dump(conf, f, sort_keys=False, indent=4, ensure_ascii=False) else: # Write the config to a file unique_name = uuid.uuid4() filename = UPLOAD_DIR + 'robot_configs/{0}.json'.format( unique_name) with open(filename, 'w') as f: json.dump(conf, f, sort_keys=False, indent=4, ensure_ascii=False) robot = Robot(request.form['robot_name'], filename, session['userID']) robot = self.db.insert_robot(robot) return jsonify({"success": "true"}) @self.app.route('/robotdesign', methods=['GET']) def robotdesigner(): if 'userID' not in session: return redirect('/login?ref=robotdesign') sounds = self.db.get_user_sounds(session['userID']) mics = self.db.get_user_mics(session['userID']) if 'robot' in request.args: robotID = request.args['robot'] robot = self.db.get_robot(robotID) if robot.userID != session['userID']: robot = None robot_conf = "" else: robot_conf = open(robot.pathToConfig).read() else: robot_conf = "" robot = None return render_template('robotdesign.html', user=self.db.get_user(id=session['userID']), sounds=sounds, mic_responses=mics, robotconfig=robot_conf, robot=robot) @self.app.route('/upload_config', methods=['POST']) def review_config(): if 'userID' not in session: return jsonify({"success": "false"}) if 'file' in request.files: file = request.files['file'] myfile = file.read() ret = {"success": "true", "config": json.loads(myfile)} # Needs to check validity of uploaded file too. return jsonify(ret) else: return jsonify({"success": "false"}) @self.app.route("/dl/<path>") def downloadLogFile(path=None): filename = os.path.join(self.app.root_path, 'static', 'dl', path) return send_file(filename, as_attachment=True) @self.app.route("/uploads/sounds/<name>") def serveSound(name=None): # //filename = os.path.join(self.app.root_path, 'uploads' ,'sounds', name) return send_file("server/uploads/sounds/{0}".format(name), as_attachment=True) @self.app.route("/search") def search(): query = request.args['query'] if 'query' in request.args else None searchFor = request.args['type'] if 'type' in request.args else '%' relevantItems = [] if query == None: relevantItems = self.db.get_all( 'SELECT * FROM public_items WHERE type LIKE ?', [searchFor], type=PublicItem) else: relevantItems = self.db.get_all( "SELECT * FROM public_items WHERE type LIKE ? AND (name LIKE ? OR description LIKE ?)", [searchFor, '%' + query + '%', '%' + query + '%'], type=PublicItem) add_count = { x.id: len( self.db.get_all( 'SELECT * FROM user_added_items WHERE itemID = ?', [x.id], type=UserAddedItem)) for x in relevantItems } return render_template('search.html', user=self.db.get_user(id=session['userID']), items=relevantItems, addCount=add_count) @self.app.route("/quicksearch") def quick_search(): query = request.args['query'] if 'query' in request.args else None searchFor = request.args['type'] if 'type' in request.args else '%' relevantItems = [] if query == None: relevantItems = self.db.get_all( 'SELECT * FROM public_items WHERE type LIKE ?', [searchFor], type=PublicItem) else: relevantItems = self.db.get_all( "SELECT * FROM public_items WHERE type LIKE ? AND (name LIKE ? OR description LIKE ?)", [searchFor, '%' + query + '%', '%' + query + '%'], type=PublicItem) processedItems = [{ 'id': x.id, 'name': x.name, 'desc': x.description, 'likes': x.likes, 'type': x.type, 'itemID': x.itemID } for x in relevantItems] return jsonify({'result': processedItems}) @self.app.route("/publish", methods=['POST']) def publish(): item = None if request.form['type'] == "ROBOT": item = self.db.get_robot(request.form['id']) elif request.form['type'] == "SIM": item = self.db.get_simulation(request.form['id']) elif request.form['type'] == "SOUND": item = self.db.get_sound(request.form['id']) elif request.form['type'] == "MIC": item = self.db.get_microphone(request.form['id']) if item == None or item.userID != session['userID']: return redirect('/dashboard') date = str(dt.now().date()) publicItem = PublicItem(request.form['name'], request.form['desc'], request.form['type'], request.form['id'], session['userID'], publishDate=date) publicItem = self.db.insert_public_item(publicItem) return redirect('/search?query={0}&type={1}'.format( publicItem.name, publicItem.type)) @self.app.route("/toggle_like", methods=["POST"]) def toggleLike(): if 'userID' not in session: return jsonify({"success": "false"}) if 'item' not in request.form: return jsonify({"success": "false"}) existing_like = self.db.get_one( "SELECT * FROM user_liked_items WHERE itemID = ? AND userID = ?", [request.form['item'], session['userID']], type=UserLikedItem) if existing_like is None: likedItem = UserLikedItem(session['userID'], request.form['item']) self.db.insert_user_liked_item(likedItem) else: self.db.run_query('DELETE FROM user_liked_items WHERE id = ?', [existing_like.id]) allLikes = self.db.get_all( 'SELECT * FROM user_liked_items WHERE itemID = ?', [request.form['item']], type=UserLikedItem) self.db.run_query('UPDATE public_items SET likes = ? WHERE id = ?', [len(allLikes), request.form['item']]) return jsonify({"success": "true", "like_count": len(allLikes)}) @self.app.route("/toggle_add", methods=["POST"]) def toggle_add(): if 'userID' not in session: return jsonify({"success": "false"}) if 'item' not in request.form: return jsonify({"success": "false"}) existing_add = self.db.get_one( "SELECT * FROM user_added_items WHERE itemID = ? AND userID = ?", [request.form['item'], session['userID']], type=UserAddedItem) if existing_add is None: addedItem = UserAddedItem(session['userID'], request.form['item']) self.db.insert_user_added_item(addedItem) else: self.db.run_query('DELETE FROM user_added_items WHERE id = ?', [existing_add.id]) allAdds = self.db.get_all( 'SELECT * FROM user_added_items WHERE itemID = ?', [request.form['item']], type=UserAddedItem) return jsonify({"success": "true", "add_count": len(allAdds)}) @self.app.route('/documentation') def documentation(): p = request.args['p'] if 'p' in request.args else 'introduction' doc = "" try: with open('server/doc_files/{0}.md'.format(p), 'r') as f: doc = f.read() except: return jsonify({'success': 'false'}) doc_html = Markup( markdown.markdown(doc, extensions=['fenced_code'])) return jsonify({'success': 'true', 'html': doc_html}) @self.app.route('/getrobotconfig') def getrobotconfig(): if 'userID' not in session: return jsonify({ "success": "false", "reason": "No user session" }) robotID = request.args['robot'] if 'robot' in request.args else None print(robotID) if robotID is not None: robot = self.db.get_one("SELECT * FROM robots WHERE id =?", [robotID], type=Robot) if robot is not None and (robot.userID == session['userID'] or self.db.item_is_public( robot.id, "ROBOT")): robot_conf = open(robot.pathToConfig).read() return jsonify({ "success": "true", "robot": json.loads(robot_conf) }) else: return jsonify({ "success": "false", "reason": "robot not found" }) else: return jsonify({ "success": "false", "reason": "robotID not sent" })
def getClient(): app = Flask(__name__) define_routes(app) client = app.test_client() return client
def client(app: Flask): with app.test_client() as c: yield c
class TestEndpoints(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.secret_key = 'super=secret' self.app.config['JWT_ALGORITHM'] = 'HS256' self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1) self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1) self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() @self.app.route('/auth/login', methods=['POST']) def login(): ret = { 'access_token': create_access_token('test', fresh=True), 'refresh_token': create_refresh_token('test') } return jsonify(ret), 200 @self.app.route('/auth/refresh', methods=['POST']) @jwt_refresh_token_required def refresh(): username = get_jwt_identity() ret = {'access_token': create_access_token(username, fresh=False)} return jsonify(ret), 200 @self.app.route('/auth/fresh-login', methods=['POST']) def fresh_login(): ret = {'access_token': create_access_token('test', fresh=True)} return jsonify(ret), 200 @self.app.route('/protected') @jwt_required def protected(): return jsonify({'msg': "hello world"}) @self.app.route('/fresh-protected') @fresh_jwt_required def fresh_protected(): return jsonify({'msg': "fresh hello world"}) def _jwt_post(self, url, jwt): response = self.client.post( url, content_type='application/json', headers={'Authorization': 'Bearer {}'.format(jwt)}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) return status_code, data def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'): header_type = '{} {}'.format(header_type, jwt).strip() response = self.client.get(url, headers={header_name: header_type}) status_code = response.status_code data = json.loads(response.get_data(as_text=True)) return status_code, data def test_login(self): response = self.client.post('/auth/login') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertIn('access_token', data) self.assertIn('refresh_token', data) def test_fresh_login(self): response = self.client.post('/auth/fresh-login') status_code = response.status_code data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 200) self.assertIn('access_token', data) self.assertNotIn('refresh_token', data) def test_refresh(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] refresh_token = data['refresh_token'] status_code, data = self._jwt_post('/auth/refresh', refresh_token) self.assertEqual(status_code, 200) self.assertIn('access_token', data) self.assertNotIn('refresh_token', data) def test_wrong_token_refresh(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] # Try to refresh with an access token instead of a refresh one status_code, data = self._jwt_post('/auth/refresh', access_token) self.assertEqual(status_code, 422) self.assertIn('msg', data) def test_jwt_required(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) fresh_access_token = data['access_token'] refresh_token = data['refresh_token'] # Test it works with a fresh token status, data = self._jwt_get('/protected', fresh_access_token) self.assertEqual(data, {'msg': 'hello world'}) self.assertEqual(status, 200) # Test it works with a non-fresh access token _, data = self._jwt_post('/auth/refresh', refresh_token) non_fresh_token = data['access_token'] status, data = self._jwt_get('/protected', non_fresh_token) self.assertEqual(status, 200) self.assertEqual(data, {'msg': 'hello world'}) def test_jwt_required_wrong_token(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) refresh_token = data['refresh_token'] # Shouldn't work with a refresh token status, text = self._jwt_get('/protected', refresh_token) self.assertEqual(status, 422) def test_fresh_jwt_required(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) fresh_access_token = data['access_token'] refresh_token = data['refresh_token'] # Test it works with a fresh token status, data = self._jwt_get('/fresh-protected', fresh_access_token) self.assertEqual(data, {'msg': 'fresh hello world'}) self.assertEqual(status, 200) # Test it works with a non-fresh access token _, data = self._jwt_post('/auth/refresh', refresh_token) non_fresh_token = data['access_token'] status, text = self._jwt_get('/fresh-protected', non_fresh_token) self.assertEqual(status, 401) def test_fresh_jwt_required_wrong_token(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) refresh_token = data['refresh_token'] # Shouldn't work with a refresh token status, text = self._jwt_get('/fresh-protected', refresh_token) self.assertEqual(status, 422) def test_without_secret_key(self): app = Flask(__name__) app.testing = True # Propagate exceptions JWTManager(app) client = app.test_client() @app.route('/login', methods=['POST']) def login(): ret = {'access_token': create_access_token('test', fresh=True)} return jsonify(ret), 200 with self.assertRaises(RuntimeError): client.post('/login') def test_bad_jwt_requests(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] # Test with no authorization header response = self.client.get('/protected') data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 401) self.assertIn('msg', data) # Test with missing type in authorization header auth_header = access_token response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 422) self.assertIn('msg', data) # Test with type not being Bearer in authorization header auth_header = "BANANA {}".format(access_token) response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 422) self.assertIn('msg', data) # Test with too many items in auth header auth_header = "Bearer {} BANANA".format(access_token) response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 422) self.assertIn('msg', data) def test_bad_tokens(self): # Test expired access token response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] status_code, data = self._jwt_get('/protected', access_token) self.assertEqual(status_code, 200) self.assertIn('msg', data) time.sleep(2) status_code, data = self._jwt_get('/protected', access_token) self.assertEqual(status_code, 401) self.assertIn('msg', data) # Test expired refresh token response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) refresh_token = data['refresh_token'] status_code, data = self._jwt_post('/auth/refresh', refresh_token) self.assertEqual(status_code, 200) self.assertIn('access_token', data) self.assertNotIn('msg', data) time.sleep(2) status_code, data = self._jwt_post('/auth/refresh', refresh_token) self.assertEqual(status_code, 401) self.assertNotIn('access_token', data) self.assertIn('msg', data) # Test Bogus token auth_header = "Bearer {}".format('this_is_totally_an_access_token') response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 422) self.assertIn('msg', data) # Test token that was signed with a different key with self.app.test_request_context(): token = encode_access_token('foo', 'newsecret', 'HS256', timedelta(minutes=5), True, {}, csrf=False) auth_header = "Bearer {}".format(token) response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 422) self.assertIn('msg', data) # Test with valid token that is missing required claims now = datetime.utcnow() token_data = {'exp': now + timedelta(minutes=5)} encoded_token = jwt.encode( token_data, self.app.config['SECRET_KEY'], self.app.config['JWT_ALGORITHM']).decode('utf-8') auth_header = "Bearer {}".format(encoded_token) response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) status_code = response.status_code self.assertEqual(status_code, 422) self.assertIn('msg', data) def test_jwt_identity_claims(self): # Setup custom claims @self.jwt_manager.user_claims_loader def custom_claims(identity): return {'foo': 'bar'} @self.app.route('/claims') @jwt_required def claims(): return jsonify({ 'username': get_jwt_identity(), 'claims': get_jwt_claims() }) # Login response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] # Test our custom endpoint status, data = self._jwt_get('/claims', access_token) self.assertEqual(status, 200) self.assertEqual(data, {'username': '******', 'claims': {'foo': 'bar'}}) def test_jwt_raw_token(self): # Endpoints that uses get raw tokens and returns the keys @self.app.route('/claims') @jwt_required def claims(): jwt = get_raw_jwt() claims_keys = [claim for claim in jwt] return jsonify(claims_keys), 200 # Login response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] # Test our custom endpoint status, data = self._jwt_get('/claims', access_token) self.assertEqual(status, 200) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) self.assertIn('identity', data) self.assertIn('fresh', data) self.assertIn('type', data) self.assertIn('user_claims', data) def test_different_headers(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) access_token = data['access_token'] self.app.config['JWT_HEADER_TYPE'] = 'JWT' status, data = self._jwt_get('/protected', access_token, header_type='JWT') self.assertEqual(data, {'msg': 'hello world'}) self.assertEqual(status, 200) self.app.config['JWT_HEADER_TYPE'] = '' status, data = self._jwt_get('/protected', access_token, header_type='') self.assertEqual(data, {'msg': 'hello world'}) self.assertEqual(status, 200) self.app.config['JWT_HEADER_TYPE'] = '' status, data = self._jwt_get('/protected', access_token, header_type='Bearer') self.assertIn('msg', data) self.assertEqual(status, 422) self.app.config['JWT_HEADER_TYPE'] = 'Bearer' self.app.config['JWT_HEADER_NAME'] = 'Auth' status, data = self._jwt_get('/protected', access_token, header_name='Auth', header_type='Bearer') self.assertIn('msg', data) self.assertEqual(status, 200) status, data = self._jwt_get('/protected', access_token, header_name='Authorization', header_type='Bearer') self.assertIn('msg', data) self.assertEqual(status, 401) def test_cookie_methods_fail_with_headers_configured(self): app = Flask(__name__) app.config['JWT_TOKEN_LOCATION'] = ['headers'] app.secret_key = 'super=secret' app.testing = True JWTManager(app) client = app.test_client() @app.route('/login-bad', methods=['POST']) def bad_login(): access_token = create_access_token('test') resp = jsonify({'login': True}) set_access_cookies(resp, access_token) return resp, 200 @app.route('/refresh-bad', methods=['POST']) def bad_refresh(): refresh_token = create_refresh_token('test') resp = jsonify({'login': True}) set_refresh_cookies(resp, refresh_token) return resp, 200 @app.route('/logout-bad', methods=['POST']) def bad_logout(): resp = jsonify({'logout': True}) unset_jwt_cookies(resp) return resp, 200 with self.assertRaises(RuntimeWarning): client.post('/login-bad') with self.assertRaises(RuntimeWarning): client.post('/refresh-bad') with self.assertRaises(RuntimeWarning): client.post('/logout-bad')
class TestNetworkPort(BaseTest): """Tests for NetworkPort blueprint""" def setUp(self): """Tests preparation""" # creates a test client self.app = Flask(__name__) self.app.register_blueprint(network_port.network_port) @self.app.errorhandler(status.HTTP_500_INTERNAL_SERVER_ERROR) def internal_server_error(error): """Creates a Internal Server Error response""" redfish_error = RedfishError( "InternalError", "The request failed due to an internal service error. " "The service is still operational.") redfish_error.add_extended_info("InternalError") error_str = redfish_error.serialize() return Response(response=error_str, status=status.HTTP_500_INTERNAL_SERVER_ERROR, mimetype="application/json") @self.app.errorhandler(status.HTTP_404_NOT_FOUND) def not_found(error): """Creates a Not Found Error response""" redfish_error = RedfishError("GeneralError", error.description) error_str = redfish_error.serialize() return Response(response=error_str, status=status.HTTP_404_NOT_FOUND, mimetype='application/json') self.app = self.app.test_client() # propagate the exceptions to the test client self.app.testing = True @mock.patch.object(network_port, 'g') def test_get_network_port(self, g): """Tests NetworkPort""" # Loading server_hardware mockup value with open('oneview_redfish_toolkit/mockups/oneview/ServerHardware.json' ) as f: server_hardware = json.load(f) # Loading NetworkPort mockup result with open('oneview_redfish_toolkit/mockups/redfish/' 'NetworkPort1-Ethernet.json') as f: network_port_mockup = json.load(f) # Create mock response g.oneview_client.server_hardware.get.return_value = server_hardware # Get NetworkPort response = self.app.get( "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/" "NetworkAdapters/3/NetworkPorts/1") # Gets json from response result = json.loads(response.data.decode("utf-8")) # Tests response self.assertEqual(status.HTTP_200_OK, response.status_code) self.assertEqual("application/json", response.mimetype) self.assertEqual(network_port_mockup, result) @mock.patch.object(network_port, 'g') def test_get_network_port_fibre_channel(self, g): """Tests NetworkPort""" # Loading server_hardware mockup value with open('oneview_redfish_toolkit/mockups/oneview/' 'ServerHardwareFibreChannel.json') as f: server_hardware = json.load(f) # Loading NetworkPort mockup result with open('oneview_redfish_toolkit/mockups/redfish/' 'NetworkPort1-FibreChannel.json') as f: network_port_mockup = json.load(f) # Create mock response g.oneview_client.server_hardware.get.return_value = server_hardware # Get NetworkPort response = self.app.get( "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/" "NetworkAdapters/2/NetworkPorts/1") # Gets json from response result = json.loads(response.data.decode("utf-8")) # Tests response self.assertEqual(status.HTTP_200_OK, response.status_code) self.assertEqual("application/json", response.mimetype) self.assertEqual(network_port_mockup, result) @mock.patch.object(network_port, 'g') def test_get_network_port_invalid_device_id(self, g): """Tests NetworkPort""" # Loading server_hardware mockup value with open('oneview_redfish_toolkit/mockups/oneview/ServerHardware.json' ) as f: server_hardware = json.load(f) # Create mock response g.oneview_client.server_hardware.get.return_value = server_hardware # Get NetworkPort response = self.app.get( "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/" "NetworkAdapters/invalid_id/NetworkPorts/1") # Tests response self.assertEqual(status.HTTP_404_NOT_FOUND, response.status_code) self.assertEqual("application/json", response.mimetype) @mock.patch.object(network_port, 'g') def test_get_network_port_sh_not_found(self, g): """Tests NetworkPort server hardware not found""" e = HPOneViewException({ 'errorCode': 'RESOURCE_NOT_FOUND', 'message': 'server-hardware not found', }) g.oneview_client.server_hardware.get.side_effect = e # Get NetworkPort response = self.app.get( "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/" "NetworkAdapters/3/NetworkPorts/1") self.assertEqual(status.HTTP_404_NOT_FOUND, response.status_code) self.assertEqual("application/json", response.mimetype) @mock.patch.object(network_port, 'g') def test_get_network_port_sh_exception(self, g): """Tests NetworkPort unknown exception""" e = HPOneViewException({ 'errorCode': 'ANOTHER_ERROR', 'message': 'server-hardware error', }) g.oneview_client.server_hardware.get.side_effect = e # Get NetworkPort response = self.app.get( "/redfish/v1/Chassis/30303437-3034-4D32-3230-313133364752/" "NetworkAdapters/3/NetworkPorts/1") self.assertEqual(status.HTTP_500_INTERNAL_SERVER_ERROR, response.status_code) self.assertEqual("application/json", response.mimetype)
complex_parameters = self.get_known_parameter_values() if len(complex_parameters) > 0: result = self.search_for_complex(complex_parameters) else: text_response = 'Введён запрос поиска без параметров или же по запросу ничего не найдено...' result = { 'speech': text_response, 'displayText': text_response, 'source': 'webhookdata' } result['contextOut'] = obj['result']['contexts'] return result @app.route('/', methods=['POST']) def post(): wb = Webhook() req = request.get_json(silent=True, force=True) res = wb.get_result(req) return make_response(jsonify(res)) if __name__ == '__main__': file = open('sample_request.json', 'r') j = json.load(file) p = app.test_client().post('/', data=json.dumps(j)) if p.status_code == 200: j = json.loads(p.data) print(j['data']) # [END app]
class BootstrapTestCase(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.testing = True self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False self.app.secret_key = 'for test' self.bootstrap = Bootstrap(self.app) # noqa @self.app.route('/') def index(): return render_template_string( '{{ bootstrap.load_css() }}{{ bootstrap.load_js() }}') self.context = self.app.test_request_context() self.context.push() self.client = self.app.test_client() def tearDown(self): self.context.pop() def test_extension_init(self): self.assertIn('bootstrap', current_app.extensions) def test_load_css(self): rv = self.bootstrap.load_css() self.assertIn('bootstrap.min.css', rv) def test_load_js(self): rv = self.bootstrap.load_js() self.assertIn('bootstrap.min.js', rv) def test_local_resources(self): current_app.config['BOOTSTRAP_SERVE_LOCAL'] = True response = self.client.get('/') data = response.get_data(as_text=True) self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', data) self.assertIn('bootstrap.min.js', data) self.assertIn('bootstrap.min.css', data) self.assertIn('jquery.min.js', data) css_response = self.client.get( '/bootstrap/static/css/bootstrap.min.css') js_response = self.client.get('/bootstrap/static/js/bootstrap.min.js') jquery_response = self.client.get('/bootstrap/static/jquery.min.js') self.assertNotEqual(css_response.status_code, 404) self.assertNotEqual(js_response.status_code, 404) self.assertNotEqual(jquery_response.status_code, 404) css_rv = self.bootstrap.load_css() js_rv = self.bootstrap.load_js() self.assertIn('/bootstrap/static/css/bootstrap.min.css', css_rv) self.assertIn('/bootstrap/static/js/bootstrap.min.js', js_rv) self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv) self.assertNotIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv) def test_cdn_resources(self): response = self.client.get('/') data = response.get_data(as_text=True) self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', data) self.assertIn('bootstrap.min.js', data) self.assertIn('bootstrap.min.css', data) css_rv = self.bootstrap.load_css() js_rv = self.bootstrap.load_js() self.assertNotIn('/bootstrap/static/css/bootstrap.min.css', css_rv) self.assertNotIn('/bootstrap/static/js/bootstrap.min.js', js_rv) self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', css_rv) self.assertIn('https://cdn.jsdelivr.net/npm/bootstrap', js_rv) def test_render_field(self): @self.app.route('/field') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_field %} {{ render_field(form.username) }} {{ render_field(form.password) }} ''', form=form) response = self.client.get('/field') data = response.get_data(as_text=True) self.assertIn( '<input class="form-control" id="username" name="username"', data) self.assertIn( '<input class="form-control" id="password" name="password"', data) def test_render_form(self): @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn( '<input class="form-control" id="username" name="username"', data) self.assertIn( '<input class="form-control" id="password" name="password"', data) def test_render_form_row(self): @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form_row %} {{ render_form_row([form.username, form.password]) }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn('<div class="form-row">', data) self.assertIn('<div class="col">', data) def test_render_form_row_row_class(self): @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form_row %} {{ render_form_row([form.username, form.password], row_class='row') }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn('<div class="row">', data) def test_render_form_row_col_class_default(self): @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form_row %} {{ render_form_row([form.username, form.password], col_class_default='col-md-6') }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn('<div class="col-md-6">', data) def test_render_form_row_col_map(self): @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form_row %} {{ render_form_row([form.username, form.password], col_map={'username': '******'}) }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn('<div class="col">', data) self.assertIn('<div class="col-md-6">', data) def test_render_pager(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) @self.app.route('/pager') def test(): db.drop_all() db.create_all() for i in range(100): m = Message() db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items return render_template_string(''' {% from 'bootstrap/pagination.html' import render_pager %} {{ render_pager(pagination) }} ''', pagination=pagination, messages=messages) response = self.client.get('/pager') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn('Previous', data) self.assertIn('Next', data) self.assertIn('<li class="page-item disabled">', data) response = self.client.get('/pager?page=2') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn('Previous', data) self.assertIn('Next', data) self.assertNotIn('<li class="page-item disabled">', data) def test_render_pagination(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) @self.app.route('/pagination') def test(): db.drop_all() db.create_all() for i in range(100): m = Message() db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items return render_template_string(''' {% from 'bootstrap/pagination.html' import render_pagination %} {{ render_pagination(pagination) }} ''', pagination=pagination, messages=messages) response = self.client.get('/pagination') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn( '<a class="page-link" href="#">1 <span class="sr-only">(current)</span></a>', data) self.assertIn('10</a>', data) response = self.client.get('/pagination?page=2') data = response.get_data(as_text=True) self.assertIn('<nav aria-label="Page navigation">', data) self.assertIn('1</a>', data) self.assertIn( '<a class="page-link" href="#">2 <span class="sr-only">(current)</span></a>', data) self.assertIn('10</a>', data) def test_render_nav_item(self): @self.app.route('/nav_item') def test(): return render_template_string(''' {% from 'bootstrap/nav.html' import render_nav_item %} {{ render_nav_item('test', 'Home') }} ''') response = self.client.get('/nav_item') data = response.get_data(as_text=True) self.assertIn('<a class="nav-item nav-link active"', data) def test_render_breadcrumb_item(self): @self.app.route('/breadcrumb_item') def test(): return render_template_string(''' {% from 'bootstrap/nav.html' import render_breadcrumb_item %} {{ render_breadcrumb_item('test', 'Home') }} ''') response = self.client.get('/breadcrumb_item') data = response.get_data(as_text=True) self.assertIn( '<li class="breadcrumb-item active" aria-current="page">', data) def test_render_static(self): @self.app.route('/test_static') def test(): return render_template_string(''' {% from 'bootstrap/utils.html' import render_static %} {{ render_static('css', 'test.css') }} {{ render_static('js', 'test.js') }} {{ render_static('icon', 'test.ico') }} ''') response = self.client.get('/test_static') data = response.get_data(as_text=True) self.assertIn( '<link rel="stylesheet" href="/static/test.css" type="text/css">', data) self.assertIn( '<script type="text/javascript" src="/static/test.js"></script>', data) self.assertIn('<link rel="icon" href="/static/test.ico">', data) def test_render_messages(self): @self.app.route('/messages') def test_messages(): flash('test message', 'danger') return render_template_string(''' {% from 'bootstrap/utils.html' import render_messages %} {{ render_messages() }} ''') @self.app.route('/container') def test_container(): flash('test message', 'danger') return render_template_string(''' {% from 'bootstrap/utils.html' import render_messages %} {{ render_messages(container=True) }} ''') @self.app.route('/dismissible') def test_dismissible(): flash('test message', 'danger') return render_template_string(''' {% from 'bootstrap/utils.html' import render_messages %} {{ render_messages(dismissible=True) }} ''') @self.app.route('/dismiss_animate') def test_dismiss_animate(): flash('test message', 'danger') return render_template_string(''' {% from 'bootstrap/utils.html' import render_messages %} {{ render_messages(dismissible=True, dismiss_animate=True) }} ''') response = self.client.get('/messages') data = response.get_data(as_text=True) self.assertIn('<div class="alert alert-danger"', data) response = self.client.get('/container') data = response.get_data(as_text=True) self.assertIn('<div class="container flashed-messages">', data) response = self.client.get('/dismissible') data = response.get_data(as_text=True) self.assertIn('alert-dismissible', data) self.assertIn( '<button type="button" class="close" data-dismiss="alert"', data) self.assertNotIn('fade show', data) response = self.client.get('/dismiss_animate') data = response.get_data(as_text=True) self.assertIn('alert-dismissible', data) self.assertIn( '<button type="button" class="close" data-dismiss="alert"', data) self.assertIn('fade show', data) # test WTForm fields for render_form and render_field def test_render_form_enctype(self): class SingleUploadForm(FlaskForm): avatar = FileField('Avatar') class MultiUploadForm(FlaskForm): photos = MultipleFileField('Multiple photos') @self.app.route('/single') def single(): form = SingleUploadForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) @self.app.route('/multi') def multi(): form = SingleUploadForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/single') data = response.get_data(as_text=True) self.assertIn('multipart/form-data', data) response = self.client.get('/multi') data = response.get_data(as_text=True) self.assertIn('multipart/form-data', data) # test render_kw class for WTForms field def test_form_render_kw_class(self): class TestForm(FlaskForm): username = StringField('Username') password = PasswordField('Password', render_kw={'class': 'my-password-class'}) submit = SubmitField(render_kw={'class': 'my-awesome-class'}) @self.app.route('/render_kw') def render_kw(): form = TestForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/render_kw') data = response.get_data(as_text=True) self.assertIn('class="form-control"', data) self.assertNotIn('class="form-control "', data) self.assertIn('class="form-control my-password-class"', data) self.assertIn('my-awesome-class', data) self.assertIn('btn', data) # test WTForm field description for BooleanField def test_form_description_for_booleanfield(self): class TestForm(FlaskForm): remember = BooleanField('Remember me', description='Just check this') @self.app.route('/description') def description(): form = TestForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/description') data = response.get_data(as_text=True) self.assertIn('Remember me', data) self.assertIn( '<small class="form-text text-muted">Just check this</small>', data) def test_button_size(self): self.assertEqual(current_app.config['BOOTSTRAP_BTN_SIZE'], 'md') current_app.config['BOOTSTRAP_BTN_SIZE'] = 'lg' @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn('btn-lg', data) @self.app.route('/form2') def test_overwrite(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form, button_size='sm') }} ''', form=form) response = self.client.get('/form2') data = response.get_data(as_text=True) self.assertNotIn('btn-lg', data) self.assertIn('btn-sm', data) def test_button_style(self): self.assertEqual(current_app.config['BOOTSTRAP_BTN_STYLE'], 'secondary') current_app.config['BOOTSTRAP_BTN_STYLE'] = 'primary' @self.app.route('/form') def test(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.get('/form') data = response.get_data(as_text=True) self.assertIn('btn-primary', data) @self.app.route('/form2') def test_overwrite(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form, button_style='success') }} ''', form=form) response = self.client.get('/form2') data = response.get_data(as_text=True) self.assertNotIn('btn-primary', data) self.assertIn('btn-success', data) @self.app.route('/form3') def test_button_map(): form = HelloForm() return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form, button_map={'submit': 'warning'}) }} ''', form=form) response = self.client.get('/form3') data = response.get_data(as_text=True) self.assertNotIn('btn-primary', data) self.assertIn('btn-warning', data) def test_error_message_for_radiofield_and_booleanfield(self): class TestForm(FlaskForm): remember = BooleanField('Remember me', validators=[DataRequired()]) option = RadioField(choices=[('dog', 'Dog'), ('cat', 'Cat'), ('bird', 'Bird'), ('alien', 'Alien')], validators=[DataRequired()]) @self.app.route('/error', methods=['GET', 'POST']) def error(): form = TestForm() if form.validate_on_submit(): pass return render_template_string(''' {% from 'bootstrap/form.html' import render_form %} {{ render_form(form) }} ''', form=form) response = self.client.post('/error', follow_redirects=True) data = response.get_data(as_text=True) self.assertIn('This field is required', data) def test_render_simple_table(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) text = db.Column(db.Text) @self.app.route('/table') def test(): db.drop_all() db.create_all() for i in range(10): m = Message(text='Test message {}'.format(i + 1)) db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items titles = [('id', '#'), ('text', 'Message')] return render_template_string(''' {% from 'bootstrap/table.html' import render_table %} {{ render_table(messages, titles) }} ''', titles=titles, messages=messages) response = self.client.get('/table') data = response.get_data(as_text=True) self.assertIn('<table class="table">', data) self.assertIn('<th scope="col">#</th>', data) self.assertIn('<th scope="col">Message</th>', data) self.assertIn('<th scope="col">Message</th>', data) self.assertIn('<th scope="row">1</th>', data) self.assertIn('<td>Test message 1</td>', data) def test_render_customized_table(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) text = db.Column(db.Text) @self.app.route('/table') def test(): db.drop_all() db.create_all() for i in range(10): m = Message(text='Test message {}'.format(i + 1)) db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items titles = [('id', '#'), ('text', 'Message')] return render_template_string(''' {% from 'bootstrap/table.html' import render_table %} {{ render_table(messages, titles, table_classes='table-striped', header_classes='thead-dark', caption='Messages') }} ''', titles=titles, messages=messages) response = self.client.get('/table') data = response.get_data(as_text=True) self.assertIn('<table class="table table-striped">', data) self.assertIn('<thead class="thead-dark">', data) self.assertIn('<caption>Messages</caption>', data) def test_render_responsive_table(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) text = db.Column(db.Text) @self.app.route('/table') def test(): db.drop_all() db.create_all() for i in range(10): m = Message(text='Test message {}'.format(i + 1)) db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items titles = [('id', '#'), ('text', 'Message')] return render_template_string(''' {% from 'bootstrap/table.html' import render_table %} {{ render_table(messages, titles, responsive=True, responsive_class='table-responsive-sm') }} ''', titles=titles, messages=messages) response = self.client.get('/table') data = response.get_data(as_text=True) self.assertIn('<div class="table-responsive-sm">', data) def test_build_table_titles(self): db = SQLAlchemy(self.app) class Message(db.Model): id = db.Column(db.Integer, primary_key=True) text = db.Column(db.Text) @self.app.route('/table') def test(): db.drop_all() db.create_all() for i in range(10): m = Message(text='Test message {}'.format(i + 1)) db.session.add(m) db.session.commit() page = request.args.get('page', 1, type=int) pagination = Message.query.paginate(page, per_page=10) messages = pagination.items return render_template_string(''' {% from 'bootstrap/table.html' import render_table %} {{ render_table(messages) }} ''', messages=messages) response = self.client.get('/table') data = response.get_data(as_text=True) self.assertIn('<table class="table">', data) self.assertIn('<th scope="col">#</th>', data) self.assertIn('<th scope="col">Text</th>', data) self.assertIn('<th scope="col">Text</th>', data) self.assertIn('<th scope="row">1</th>', data) self.assertIn('<td>Test message 1</td>', data) def test_build_table_titles_with_empty_data(self): @self.app.route('/table') def test(): messages = [] return render_template_string(''' {% from 'bootstrap/table.html' import render_table %} {{ render_table(messages) }} ''', messages=messages) response = self.client.get('/table') self.assertEqual(response.status_code, 200)
class UnicodeCookieUserIDTestCase(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.config['SECRET_KEY'] = 'deterministic' self.app.config['SESSION_PROTECTION'] = None self.remember_cookie_name = 'remember' self.app.config['REMEMBER_COOKIE_NAME'] = self.remember_cookie_name self.login_manager = LoginManager() self.login_manager.init_app(self.app) self.login_manager._login_disabled = False @self.app.route('/') def index(): return u'Welcome!' @self.app.route('/login-germanjapanese-remember') def login_germanjapanese_remember(): return unicode(login_user(germanjapanese, remember=True)) @self.app.route('/username') def username(): if current_user.is_authenticated: return current_user.name return u'Anonymous' @self.app.route('/userid') def user_id(): if current_user.is_authenticated: return current_user.id return u'wrong_id' @self.login_manager.user_loader def load_user(user_id): return USERS[unicode(user_id)] # This will help us with the possibility of typoes in the tests. Now # we shouldn't have to check each response to help us set up state # (such as login pages) to make sure it worked: we will always # get an exception raised (rather than return a 404 response) @self.app.errorhandler(404) def handle_404(e): raise e unittest.TestCase.setUp(self) def _delete_session(self, c): # Helper method to cause the session to be deleted # as if the browser was closed. This will remove # the session regardless of the permament flag # on the session! with c.session_transaction() as sess: sess.clear() def test_remember_me_username(self): with self.app.test_client() as c: c.get('/login-germanjapanese-remember') self._delete_session(c) result = c.get('/username') self.assertEqual(u'Müller', result.data.decode('utf-8')) def test_remember_me_user_id(self): with self.app.test_client() as c: c.get('/login-germanjapanese-remember') self._delete_session(c) result = c.get('/userid') self.assertEqual(u'佐藤', result.data.decode('utf-8'))
class UserAPITestCase(unittest.TestCase): def __init__(self, *args, **kwargs): self.app = Flask(__name__) self.app.secret_key = 'test_key' #add an index rule because some views redirect to index self.app.add_url_rule('/', view_func=TestView.as_view('index', template='index.html'), methods=['GET', 'POST']) self.app.add_url_rule('/restricted/', view_func=RestrictedView.as_view('restricted', template='index.html'), methods=['GET']) self.app.add_url_rule('/logout/', view_func=LoginAPI.logout) self.app.add_url_rule('/login/', view_func=LoginAPI.as_view('login'), methods=['GET', 'POST']) self.app.add_url_rule('/delete/<user_id>/', view_func=UserAPI.as_view('user_api'), methods=['DELETE']) self.test_client = self.app.test_client() self.user = None super(UserAPITestCase, self).__init__(*args, **kwargs) def setUp(self): #nothing to do for set up self.user = DatabaseTestCase.add_user() def tearDown(self): #nothing to do for tear down self.user = DatabaseTestCase.delete_user() def login(self, password): return self.test_client.post('/login/', data=dict( username='******', password=password), follow_redirects=True) def test_delete_user(self): '''test setting a user to inactive through UserAPI. ensures that deletion only works when user is logged in.''' with self.app.test_client() as c: #first check when user isn't logged in rv = c.delete('/delete/{0}/'.format(self.user.id)) assert User.objects(username='******')[0].active == True with c.session_transaction() as sess: sess['username'] = '******' rv = c.delete('/delete/{0}/'.format(self.user.id)) assert User.objects(username='******')[0].active == False def test_login(self): #good password works assert 'OK' in self.login('test_password').data #bad password doesn't assert 'error' in self.login('bad_password').data self.user.active = False self.user.save() #inactive users can't log in assert 'error' in self.login('test_password').data def test_user_logout(self): rv = self.test_client.get('/logout/', follow_redirects=True) assert 'OK' in rv.data def test_restricted_view(self): rv = self.test_client.get('/restricted/', follow_redirects=True) assert 'not logged in' in rv.data
class CompressionAlgoTests(unittest.TestCase): """ Test different scenarios for compression algorithm negotiation between client and server. Please note that algorithm names (even the "supported" ones) in these tests **do not** indicate that all of these are actually supported by this extension. """ def setUp(self): super(CompressionAlgoTests, self).setUp() # Create the app here but don't call `Compress()` on it just yet; we need # to be able to modify the settings in various tests. Calling `Compress(self.app)` # twice would result in two `@after_request` handlers, which would be bad. self.app = Flask(__name__) self.app.testing = True small_path = os.path.join(os.getcwd(), 'tests', 'templates', 'small.html') self.small_size = os.path.getsize(small_path) - 1 @self.app.route('/small/') def small(): return render_template('small.html') def test_setting_compress_algorithm_simple_string(self): """ Test that a single entry in `COMPRESS_ALGORITHM` still works for backwards compatibility """ self.app.config['COMPRESS_ALGORITHM'] = 'gzip' c = Compress(self.app) self.assertListEqual(c.enabled_algorithms, ['gzip']) def test_setting_compress_algorithm_cs_string(self): """ Test that `COMPRESS_ALGORITHM` can be a comma-separated string """ self.app.config['COMPRESS_ALGORITHM'] = 'gzip, br, zstd' c = Compress(self.app) self.assertListEqual(c.enabled_algorithms, ['gzip', 'br', 'zstd']) def test_setting_compress_algorithm_list(self): """ Test that `COMPRESS_ALGORITHM` can be a list of strings """ self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate'] c = Compress(self.app) self.assertListEqual(c.enabled_algorithms, ['gzip', 'br', 'deflate']) def test_one_algo_supported(self): """ Tests requesting a single supported compression algorithm """ accept_encoding = 'gzip' self.app.config['COMPRESS_ALGORITHM'] = ['br', 'gzip'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'gzip') def test_one_algo_unsupported(self): """ Tests requesting single unsupported compression algorithm """ accept_encoding = 'some-alien-algorithm' self.app.config['COMPRESS_ALGORITHM'] = ['br', 'gzip'] c = Compress(self.app) self.assertIsNone(c._choose_compress_algorithm(accept_encoding)) def test_multiple_algos_supported(self): """ Tests requesting multiple supported compression algorithms """ accept_encoding = 'br, gzip, zstd' self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip'] c = Compress(self.app) # When the decision is tied, we expect to see the first server-configured algorithm self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'zstd') def test_multiple_algos_unsupported(self): """ Tests requesting multiple unsupported compression algorithms """ accept_encoding = 'future-algo, alien-algo, forbidden-algo' self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip'] c = Compress(self.app) self.assertIsNone(c._choose_compress_algorithm(accept_encoding)) def test_multiple_algos_with_wildcard(self): """ Tests requesting multiple unsupported compression algorithms and a wildcard """ accept_encoding = 'future-algo, alien-algo, forbidden-algo, *' self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip'] c = Compress(self.app) # We expect to see the first server-configured algorithm self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'zstd') def test_multiple_algos_with_different_quality(self): """ Tests requesting multiple supported compression algorithms with different q-factors """ accept_encoding = 'zstd;q=0.8, br;q=0.9, gzip;q=0.5' self.app.config['COMPRESS_ALGORITHM'] = ['zstd', 'br', 'gzip'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'br') def test_multiple_algos_with_equal_quality(self): """ Tests requesting multiple supported compression algorithms with equal q-factors """ accept_encoding = 'zstd;q=0.5, br;q=0.5, gzip;q=0.5' self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'zstd'] c = Compress(self.app) # We expect to see the first server-configured algorithm self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'gzip') def test_default_quality_is_1(self): """ Tests that when making mixed-quality requests, the default q-factor is 1.0 """ accept_encoding = 'deflate, br;q=0.999, gzip;q=0.5' self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'deflate') def test_default_wildcard_quality_is_0(self): """ Tests that a wildcard has a default q-factor of 0.0 """ accept_encoding = 'br;q=0.001, *' self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), 'br') def test_wildcard_quality(self): """ Tests that a wildcard with q=0 is discarded """ accept_encoding = '*;q=0' self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), None) def test_identity(self): """ Tests that identity is understood """ accept_encoding = 'identity;q=1, br;q=0.5, *;q=0' self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), None) def test_chrome_ranged_requests(self): """ Tests that Chrome ranged requests behave as expected """ accept_encoding = 'identity;q=1, *;q=0' self.app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br', 'deflate'] c = Compress(self.app) self.assertEqual(c._choose_compress_algorithm(accept_encoding), None) def test_content_encoding_is_correct(self): """ Test that the `Content-Encoding` header matches the compression algorithm """ self.app.config['COMPRESS_ALGORITHM'] = ['br', 'gzip', 'deflate'] Compress(self.app) headers_gzip = [('Accept-Encoding', 'gzip')] client = self.app.test_client() response_gzip = client.options('/small/', headers=headers_gzip) self.assertIn('Content-Encoding', response_gzip.headers) self.assertEqual(response_gzip.headers.get('Content-Encoding'), 'gzip') headers_br = [('Accept-Encoding', 'br')] client = self.app.test_client() response_br = client.options('/small/', headers=headers_br) self.assertIn('Content-Encoding', response_br.headers) self.assertEqual(response_br.headers.get('Content-Encoding'), 'br') headers_deflate = [('Accept-Encoding', 'deflate')] client = self.app.test_client() response_deflate = client.options('/small/', headers=headers_deflate) self.assertIn('Content-Encoding', response_deflate.headers) self.assertEqual(response_deflate.headers.get('Content-Encoding'), 'deflate')
class UrlTests(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.testing = True small_path = os.path.join(os.getcwd(), 'tests', 'templates', 'small.html') large_path = os.path.join(os.getcwd(), 'tests', 'templates', 'large.html') self.small_size = os.path.getsize(small_path) - 1 self.large_size = os.path.getsize(large_path) - 1 Compress(self.app) @self.app.route('/small/') def small(): return render_template('small.html') @self.app.route('/large/') def large(): return render_template('large.html') def client_get(self, ufs): client = self.app.test_client() response = client.get(ufs, headers=[('Accept-Encoding', 'gzip')]) self.assertEqual(response.status_code, 200) return response def test_br_algorithm(self): client = self.app.test_client() headers = [('Accept-Encoding', 'br')] response = client.options('/small/', headers=headers) self.assertEqual(response.status_code, 200) response = client.options('/large/', headers=headers) self.assertEqual(response.status_code, 200) def test_compress_min_size(self): """ Tests COMPRESS_MIN_SIZE correctly affects response data. """ response = self.client_get('/small/') self.assertEqual(self.small_size, len(response.data)) response = self.client_get('/large/') self.assertNotEqual(self.large_size, len(response.data)) def test_mimetype_mismatch(self): """ Tests if mimetype not in COMPRESS_MIMETYPES. """ response = self.client_get('/static/1.png') self.assertEqual(response.mimetype, 'image/png') def test_content_length_options(self): client = self.app.test_client() headers = [('Accept-Encoding', 'gzip')] response = client.options('/small/', headers=headers) self.assertEqual(response.status_code, 200) def test_gzip_compression_level(self): """ Tests COMPRESS_LEVEL correctly affects response data. """ self.app.config['COMPRESS_LEVEL'] = 1 client = self.app.test_client() response = client.get('/large/', headers=[('Accept-Encoding', 'gzip')]) response1_size = len(response.data) self.app.config['COMPRESS_LEVEL'] = 6 client = self.app.test_client() response = client.get('/large/', headers=[('Accept-Encoding', 'gzip')]) response6_size = len(response.data) self.assertNotEqual(response1_size, response6_size) def test_br_compression_level(self): """ Tests that COMPRESS_BR_LEVEL correctly affects response data. """ self.app.config['COMPRESS_BR_LEVEL'] = 4 client = self.app.test_client() response = client.get('/large/', headers=[('Accept-Encoding', 'br')]) response4_size = len(response.data) self.app.config['COMPRESS_BR_LEVEL'] = 11 client = self.app.test_client() response = client.get('/large/', headers=[('Accept-Encoding', 'br')]) response11_size = len(response.data) self.assertNotEqual(response4_size, response11_size) def test_deflate_compression_level(self): """ Tests COMPRESS_DELATE_LEVEL correctly affects response data. """ self.app.config['COMPRESS_DEFLATE_LEVEL'] = -1 client = self.app.test_client() response = client.get('/large/', headers=[('Accept-Encoding', 'deflate')]) response_size = len(response.data) self.app.config['COMPRESS_DEFLATE_LEVEL'] = 1 client = self.app.test_client() response = client.get('/large/', headers=[('Accept-Encoding', 'deflate')]) response1_size = len(response.data) self.assertNotEqual(response_size, response1_size)
class FlstatsTestCase(unittest.TestCase): def setUp(self): """Creates a Flask test app and registers two routes as well as the flstats blueprint. """ self.app = Flask(__name__) self.app.register_blueprint(webstatistics) self.client = self.app.test_client() @self.app.route('/url1') @statistics def url1(): return random.randint(0, 1000) @self.app.route('/url2') @statistics def url2(): return random.randint(0, 1000) def test_url1(self): """Test with one URL only""" self.client.get('/url1') # We make sure data processing is complete sleep(0.1) response = self.client.get('/flstats/') self.assertEqual(response.status, '200 OK') # Statistics tests data = json.loads(response.data) stats = data['stats'] self.assertEqual(len(stats), 1) stat = stats.pop() self.assertEqual(stat['url'], 'http://localhost/url1') self.assertEqual(stat['throughput'], 1) self.assertTrue(stat['min'] == stat['avg'] == stat['max']) for i in range(0, 9): self.client.get('/url1') # We make sure data processing is complete sleep(0.1) response = self.client.get('/flstats/') self.assertEqual(response.status, '200 OK') # Statistics tests data = json.loads(response.data) stats = data['stats'] self.assertEqual(len(stats), 1) stat = stats.pop() self.assertEqual(stat['url'], 'http://localhost/url1') self.assertEqual(stat['throughput'], 9) self.assertTrue(stat['min'] <= stat['avg'] <= stat['max']) def test_url2(self): """Test with two URLs""" for i in range(0, 20): self.client.get('/url2') # We make sure data processing is complete sleep(0.1) response = self.client.get('/flstats/') self.assertEqual(response.status, '200 OK') # Statistics tests data = json.loads(response.data) stats = data['stats'] self.assertEqual(len(stats), 2) for stat in stats: if stat['url'] == 'http://localhost/url1': self.assertEqual(stat['throughput'], 0) self.assertTrue(stat['min'] <= stat['avg'] <= stat['max']) elif stat['url'] == 'http://localhost/url2': self.assertEqual(stat['throughput'], 20) self.assertTrue(stat['min'] <= stat['avg'] <= stat['max']) else: self.fail('Invalid URL, WTF?!')
class TestViews(TestCase): def setUp(self): self.app = Flask(__name__) self.app.config['SCHEDULER_VIEWS_ENABLED'] = True self.scheduler = APScheduler(app=self.app) self.scheduler.start() self.client = self.app.test_client() def test_add_job(self): job = { 'id': 'job1', 'func': 'test_views:job1', 'trigger': 'date', 'run_date': '2020-12-01T12:30:01+00:00', } response = self.client.post('/scheduler/jobs', data=json.dumps(job)) self.assertEqual(response.status_code, 200) job2 = json.loads(response.get_data(as_text=True)) self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('run_date'), job2.get('run_date')) def test_delete_job(self): self.__add_job() response = self.client.delete('/scheduler/jobs/job1') self.assertEqual(response.status_code, 204) response = self.client.get('/scheduler/jobs/job1') self.assertEqual(response.status_code, 404) def test_get_job(self): job = self.__add_job() response = self.client.get('/scheduler/jobs/job1') self.assertEqual(response.status_code, 200) job2 = json.loads(response.get_data(as_text=True)) self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('minutes'), job2.get('minutes')) def test_get_all_jobs(self): job = self.__add_job() response = self.client.get('/scheduler/jobs') self.assertEqual(response.status_code, 200) jobs = json.loads(response.get_data(as_text=True)) self.assertEqual(len(jobs), 1) job2 = jobs[0] self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('minutes'), job2.get('minutes')) def test_update_job(self): job = self.__add_job() data_to_update = { 'args': [1] } response = self.client.patch('/scheduler/jobs/job1', data=json.dumps(data_to_update)) self.assertEqual(response.status_code, 200) job2 = json.loads(response.get_data(as_text=True)) self.assertEqual(job.get('id'), job2.get('id')) self.assertEqual(job.get('func'), job2.get('func')) self.assertEqual(data_to_update.get('args'), job2.get('args')) self.assertEqual(job.get('trigger'), job2.get('trigger')) self.assertEqual(job.get('minutes'), job2.get('minutes')) def test_pause_and_resume_job(self): self.__add_job() response = self.client.post('/scheduler/jobs/job1/pause') self.assertEqual(response.status_code, 200) job = json.loads(response.get_data(as_text=True)) self.assertIsNone(job.get('next_run_time')) response = self.client.post('/scheduler/jobs/job1/resume') self.assertEqual(response.status_code, 200) job = json.loads(response.get_data(as_text=True)) self.assertIsNotNone(job.get('next_run_time')) def __add_job(self): job = { 'id': 'job1', 'func': 'test_views:job1', 'trigger': 'interval', 'minutes': 10, } response = self.client.post('/scheduler/jobs', data=json.dumps(job)) return json.loads(response.get_data(as_text=True))