示例#1
3
def test_gravatar_url():
    """Test that the gravatar url is generated correctly"""
    app = Flask(__name__)

    with app.test_request_context("/"):
        app.debug = True
        url = gravatar_url("*****@*****.**")
        eq_(url, "http://www.gravatar.com/avatar/" "b642b4217b34b1e8d3bd915fc65c4452?d=mm")

        url = gravatar_url("*****@*****.**", 200)
        eq_(url, "http://www.gravatar.com/avatar/" "b642b4217b34b1e8d3bd915fc65c4452?s=200&d=mm")

        app.debug = False

        url = gravatar_url("*****@*****.**")
        eq_(url, "http://www.gravatar.com/avatar/" "b642b4217b34b1e8d3bd915fc65c4452?d=mm")

        app.config["SITE_URL"] = "http://www.site.com"

        url = gravatar_url("*****@*****.**")
        eq_(
            url,
            "http://www.gravatar.com/avatar/"
            "b642b4217b34b1e8d3bd915fc65c4452"
            "?d=http%3A%2F%2Fwww.site.com%2Fstatic%2Fimg%2Fdefault-avatar.png",
        )
示例#2
1
def test_context_local():
    responses.add(responses.GET, "https://google.com")

    # set up two apps with two different set of auth tokens
    app1 = Flask(__name__)
    ghbp1 = make_github_blueprint("foo1", "bar1", redirect_to="url1")
    app1.register_blueprint(ghbp1)
    ghbp1.token_getter(lambda: {"access_token": "app1"})

    app2 = Flask(__name__)
    ghbp2 = make_github_blueprint("foo2", "bar2", redirect_to="url2")
    app2.register_blueprint(ghbp2)
    ghbp2.token_getter(lambda: {"access_token": "app2"})

    # outside of a request context, referencing functions on the `github` object
    # will raise an exception
    with pytest.raises(RuntimeError):
        github.get("https://google.com")

    # inside of a request context, `github` should be a proxy to the correct
    # blueprint session
    with app1.test_request_context("/"):
        app1.preprocess_request()
        github.get("https://google.com")
        request = responses.calls[0].request
        assert request.headers["Authorization"] == "Bearer app1"


    with app2.test_request_context("/"):
        app2.preprocess_request()
        github.get("https://google.com")
        request = responses.calls[1].request
        assert request.headers["Authorization"] == "Bearer app2"
示例#3
1
def test_context_local():
    responses.add(responses.GET, "https://google.com")

    # set up two apps with two different set of auth tokens
    app1 = Flask(__name__)
    jbp1 = make_jira_blueprint(
        "https://t1.atlassian.com", "foo1", "bar1", redirect_to="url1"
    )
    app1.register_blueprint(jbp1)

    app2 = Flask(__name__)
    jbp2 = make_jira_blueprint(
        "https://t2.atlassian.com", "foo2", "bar2", redirect_to="url2"
    )
    app2.register_blueprint(jbp2)

    # outside of a request context, referencing functions on the `jira` object
    # will raise an exception
    with pytest.raises(RuntimeError):
        jira.get("https://google.com")

    # inside of a request context, `jira` should be a proxy to the correct
    # blueprint session
    with app1.test_request_context("/"):
        jbp1.session.auth.client.get_oauth_signature = mock.Mock(return_value="sig1")
        jbp2.session.auth.client.get_oauth_signature = mock.Mock(return_value="sig2")

        app1.preprocess_request()
        jira.get("https://google.com")
        auth_header = dict(
            parse_authorization_header(
                responses.calls[0].request.headers["Authorization"].decode("utf-8")
            )
        )
        assert auth_header["oauth_consumer_key"] == "foo1"
        assert auth_header["oauth_signature"] == "sig1"

    with app2.test_request_context("/"):
        jbp1.session.auth.client.get_oauth_signature = mock.Mock(return_value="sig1")
        jbp2.session.auth.client.get_oauth_signature = mock.Mock(return_value="sig2")

        app2.preprocess_request()
        jira.get("https://google.com")
        auth_header = dict(
            parse_authorization_header(
                responses.calls[1].request.headers["Authorization"].decode("utf-8")
            )
        )
        assert auth_header["oauth_consumer_key"] == "foo2"
        assert auth_header["oauth_signature"] == "sig2"
示例#4
0
def gen_app(config):
    """Generate a fresh app."""
    app = Flask('testapp')
    app.testing = True
    app.config.update(**config)

    FlaskCLI(app)
    FlaskMenu(app)
    Babel(app)
    Mail(app)
    InvenioDB(app)
    InvenioAccounts(app)
    FlaskOAuth(app)
    InvenioOAuthClient(app)

    app.register_blueprint(blueprint_client)
    app.register_blueprint(blueprint_settings)

    with app.app_context():
        db.create_all()

    app.test_request_context().push()

    datastore = app.extensions['invenio-accounts'].datastore

    datastore.create_user(
        email="*****@*****.**", password='******', active=True)
    datastore.create_user(
        email="*****@*****.**", password='******', active=True)
    datastore.create_user(
        email="*****@*****.**", password='******', active=True)
    datastore.commit()

    return app
示例#5
0
文件: test_api.py 项目: CeBkCn/dobot
    def test_handle_smart_errors(self):
        app = Flask(__name__)
        api = flask_restful.Api(app)
        view = flask_restful.Resource

        api.add_resource(view, '/foo', endpoint='bor')
        api.add_resource(view, '/fee', endpoint='bir')
        api.add_resource(view, '/fii', endpoint='ber')

        with app.test_request_context("/faaaaa"):
            resp = api.handle_error(NotFound())
            self.assertEquals(resp.status_code, 404)
            self.assertEquals(resp.data.decode(), dumps({
                "message": NotFound.description,
            }) + "\n")

        with app.test_request_context("/fOo"):
            resp = api.handle_error(NotFound())
            self.assertEquals(resp.status_code, 404)
            self.assertTrue('did you mean /foo ?' in resp.data.decode())

        app.config['ERROR_404_HELP'] = False

        with app.test_request_context("/fOo"):
            resp = api.handle_error(NotFound())
            self.assertEquals(resp.status_code, 404)
            self.assertEquals(resp.data.decode(), dumps({
                "message": NotFound.description
            }) + "\n")
class PluginManagerGetPlugins(unittest.TestCase):

    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['TESTING'] = True
        self.plugin_manager = PluginManager()
        self.plugin_manager.init_app(self.app)

    def test_get_enabled_plugins(self):
        with self.app.test_request_context():
            plugins = get_enabled_plugins()

        self.assertEquals(
            set(plugins),
            set(self.plugin_manager.plugins.values())
        )

    def test_get_all_plugins(self):
        with self.app.test_request_context():
            plugins = get_all_plugins()

        self.assertEquals(len(plugins), 3)

    def test_get_plugin(self):
        with self.app.test_request_context():
            plugin = get_plugin("test1")

        self.assertEquals(plugin, self.plugin_manager.plugins["test1"])

    def test_get_plugin_from_all(self):
        with self.app.test_request_context():
            plugin = get_plugin_from_all("test3")

        self.assertFalse(plugin.enabled)
示例#7
0
def test_sentry6():
    """Test Sentry 6."""
    from invenio_logging.sentry import InvenioLoggingSentry
    app = Flask('testapp')
    app.config.update(dict(
        SENTRY_DSN='http://*****:*****@localhost/0',
        LOGGING_SENTRY_CLASS='invenio_logging.sentry6:Sentry6',
        SENTRY_USER_ATTRS=['name'],
        SECRET_KEY='CHANGEME',
    ))
    InvenioLoggingSentry(app)
    LoginManager(app)

    class User(UserMixin):
        def __init__(self, user_id, name):
            self.id = user_id
            self.name = name

    with app.test_request_context('/'):
        assert app.extensions['sentry'].get_user_info(request) == {}

    with app.test_request_context('/'):
        login_user(User(1, 'viggo'))
        assert app.extensions['sentry'].get_user_info(request) == {
            'id': '1',
            'name': 'viggo',
        }
示例#8
0
class SearchTestBase(unittest.TestCase):
    def setUp(self):
        class TestConfig(object):
            SQLALCHEMY_TRACK_MODIFICATIONS = True
            SQLALCHEMY_DATABASE_URI = 'sqlite://'
            DEBUG = True
            TESTING = True
            MSEARCH_INDEX_NAME = mkdtemp()
            # MSEARCH_BACKEND = 'whoosh'

        self.app = Flask(__name__)
        self.app.config.from_object(TestConfig())
        # we need this instance to be:
        #  a) global for all objects we share and
        #  b) fresh for every test run
        global db
        db = SQLAlchemy(self.app)
        self.search = Search(self.app, db=db)
        self.Post = None

    def init_data(self):
        if self.Post is None:
            self.fail('Post class not defined')
        with self.app.test_request_context():
            db.create_all()
            for (i, title) in enumerate(titles, 1):
                post = self.Post(title=title, content='content%d' % i)
                post.save()

    def tearDown(self):
        with self.app.test_request_context():
            db.drop_all()
            db.metadata.clear()
示例#9
0
    def test_handle_smart_errors(self):
        app = Flask(__name__)
        api = flask_restful.Api(app)
        view = flask_restful.Resource

        exception = Mock()
        exception.code = 404
        exception.data = {"status": 404, "message": "Not Found"}
        api.add_resource(view, '/foo', endpoint='bor')
        api.add_resource(view, '/fee', endpoint='bir')
        api.add_resource(view, '/fii', endpoint='ber')


        with app.test_request_context("/faaaaa"):
            resp = api.handle_error(exception)
            self.assertEquals(resp.status_code, 404)
            self.assertEquals(resp.data, dumps({
                "status": 404, "message": "Not Found",
            }))

        with app.test_request_context("/fOo"):
            resp = api.handle_error(exception)
            self.assertEquals(resp.status_code, 404)
            self.assertEquals(resp.data, dumps({
                "status": 404, "message": "Not Found. You have requested this URI [/fOo] but did you mean /foo ?",
            }))

        with app.test_request_context("/fOo"):
            del exception.data["message"]
            resp = api.handle_error(exception)
            self.assertEquals(resp.status_code, 404)
            self.assertEquals(resp.data, dumps({
                "status": 404, "message": "You have requested this URI [/fOo] but did you mean /foo ?",
            }))
    def test_heartbeat(self):
        app = Flask(__name__)
        controller = HeartbeatController()

        with app.test_request_context('/'):
            response = controller.heartbeat()
        eq_(200, response.status_code)
        eq_(controller.HEALTH_CHECK_TYPE, response.headers.get('Content-Type'))
        data = json.loads(response.data)
        eq_('pass', data['status'])

        # Create a .version file.
        root_dir = os.path.join(os.path.split(__file__)[0], "..", "..")
        version_filename = os.path.join(root_dir, controller.VERSION_FILENAME)
        with open(version_filename, 'w') as f:
            f.write('ba.na.na-10-ssssssssss')

        # Create a mock configuration object to test with.
        class MockConfiguration(Configuration):
            instance = dict()

        with app.test_request_context('/'):
            response = controller.heartbeat(conf_class=MockConfiguration)
        if os.path.exists(version_filename):
            os.remove(version_filename)

        eq_(200, response.status_code)
        content_type = response.headers.get('Content-Type')
        eq_(controller.HEALTH_CHECK_TYPE, content_type)

        data = json.loads(response.data)
        eq_('pass', data['status'])
        eq_('ba.na.na', data['version'])
        eq_('ba.na.na-10-ssssssssss', data['releaseID'])
示例#11
0
def test_create_filter_dsl():
    """Test request value extraction."""
    app = Flask('testapp')
    kwargs = MultiDict([('a', '1')])
    defs = dict(
        type=terms_filter('type.type'),
        subtype=terms_filter('type.subtype'),
    )

    with app.test_request_context(u'?type=a&type=b&subtype=c&type=zażółcić'):
        filters, args = _create_filter_dsl(kwargs, defs)
        assert len(filters) == 2
        assert args == MultiDict([
            ('a', u'1'),
            ('type', u'a'),
            ('type', u'b'),
            ('subtype', u'c'),
            ('type', u'zażółcić')
        ])

    kwargs = MultiDict([('a', '1')])
    with app.test_request_context('?atype=a&atype=b'):
        filters, args = _create_filter_dsl(kwargs, defs)
        assert not filters
        assert args == kwargs
 def test_non_blueprint_rest_error_routing(self):
     blueprint = Blueprint('test', __name__)
     api = flask_restful.Api(blueprint)
     api.add_resource(HelloWorld(), '/hi', endpoint="hello")
     api.add_resource(GoodbyeWorld(404), '/bye', endpoint="bye")
     app = Flask(__name__)
     app.register_blueprint(blueprint, url_prefix='/blueprint')
     api2 = flask_restful.Api(app)
     api2.add_resource(HelloWorld(), '/hi', endpoint="hello")
     api2.add_resource(GoodbyeWorld(404), '/bye', endpoint="bye")
     with app.test_request_context('/hi', method='POST'):
         assert_false(api._should_use_fr_error_handler())
         assert_true(api2._should_use_fr_error_handler())
         assert_false(api._has_fr_route())
         assert_true(api2._has_fr_route())
     with app.test_request_context('/blueprint/hi', method='POST'):
         assert_true(api._should_use_fr_error_handler())
         assert_false(api2._should_use_fr_error_handler())
         assert_true(api._has_fr_route())
         assert_false(api2._has_fr_route())
     api._should_use_fr_error_handler = Mock(return_value=False)
     api2._should_use_fr_error_handler = Mock(return_value=False)
     with app.test_request_context('/bye'):
         assert_false(api._has_fr_route())
         assert_true(api2._has_fr_route())
     with app.test_request_context('/blueprint/bye'):
         assert_true(api._has_fr_route())
         assert_false(api2._has_fr_route())
def test_resource2():
    app = Flask(__name__)
    app.debug = True
    api = Api(app)

    @api.resource(name="hi")
    class Hello(Resource):
        def get(self, *args, **kvargs):
            # import pdb
            # pdb.set_trace()
            return "hello"

        def post_login(self, *args, **kvargs):
            return "login"

    # import pdb
    # pdb.set_trace()
    # api.add_resource(Hello)
    client = app.test_client()
    with app.test_request_context("/hi"):
        # import pdb
        # pdb.set_trace()
        assert url_for("hi") == "/hi"
        assert url_for("hi@login") == "/hi/login"

    with app.test_request_context("/hi/login"):
        assert request.endpoint == "hi@login"
        assert True

    assert "hello" == client.get("/hi").data
    assert "login" == client.post("/hi/login").data
示例#14
0
class LoginInSessionTestCase(unittest.TestCase):
    ''' Tests for login_user_in_session function '''

    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SECRET_KEY'] = 'deterministic'
        self.login_manager = LoginManager()
        self.login_manager.init_app(self.app)

        unittest.TestCase.setUp(self)

    def test_login_user_in_session(self):
        with self.app.test_request_context():
            session = {}
            login_user_in_session(session, notch)
            self.assertTrue('user_id' in session)
            self.assertTrue('_fresh' in session)
            self.assertTrue('_id' in session)
            self.assertTrue('remember' not in session)

    def test_login_user_in_session_remember(self):
        with self.app.test_request_context():
            session = {}
            login_user_in_session(session, notch, remember=True)
            self.assertTrue('user_id' in session)
            self.assertTrue('_fresh' in session)
            self.assertTrue('_id' in session)
            self.assertTrue(session['remember'])
示例#15
0
class TestCoasterViews(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        load_config_from_file(self.app, "settings.py")
        self.app.add_url_rule('/', 'index', index)
        self.app.add_url_rule('/', 'external', external)

    def test_get_current_url(self):
        with self.app.test_request_context('/'):
            self.assertEqual(get_current_url(), '/')

        with self.app.test_request_context('/?q=hasgeek'):
            self.assertEqual(get_current_url(), '/?q=hasgeek')

    def test_get_next_url(self):
        with self.app.test_request_context('/?next=http://example.com'):
            self.assertEqual(get_next_url(external=True), 'http://example.com')
            self.assertEqual(get_next_url(), '/')
            self.assertEqual(get_next_url(default=()), ())
        
        with self.app.test_request_context('/'):
            session['next'] = '/external'
            self.assertEqual(get_next_url(session=True), '/external')

    def test_jsonp(self):
        with self.app.test_request_context('/?callback=callback'):
            kwargs = {'lang': 'en-us', 'query': 'python'}
            r = jsonp(**kwargs)
            response = 'callback({\n  "%s": "%s",\n  "%s": "%s"\n});' % ('lang', kwargs['lang'], 'query', kwargs['query'])
            self.assertEqual(response, r.data)

        with self.app.test_request_context('/'):
            param1, param2 = 1, 2
            r = jsonp(param1=param1, param2=param2)
            resp = json.loads(r.response[0])
            self.assertEqual(resp['param1'], param1)
            self.assertEqual(resp['param2'], param2)
            r = jsonp({'param1': param1, 'param2': param2})
            resp = json.loads(r.response[0])
            self.assertEqual(resp['param1'], param1)
            self.assertEqual(resp['param2'], param2)
            r = jsonp([('param1', param1), ('param2', param2)])
            resp = json.loads(r.response[0])
            self.assertEqual(resp['param1'], param1)
            self.assertEqual(resp['param2'], param2)

    def test_requestargs(self):
        with self.app.test_request_context('/?p3=1&p3=2&p2=3&p1=1'):
            self.assertEqual(f(), (u'1', 3, [1, 2]))

        with self.app.test_request_context('/?p2=2'):
            self.assertEqual(f(p1='1'), (u'1', 2, None))

        with self.app.test_request_context('/?p3=1&p3=2&p2=3&p1=1'):
            self.assertEqual(f1(), (u'1', 3, [u'1', u'2']))

        with self.app.test_request_context('/?p2=2&p4=4'):
            self.assertRaises(TypeError, f, p4='4')
            self.assertRaises(BadRequest, f, p4='4')
示例#16
0
    def test_endpoints(self):
        app = Flask(__name__)
        api = flask_restful.Api(app)
        api.add_resource(HelloWorld, '/ids/<int:id>', endpoint="hello")
        with app.test_request_context('/foo'):
            self.assertFalse(api._has_fr_route())

        with app.test_request_context('/ids/3'):
            self.assertTrue(api._has_fr_route())
示例#17
0
def test_upload_field():
    app = Flask(__name__)

    path = _create_temp()

    def _remove_testfiles():
        safe_delete(path, 'test1.txt')
        safe_delete(path, 'test2.txt')

    class TestForm(form.BaseForm):
        upload = form.FileUploadField('Upload', path=path)

    class Dummy(object):
        pass

    my_form = TestForm()
    eq_(my_form.upload.path, path)

    _remove_testfiles()

    dummy = Dummy()

    # Check upload
    with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hello World 1'), 'test1.txt')}):
        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())

        my_form.populate_obj(dummy)

        eq_(dummy.upload, 'test1.txt')
        ok_(op.exists(op.join(path, 'test1.txt')))

    # Check replace
    with app.test_request_context(method='POST', data={'upload': (BytesIO(b'Hello World 2'), 'test2.txt')}):
        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())
        my_form.populate_obj(dummy)

        eq_(dummy.upload, 'test2.txt')
        ok_(not op.exists(op.join(path, 'test1.txt')))
        ok_(op.exists(op.join(path, 'test2.txt')))

    # Check delete
    with app.test_request_context(method='POST', data={'_upload-delete': 'checked'}):
        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())

        my_form.populate_obj(dummy)
        eq_(dummy.upload, None)

        ok_(not op.exists(op.join(path, 'test2.txt')))
示例#18
0
def base_app(request):
    """Flask application fixture without OAuthClient initialized."""
    instance_path = tempfile.mkdtemp()
    base_app = Flask('testapp')
    base_app.config.update(
        TESTING=True,
        WTF_CSRF_ENABLED=False,
        LOGIN_DISABLED=False,
        CACHE_TYPE='simple',
        OAUTHCLIENT_REMOTE_APPS=dict(
            orcid=REMOTE_APP,
        ),
        ORCID_APP_CREDENTIALS=dict(
            consumer_key='changeme',
            consumer_secret='changeme',
        ),
        # use local memory mailbox
        EMAIL_BACKEND='flask_email.backends.locmem.Mail',
        SQLALCHEMY_DATABASE_URI=os.getenv('SQLALCHEMY_DATABASE_URI',
                                          'sqlite://'),
        SERVER_NAME='localhost',
        DEBUG=False,
        SECRET_KEY='TEST',
        SECURITY_PASSWORD_HASH='plaintext',
        SECURITY_PASSWORD_SCHEMES=['plaintext'],
    )
    FlaskCLI(base_app)
    FlaskMenu(base_app)
    Babel(base_app)
    Mail(base_app)
    InvenioDB(base_app)
    InvenioAccounts(base_app)

    with base_app.app_context():
        if str(db.engine.url) != 'sqlite://' and \
           not database_exists(str(db.engine.url)):
                create_database(str(db.engine.url))
        db.create_all()

    def teardown():
        with base_app.app_context():
            db.session.close()
            if str(db.engine.url) != 'sqlite://':
                drop_database(str(db.engine.url))
            shutil.rmtree(instance_path)

    request.addfinalizer(teardown)

    base_app.test_request_context().push()

    return base_app
示例#19
0
class TestSelf(object):

    def setup(self):
        self.app = Flask(__name__)

    def test_only_valid_link_attrs_set(self):
        with self.app.test_request_context():
            l = Self(foo='foo', name='foo')

            assert not hasattr(l, 'foo')
            assert l.name == 'foo'

    def test_to_dict(self):
        with self.app.test_request_context():
            l = Self(foo='foo', name='foo')

            expected = {
                'self': {
                    'href': '/',
                    'name': 'foo',
                }
            }

            assert l.to_dict() == expected

    def test_to_json(self):
        with self.app.test_request_context():
            l = Self(foo='foo', name='foo')

            expected = json.dumps({
                'self': {
                    'href': '/',
                    'name': 'foo',
                }
            })

            assert l.to_json() == expected

    def test_with_server_name(self):
        self.app.config['SERVER_NAME'] = 'foo.com'
        with self.app.test_request_context():
            l = Self(foo='foo', name='foo')

            expected = {
                'self': {
                    'href': 'http://foo.com/',
                    'name': 'foo',
                }
            }

            assert l.to_dict() == expected
示例#20
0
    def test_get_request_body_args(self):
        """
        Tests getting the request body args
        from a flask request object.
        """
        app = Flask('myapp')
        body = dict(x=1)
        with app.test_request_context('/', data=json.dumps(body), content_type='application/json'):
            q, b = get_request_query_body_args(request)
            self.assertDictEqual(b, body)

        with app.test_request_context('/', data=body):  # Form encoded
            q, b = get_request_query_body_args(request)
            self.assertDictEqual(b, dict(x=['1']))
示例#21
0
    def test_base_url(self):
        """
        Tests that the base_url always returns the
        correct shit.
        """
        app = Flask('myapp')
        d = FlaskDispatcher(app)

        with app.test_request_context():
            self.assertEqual(d.base_url, 'http://localhost/')

        d = FlaskDispatcher(app, url_prefix='someprefix', auto_options_name='Options2')
        with app.test_request_context():
            self.assertEqual(d.base_url, 'http://localhost/someprefix')
示例#22
0
class SwaggerBenchmark(Benchmark):
    '''Swagger serialization benchmark for a full API'''
    times = 1000

    def before_class(self):
        self.app = Flask(__name__)
        api.init_app(self.app)

    def bench_swagger_specs(self):
        with self.app.test_request_context('/'):
            return Swagger(api).as_dict()

    def bench_swagger_specs_cached(self):
        with self.app.test_request_context('/'):
            return api.__schema__
class RuleTestCase(unittest.TestCase):

    def setUp(self):
        # create a new app for every test
        self.app = Flask(__name__)

    def _make_rule(self, **kwargs):
        def vf():
            return {}

        return Rule(
            kwargs.get('routes', ['/', ]),
            kwargs.get('methods', ['GET', ]),
            kwargs.get('view_func_or_data', vf),
            kwargs.get('renderer', json_renderer),
            kwargs.get('view_kwargs'),
        )

    def test_rule_single_route(self):
        r = self._make_rule(routes='/')
        assert_equal(r.routes, ['/', ])

    def test_rule_single_method(self):
        r = self._make_rule(methods='GET')
        assert_equal(r.methods, ['GET', ])

    def test_rule_lambda_view(self):
        r = self._make_rule(view_func_or_data=lambda: '')
        assert_true(callable(r.view_func_or_data))

    def test_url_for_simple(self):
        r = Rule(['/project/'], 'get', view_func_or_data=dummy_view, renderer=json_renderer)
        process_rules(self.app, [r])
        with self.app.test_request_context():
            assert_equal(url_for('JSONRenderer__dummy_view'), '/project/')

    def test_url_for_with_argument(self):
        r = Rule(['/project/<pid>/'], 'get', view_func_or_data=dummy_view2, renderer=json_renderer)
        process_rules(self.app, [r])
        with self.app.test_request_context():
            assert_equal(url_for('JSONRenderer__dummy_view2', pid=123), '/project/123/')

    def test_url_for_with_prefix(self):
        api_rule = Rule(['/project/'], 'get', view_func_or_data=dummy_view3,
                renderer=json_renderer)
        process_rules(self.app, [api_rule], prefix='/api/v1')
        with self.app.test_request_context():
            assert_equal(url_for('JSONRenderer__dummy_view3'), '/api/v1/project/')
示例#24
0
class SecretKeyTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)

    def test_bytes(self):
        self.app.config['SECRET_KEY'] = b'\x9e\x8f\x14'
        with self.app.test_request_context():
            self.assertEqual(_secret_key(), b'\x9e\x8f\x14')

    def test_native(self):
        self.app.config['SECRET_KEY'] = '\x9e\x8f\x14'
        with self.app.test_request_context():
            self.assertEqual(_secret_key(), b'\x9e\x8f\x14')

    def test_default(self):
        self.assertEqual(_secret_key('\x9e\x8f\x14'), b'\x9e\x8f\x14')
示例#25
0
 def test_resource_resp(self):
     app = Flask(__name__)
     resource = flask_restful.Resource()
     resource.get = Mock()
     with app.test_request_context("/foo"):
         resource.get.return_value = flask.make_response('')
         resource.dispatch_request()
示例#26
0
    def test_url_absolute(self):
        app = Flask(__name__)
        app.add_url_rule("/<hey>", "foobar", view_func=lambda x: x)
        field = fields.Url("foobar", absolute=True)

        with app.test_request_context("/"):
            self.assertEquals("http://localhost/3", field.output("hey", Foo()))
示例#27
0
    def test_url_invalid_object(self):
        app = Flask(__name__)
        app.add_url_rule("/<hey>", "foobar", view_func=lambda x: x)
        field = fields.Url("foobar")

        with app.test_request_context("/"):
            self.assertRaises(MarshallingException, lambda: field.output("hey", None))
示例#28
0
文件: tests.py 项目: biner/flask-mail
class TestCase(unittest.TestCase):

    TESTING = True
    MAIL_DEFAULT_SENDER = "*****@*****.**"

    def setUp(self):
        self.app = Flask(__name__)
        self.app.config.from_object(self)
        self.assertTrue(self.app.testing)
        self.mail = Mail(self.app)
        self.ctx = self.app.test_request_context()
        self.ctx.push()

    def tearDown(self):
        self.ctx.pop()

    def assertIn(self, member, container, msg=None):
        if hasattr(unittest.TestCase, "assertIn"):
            return unittest.TestCase.assertIn(self, member, container, msg)
        return self.assertTrue(member in container)

    def assertNotIn(self, member, container, msg=None):
        if hasattr(unittest.TestCase, "assertNotIn"):
            return unittest.TestCase.assertNotIn(self, member, container, msg)
        return self.assertFalse(member in container)

    def assertIsNone(self, obj, msg=None):
        if hasattr(unittest.TestCase, "assertIsNone"):
            return unittest.TestCase.assertIsNone(self, obj, msg)
        return self.assertTrue(obj is None)

    def assertIsNotNone(self, obj, msg=None):
        if hasattr(unittest.TestCase, "assertIsNotNone"):
            return unittest.TestCase.assertIsNotNone(self, obj, msg)
        return self.assertTrue(obj is not None)
示例#29
0
class InitializationTestCase(unittest.TestCase):
    ''' Tests the two initialization methods '''

    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SECRET_KEY'] = '1234'

    def test_init_app(self):
        login_manager = LoginManager()
        login_manager.init_app(self.app, add_context_processor=True)

        self.assertIsInstance(login_manager, LoginManager)

    def test_class_init(self):
        login_manager = LoginManager(self.app, add_context_processor=True)

        self.assertIsInstance(login_manager, LoginManager)

    def test_login_disabled_is_set(self):
        login_manager = LoginManager(self.app, add_context_processor=True)
        self.assertFalse(login_manager._login_disabled)

    def test_no_user_loader_raises(self):
        login_manager = LoginManager(self.app, add_context_processor=True)
        with self.app.test_request_context():
            session['user_id'] = '2'
            with self.assertRaises(Exception) as cm:
                login_manager.reload_user()
            expected_exception_message = 'No user_loader has been installed'
            self.assertTrue(
                str(cm.exception).startswith(expected_exception_message))
示例#30
0
    def test_media_types(self):
        app = Flask(__name__)
        api = flask_restful.Api(app)

        with app.test_request_context("/foo",
            headers={'Accept': 'application/json'}):
            self.assertEquals(api.mediatypes(), ['application/json'])
示例#31
0
class TestUrlForHelpers(unittest.TestCase):
    def setUp(self):
        def dummy_view():
            return {}

        def dummy_guid_project_view():
            return {}

        def dummy_guid_profile_view():
            return {}

        self.app = Flask(__name__)

        api_rule = Rule(['/api/v1/<pid>/', '/api/v1/<pid>/component/<nid>/'],
                        'get', dummy_view, json_renderer)
        web_rule = Rule(['/<pid>/', '/<pid>/component/<nid>/'], 'get',
                        dummy_view, OsfWebRenderer)
        web_guid_project_rule = Rule([
            '/project/<pid>/',
            '/project/<pid>/node/<nid>/',
        ], 'get', dummy_guid_project_view, OsfWebRenderer)
        web_guid_profile_rule = Rule([
            '/profile/<pid>/',
        ], 'get', dummy_guid_profile_view, OsfWebRenderer)

        process_rules(
            self.app,
            [api_rule, web_rule, web_guid_project_rule, web_guid_profile_rule])

    def test_api_url_for(self):
        with self.app.test_request_context():
            assert api_url_for('dummy_view', pid='123') == '/api/v1/123/'

    def test_api_v2_url_with_port(self):
        full_url = api_v2_url('/nodes/abcd3/contributors/',
                              base_route='http://localhost:8000/',
                              base_prefix='v2/')
        assert_equal(full_url,
                     'http://localhost:8000/v2/nodes/abcd3/contributors/')

        # Handles URL the same way whether or not user enters a leading slash
        full_url = api_v2_url('nodes/abcd3/contributors/',
                              base_route='http://localhost:8000/',
                              base_prefix='v2/')
        assert_equal(full_url,
                     'http://localhost:8000/v2/nodes/abcd3/contributors/')

    def test_api_v2_url_with_params(self):
        """Handles- and encodes- URLs with parameters (dict and kwarg) correctly"""
        full_url = api_v2_url('/nodes/abcd3/contributors/',
                              params={'filter[fullname]': 'bob'},
                              base_route='https://api.osf.io/',
                              base_prefix='v2/',
                              page_size=10)
        assert_equal(
            full_url,
            'https://api.osf.io/v2/nodes/abcd3/contributors/?filter%5Bfullname%5D=bob&page_size=10'
        )

    def test_api_v2_url_base_path(self):
        """Given a blank string, should return the base path (domain + port + prefix) with no extra cruft at end"""
        full_url = api_v2_url('',
                              base_route='http://localhost:8000/',
                              base_prefix='v2/')
        assert_equal(full_url, 'http://localhost:8000/v2/')

    def test_web_url_for(self):
        with self.app.test_request_context():
            assert web_url_for('dummy_view', pid='123') == '/123/'

    def test_web_url_for_guid(self):
        with self.app.test_request_context():
            # check /project/<pid>
            assert_equal(
                '/pid123/',
                web_url_for('dummy_guid_project_view',
                            pid='pid123',
                            _guid=True))
            assert_equal(
                '/project/pid123/',
                web_url_for('dummy_guid_project_view',
                            pid='pid123',
                            _guid=False))
            assert_equal('/project/pid123/',
                         web_url_for('dummy_guid_project_view', pid='pid123'))
            # check /project/<pid>/node/<nid>
            assert_equal(
                '/nid321/',
                web_url_for('dummy_guid_project_view',
                            pid='pid123',
                            nid='nid321',
                            _guid=True))
            assert_equal(
                '/project/pid123/node/nid321/',
                web_url_for('dummy_guid_project_view',
                            pid='pid123',
                            nid='nid321',
                            _guid=False))
            assert_equal(
                '/project/pid123/node/nid321/',
                web_url_for('dummy_guid_project_view',
                            pid='pid123',
                            nid='nid321'))
            # check /profile/<pid>
            assert_equal(
                '/pro123/',
                web_url_for('dummy_guid_profile_view',
                            pid='pro123',
                            _guid=True))
            assert_equal(
                '/profile/pro123/',
                web_url_for('dummy_guid_profile_view',
                            pid='pro123',
                            _guid=False))
            assert_equal('/profile/pro123/',
                         web_url_for('dummy_guid_profile_view', pid='pro123'))

    def test_web_url_for_guid_regex_conditions(self):
        with self.app.test_request_context():
            # regex matches limit keys to a minimum of 5 alphanumeric characters.
            # check /project/<pid>
            assert_not_equal(
                '/123/',
                web_url_for('dummy_guid_project_view', pid='123', _guid=True))
            assert_equal(
                '/123456/',
                web_url_for('dummy_guid_project_view',
                            pid='123456',
                            _guid=True))
            # check /project/<pid>/node/<nid>
            assert_not_equal(
                '/321/',
                web_url_for('dummy_guid_project_view',
                            pid='123',
                            nid='321',
                            _guid=True))
            assert_equal(
                '/654321/',
                web_url_for('dummy_guid_project_view',
                            pid='123456',
                            nid='654321',
                            _guid=True))
            # check /profile/<pid>
            assert_not_equal(
                '/123/',
                web_url_for('dummy_guid_profile_view', pid='123', _guid=True))
            assert_equal(
                '/123456/',
                web_url_for('dummy_guid_profile_view',
                            pid='123456',
                            _guid=True))

    def test_web_url_for_guid_case_sensitive(self):
        with self.app.test_request_context():
            # check /project/<pid>
            assert_equal(
                '/ABCdef/',
                web_url_for('dummy_guid_project_view',
                            pid='ABCdef',
                            _guid=True))
            # check /project/<pid>/node/<nid>
            assert_equal(
                '/GHIjkl/',
                web_url_for('dummy_guid_project_view',
                            pid='ABCdef',
                            nid='GHIjkl',
                            _guid=True))
            # check /profile/<pid>
            assert_equal(
                '/MNOpqr/',
                web_url_for('dummy_guid_profile_view',
                            pid='MNOpqr',
                            _guid=True))

    def test_web_url_for_guid_invalid_unicode(self):
        with self.app.test_request_context():
            # unicode id's are not supported when encoding guid url's.
            # check /project/<pid>
            assert_not_equal(
                '/ø∆≤µ©/',
                web_url_for('dummy_guid_project_view', pid='ø∆≤µ©',
                            _guid=True))
            assert_equal(
                '/project/%C3%B8%CB%86%E2%88%86%E2%89%A4%C2%B5%CB%86/',
                web_url_for('dummy_guid_project_view',
                            pid='øˆ∆≤µˆ',
                            _guid=True))
            # check /project/<pid>/node/<nid>
            assert_not_equal(
                '/ø∆≤µ©/',
                web_url_for('dummy_guid_project_view',
                            pid='ø∆≤µ©',
                            nid='©µ≤∆ø',
                            _guid=True))
            assert_equal(
                '/project/%C3%B8%CB%86%E2%88%86%E2%89%A4%C2%B5%CB%86/node/%C2%A9%C2%B5%E2%89%A4%E2%88%86%C3%B8/',
                web_url_for('dummy_guid_project_view',
                            pid='øˆ∆≤µˆ',
                            nid='©µ≤∆ø',
                            _guid=True))
            # check /profile/<pid>
            assert_not_equal(
                '/ø∆≤µ©/',
                web_url_for('dummy_guid_profile_view', pid='ø∆≤µ©',
                            _guid=True))
            assert_equal(
                '/profile/%C3%B8%CB%86%E2%88%86%E2%89%A4%C2%B5%CB%86/',
                web_url_for('dummy_guid_profile_view',
                            pid='øˆ∆≤µˆ',
                            _guid=True))

    def test_api_url_for_with_multiple_urls(self):
        with self.app.test_request_context():
            url = api_url_for('dummy_view', pid='123', nid='abc')
            assert url == '/api/v1/123/component/abc/'

    def test_web_url_for_with_multiple_urls(self):
        with self.app.test_request_context():
            url = web_url_for('dummy_view', pid='123', nid='abc')
            assert url == '/123/component/abc/'

    def test_is_json_request(self):
        with self.app.test_request_context(content_type='application/json'):
            assert_true(is_json_request())
        with self.app.test_request_context(content_type=None):
            assert_false(is_json_request())
        with self.app.test_request_context(
                content_type='application/json;charset=UTF-8'):
            assert_true(is_json_request())

    def test_waterbutler_api_url_for(self):
        with self.app.test_request_context():
            url = waterbutler_api_url_for('fakeid', 'provider', '/path')
        assert_in('/fakeid/', url)
        assert_in('/path', url)
        assert_in('/providers/provider/', url)
        assert_in(settings.WATERBUTLER_URL, url)

    def test_waterbutler_api_url_for_internal(self):
        settings.WATERBUTLER_INTERNAL_URL = 'http://1.2.3.4:7777'
        with self.app.test_request_context():
            url = waterbutler_api_url_for('fakeid',
                                          'provider',
                                          '/path',
                                          _internal=True)

        assert_not_in(settings.WATERBUTLER_URL, url)
        assert_in(settings.WATERBUTLER_INTERNAL_URL, url)
        assert_in('/fakeid/', url)
        assert_in('/path', url)
        assert_in('/providers/provider', url)
示例#32
0
def test_image_upload_field():
    app = Flask(__name__)

    path = _create_temp()

    def _remove_testimages():
        safe_delete(path, 'test1.png')
        safe_delete(path, 'test1_thumb.jpg')
        safe_delete(path, 'test2.png')
        safe_delete(path, 'test2_thumb.jpg')
        safe_delete(path, 'test1.jpg')
        safe_delete(path, 'test1.jpeg')
        safe_delete(path, 'test1.gif')
        safe_delete(path, 'test1.png')
        safe_delete(path, 'test1.tiff')

    class TestForm(form.BaseForm):
        upload = form.ImageUploadField('Upload',
                                       base_path=path,
                                       thumbnail_size=(100, 100, True))

    class TestNoResizeForm(form.BaseForm):
        upload = form.ImageUploadField('Upload',
                                       base_path=path,
                                       endpoint='test')

    class TestAutoResizeForm(form.BaseForm):
        upload = form.ImageUploadField('Upload',
                                       base_path=path,
                                       max_size=(64, 64, True))

    class Dummy(object):
        pass

    my_form = TestForm()
    eq_(my_form.upload.base_path, path)
    eq_(my_form.upload.endpoint, 'static')

    _remove_testimages()

    dummy = Dummy()

    # Check upload
    filename = op.join(op.dirname(__file__), 'data', 'copyleft.png')

    with open(filename, 'rb') as fp:
        with app.test_request_context(method='POST',
                                      data={'upload': (fp, 'test1.png')}):
            my_form = TestForm(helpers.get_form_data())

            ok_(my_form.validate())

            my_form.populate_obj(dummy)

            eq_(dummy.upload, 'test1.png')
            ok_(op.exists(op.join(path, 'test1.png')))
            ok_(op.exists(op.join(path, 'test1_thumb.png')))

    # Check replace
    with open(filename, 'rb') as fp:
        with app.test_request_context(method='POST',
                                      data={'upload': (fp, 'test2.png')}):
            my_form = TestForm(helpers.get_form_data())

            ok_(my_form.validate())

            my_form.populate_obj(dummy)

            eq_(dummy.upload, 'test2.png')
            ok_(op.exists(op.join(path, 'test2.png')))
            ok_(op.exists(op.join(path, 'test2_thumb.png')))

            ok_(not op.exists(op.join(path, 'test1.png')))
            ok_(not op.exists(op.join(path, 'test1_thumb.jpg')))

    # Check delete
    with app.test_request_context(method='POST',
                                  data={'_upload-delete': 'checked'}):
        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())

        my_form.populate_obj(dummy)
        eq_(dummy.upload, None)

        ok_(not op.exists(op.join(path, 'test2.png')))
        ok_(not op.exists(op.join(path, 'test2_thumb.png')))

    # Check upload no-resize
    with open(filename, 'rb') as fp:
        with app.test_request_context(method='POST',
                                      data={'upload': (fp, 'test1.png')}):
            my_form = TestNoResizeForm(helpers.get_form_data())

            ok_(my_form.validate())

            my_form.populate_obj(dummy)

            eq_(dummy.upload, 'test1.png')
            ok_(op.exists(op.join(path, 'test1.png')))
            ok_(not op.exists(op.join(path, 'test1_thumb.png')))

    # Check upload, auto-resize
    filename = op.join(op.dirname(__file__), 'data', 'copyleft.png')

    with open(filename, 'rb') as fp:
        with app.test_request_context(method='POST',
                                      data={'upload': (fp, 'test1.png')}):
            my_form = TestAutoResizeForm(helpers.get_form_data())

            ok_(my_form.validate())

            my_form.populate_obj(dummy)

            eq_(dummy.upload, 'test1.png')
            ok_(op.exists(op.join(path, 'test1.png')))

    filename = op.join(op.dirname(__file__), 'data', 'copyleft.tiff')

    with open(filename, 'rb') as fp:
        with app.test_request_context(method='POST',
                                      data={'upload': (fp, 'test1.tiff')}):
            my_form = TestAutoResizeForm(helpers.get_form_data())

            ok_(my_form.validate())

            my_form.populate_obj(dummy)

            eq_(dummy.upload, 'test1.jpg')
            ok_(op.exists(op.join(path, 'test1.jpg')))

    # check allowed extensions
    for extension in ('gif', 'jpg', 'jpeg', 'png', 'tiff'):
        filename = 'copyleft.' + extension
        filepath = op.join(op.dirname(__file__), 'data', filename)
        with open(filepath, 'rb') as fp:
            with app.test_request_context(method='POST',
                                          data={'upload': (fp, filename)}):
                my_form = TestNoResizeForm(helpers.get_form_data())
                ok_(my_form.validate())
                my_form.populate_obj(dummy)
                eq_(dummy.upload, my_form.upload.data.filename)

    # check case-sensitivity for extensions
    filename = op.join(op.dirname(__file__), 'data', 'copyleft.jpg')
    with open(filename, 'rb') as fp:
        with app.test_request_context(method='POST',
                                      data={'upload': (fp, 'copyleft.JPG')}):
            my_form = TestNoResizeForm(helpers.get_form_data())
            ok_(my_form.validate())
示例#33
0
def create_app(config_name='ProductionConfig'):
    sentry_sdk.init(send_default_pii=False, integrations=[FlaskIntegration()])
    app = Flask(__name__)
    app.config.from_object(f'sfa_api.config.{config_name}')
    if 'REDIS_SETTINGS' in os.environ:
        app.config.from_envvar('REDIS_SETTINGS')
    ma.init_app(app)
    register_error_handlers(app)
    redoc_script = f"https://cdn.jsdelivr.net/npm/redoc@{app.config['REDOC_VERSION']}/bundles/redoc.standalone.js"  # NOQA
    talisman.init_app(
        app,
        content_security_policy={
            'default-src':
            "'self'",
            'style-src':
            "'unsafe-inline' 'self'",
            'img-src':
            "'self' data:",
            'object-src':
            "'none'",
            'script-src':
            ["'unsafe-inline'", 'blob:', redoc_script, "'strict-dynamic'"],
            'child-src':
            "blob:",
            'base-uri':
            "'none'"
        },
        content_security_policy_nonce_in=['script-src'])
    app.url_map.converters['uuid_str'] = UUIDStringConverter
    app.url_map.converters['zone_str'] = ZoneStringConverter

    from sfa_api.observations import obs_blp
    from sfa_api.forecasts import forecast_blp
    from sfa_api.sites import site_blp
    from sfa_api.users import user_blp, user_email_blp
    from sfa_api.roles import role_blp
    from sfa_api.permissions import permission_blp
    from sfa_api.reports import reports_blp
    from sfa_api.aggregates import agg_blp
    from sfa_api.zones import zone_blp

    for blp in (obs_blp, forecast_blp, site_blp, user_blp, user_email_blp,
                role_blp, permission_blp, reports_blp, agg_blp, zone_blp):
        blp.before_request(protect_endpoint)
        app.register_blueprint(blp)

    with app.test_request_context():
        for k, view in app.view_functions.items():
            if k == 'static':
                continue
            spec.path(view=view)

    @app.route('/openapi.yaml')
    def get_apispec_yaml():
        return Response(spec.to_yaml(), mimetype='application/yaml')

    @app.route('/openapi.json')
    def get_apispec_json():
        return jsonify(spec.to_dict())

    @app.route('/')
    def render_docs():
        return render_template('doc.html',
                               apispec_path=url_for('get_apispec_json'),
                               redoc_script=redoc_script)

    return app
示例#34
0
def test_upload_field():
    app = Flask(__name__)

    path = _create_temp()

    def _remove_testfiles():
        safe_delete(path, 'test1.txt')
        safe_delete(path, 'test2.txt')

    class TestForm(form.BaseForm):
        upload = form.FileUploadField('Upload', base_path=path)

    class TestNoOverWriteForm(form.BaseForm):
        upload = form.FileUploadField('Upload',
                                      base_path=path,
                                      allow_overwrite=False)

    class Dummy(object):
        pass

    my_form = TestForm()
    eq_(my_form.upload.base_path, path)

    _remove_testfiles()

    dummy = Dummy()

    # Check upload
    with app.test_request_context(
            method='POST',
            data={'upload': (BytesIO(b'Hello World 1'), 'test1.txt')}):
        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())

        my_form.populate_obj(dummy)

        eq_(dummy.upload, 'test1.txt')
        ok_(op.exists(op.join(path, 'test1.txt')))

    # Check replace
    with app.test_request_context(
            method='POST',
            data={'upload': (BytesIO(b'Hello World 2'), 'test2.txt')}):
        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())
        my_form.populate_obj(dummy)

        eq_(dummy.upload, 'test2.txt')
        ok_(not op.exists(op.join(path, 'test1.txt')))
        ok_(op.exists(op.join(path, 'test2.txt')))

    # Check delete
    with app.test_request_context(method='POST',
                                  data={'_upload-delete': 'checked'}):

        my_form = TestForm(helpers.get_form_data())

        ok_(my_form.validate())

        my_form.populate_obj(dummy)
        eq_(dummy.upload, None)

        ok_(not op.exists(op.join(path, 'test2.txt')))

    # Check overwrite
    _remove_testfiles()
    my_form_ow = TestNoOverWriteForm()
    with app.test_request_context(
            method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
        my_form_ow = TestNoOverWriteForm(helpers.get_form_data())

        ok_(my_form_ow.validate())
        my_form_ow.populate_obj(dummy)
        eq_(dummy.upload, 'test1.txt')
        ok_(op.exists(op.join(path, 'test1.txt')))

    with app.test_request_context(
            method='POST', data={'upload': (BytesIO(b'Hullo'), 'test1.txt')}):
        my_form_ow = TestNoOverWriteForm(helpers.get_form_data())

        ok_(not my_form_ow.validate())

    _remove_testfiles()
示例#35
0
def test_format_currency(app: Flask) -> None:
    with app.test_request_context("/long-example/"):
        app.preprocess_request()
        assert format_currency(Decimal("2.12")) == "2.12"
        assert format_currency(Decimal("2.13"), invert=True) == "-2.13"
示例#36
0
 def setUp(self):
     app = Flask('test')
     self.ctx = app.test_request_context('')
     self.ctx.push()
     self.app = app.test_client()
示例#37
0
 def test_resource(self):
     app = Flask(__name__)
     resource = flask_restful.Resource()
     resource.get = Mock()
     with app.test_request_context("/foo"):
         resource.dispatch_request()
示例#38
0
def base_app(request):
    """Flask application fixture without OAuthClient initialized."""
    instance_path = tempfile.mkdtemp()
    base_app = Flask('testapp')
    base_app.config.update(
        TESTING=True,
        WTF_CSRF_ENABLED=False,
        LOGIN_DISABLED=False,
        CACHE_TYPE='simple',
        OAUTHCLIENT_REMOTE_APPS=dict(
            cern=CERN_REMOTE_APP,
            orcid=ORCID_REMOTE_APP,
            github=GITHUB_REMOTE_APP,
        ),
        GITHUB_APP_CREDENTIALS=dict(
            consumer_key='github_key_changeme',
            consumer_secret='github_secret_changeme',
        ),
        ORCID_APP_CREDENTIALS=dict(
            consumer_key='orcid_key_changeme',
            consumer_secret='orcid_secret_changeme',
        ),
        CERN_APP_CREDENTIALS=dict(
            consumer_key='cern_key_changeme',
            consumer_secret='cern_secret_changeme',
        ),
        # use local memory mailbox
        EMAIL_BACKEND='flask_email.backends.locmem.Mail',
        SQLALCHEMY_DATABASE_URI=os.getenv('SQLALCHEMY_DATABASE_URI',
                                          'sqlite://'),
        SERVER_NAME='localhost',
        DEBUG=False,
        SECRET_KEY='TEST',
        SECURITY_DEPRECATED_PASSWORD_SCHEMES=[],
        SECURITY_PASSWORD_HASH='plaintext',
        SECURITY_PASSWORD_SCHEMES=['plaintext'],
    )
    FlaskMenu(base_app)
    Babel(base_app)
    Mail(base_app)
    InvenioDB(base_app)
    InvenioAccounts(base_app)

    with base_app.app_context():
        if str(db.engine.url) != 'sqlite://' and \
           not database_exists(str(db.engine.url)):
            create_database(str(db.engine.url))
        db.create_all()

    def teardown():
        with base_app.app_context():
            db.session.close()
            if str(db.engine.url) != 'sqlite://':
                drop_database(str(db.engine.url))
            shutil.rmtree(instance_path)

    request.addfinalizer(teardown)

    base_app.test_request_context().push()

    return base_app
 def setUp(self):
     app = Flask(__name__)
     self.api = Api(app)
     self.api.add_resource(EntityAddResource, '/entities/')
     self.app = app.test_client()
     self.context = app.test_request_context()
示例#40
0
class WrappedDash(Dash):
    'Wrapper around the Plotly Dash application instance'
    # pylint: disable=too-many-arguments, too-many-instance-attributes
    def __init__(self,
                 base_pathname=None, replacements=None, ndid=None, serve_locally=False,
                 **kwargs):

        self._uid = ndid

        self._flask_app = Flask(self._uid)
        self._notflask = PseudoFlask()
        self._base_pathname = base_pathname

        kwargs['url_base_pathname'] = self._base_pathname
        kwargs['server'] = self._notflask

        super(WrappedDash, self).__init__(__name__,
                                          **kwargs)

        self.css.config.serve_locally = serve_locally
        self.scripts.config.serve_locally = serve_locally

        self._adjust_id = False
        if replacements:
            self._replacements = replacements
        else:
            self._replacements = dict()
        self._use_dash_layout = len(self._replacements) < 1

        self._return_embedded = False

    def use_dash_layout(self):
        '''
        Indicate if the underlying dash layout can be used.

        If application state is in use, then the underlying dash layout functionality has to be
        augmented with the state information and this function returns False
        '''
        return self._use_dash_layout

    def augment_initial_layout(self, base_response, initial_arguments=None):
        'Add application state to initial values'
        if self.use_dash_layout() and not initial_arguments and False:
            return base_response.data, base_response.mimetype

        # Adjust the base layout response
        baseDataInBytes = base_response.data
        baseData = json.loads(baseDataInBytes.decode('utf-8'))

        # Also add in any initial arguments
        if initial_arguments:
            if isinstance(initial_arguments, str):
                initial_arguments = json.loads(initial_arguments)
        else:
            initial_arguments = {}

        # Define overrides as self._replacements updated with initial_arguments
        overrides = dict(self._replacements)
        overrides.update(initial_arguments)

        # Walk tree. If at any point we have an element whose id
        # matches, then replace any named values at this level
        reworked_data = self.walk_tree_and_replace(baseData, overrides)

        response_data = json.dumps(reworked_data,
                                   cls=PlotlyJSONEncoder)

        return response_data, base_response.mimetype

    def walk_tree_and_extract(self, data, target):
        'Walk tree of properties and extract identifiers and associated values'
        if isinstance(data, dict):
            for key in ['children', 'props',]:
                self.walk_tree_and_extract(data.get(key, None), target)
            ident = data.get('id', None)
            if ident is not None:
                idVals = target.get(ident, {})
                for key, value in data.items():
                    if key not in ['props', 'options', 'children', 'id']:
                        idVals[key] = value
                if idVals:
                    target[ident] = idVals
        if isinstance(data, list):
            for element in data:
                self.walk_tree_and_extract(element, target)

    def walk_tree_and_replace(self, data, overrides):
        '''
        Walk the tree. Rely on json decoding to insert instances of dict and list
        ie we use a dna test for anatine, rather than our eyes and ears...
        '''
        if isinstance(data, dict):
            response = {}
            replacements = {}
            # look for id entry
            thisID = data.get('id', None)
            if isinstance(thisID, dict):
                # handle case of thisID being a dict (pattern) => linear search in overrides dict
                for k, v in overrides.items():
                    if compare(id_python=thisID, id_dash=k):
                        replacements = v
                        break
            elif thisID is not None:
                # handle standard case of string thisID => key lookup
                replacements = overrides.get(thisID, {})
            # walk all keys and replace if needed
            for k, v in data.items():
                r = replacements.get(k, None)
                if r is None:
                    r = self.walk_tree_and_replace(v, overrides)
                response[k] = r
            return response
        if isinstance(data, list):
            # process each entry in turn and return
            return [self.walk_tree_and_replace(x, overrides) for x in data]
        return data

    def flask_app(self):
        'Underlying flask application for stub implementation'
        return self._flask_app

    def base_url(self):
        'Base url of this component'
        return self._base_pathname

    def app_context(self, *args, **kwargs):
        'Extract application context from underlying flask application'
        return self._flask_app.app_context(*args,
                                           **kwargs)

    def test_request_context(self, *args, **kwargs):
        'Request context for testing from underluying flask application'
        return self._flask_app.test_request_context(*args,
                                                    **kwargs)

    def locate_endpoint_function(self, name=None):
        'Locate endpoint function given name of view'
        if name is not None:
            ep = "%s_%s" %(self._base_pathname,
                           name)
        else:
            ep = self._base_pathname
        return self._notflask.endpoints[ep]['view_func']

    # pylint: disable=no-member
    @Dash.layout.setter
    def layout(self, value):
        'Overloaded layout function to fix component names as needed'

        if self._adjust_id:
            self._fix_component_id(value)
        return Dash.layout.fset(self, value)

    def _fix_component_id(self, component):
        'Fix name of component ad all of its children'

        theID = getattr(component, "id", None)
        if theID is not None:
            setattr(component, "id", self._fix_id(theID))
        try:
            for c in component.children:
                self._fix_component_id(c)
        except: #pylint: disable=bare-except
            pass

    def _fix_id(self, name):
        'Adjust identifier to include component name'
        if not self._adjust_id:
            return name
        return "%s_-_%s" %(self._uid,
                           name)

    def _fix_callback_item(self, item):
        'Update component identifier'
        item.component_id = self._fix_id(item.component_id)
        return item

    def callback(self, output, inputs=[], state=[], events=[]): # pylint: disable=dangerous-default-value
        'Invoke callback, adjusting variable names as needed'

        if isinstance(output, (list, tuple)):
            fixed_outputs = [self._fix_callback_item(x) for x in output]
        else:
            fixed_outputs = self._fix_callback_item(output)

        return super(WrappedDash, self).callback(fixed_outputs,
                                                 [self._fix_callback_item(x) for x in inputs],
                                                 [self._fix_callback_item(x) for x in state])

    def clientside_callback(self, clientside_function, output, inputs=[], state=[]): # pylint: disable=dangerous-default-value
        'Invoke callback, adjusting variable names as needed'

        if isinstance(output, (list, tuple)):
            fixed_outputs = [self._fix_callback_item(x) for x in output]
        else:
            fixed_outputs = self._fix_callback_item(output)

        return super(WrappedDash, self).clientside_callback(clientside_function,
                                                            fixed_outputs,
                                                            [self._fix_callback_item(x) for x in inputs],
                                                            [self._fix_callback_item(x) for x in state])

    def dispatch(self):
        'Perform dispatch, using request embedded within flask global state'
        import flask
        body = flask.request.get_json()
        return self.dispatch_with_args(body, argMap=dict())

    #pylint: disable=too-many-locals
    def dispatch_with_args(self, body, argMap):
        'Perform callback dispatching, with enhanced arguments and recording of response'
        inputs = body.get('inputs', [])
        input_values = inputs_to_dict(inputs)
        states = body.get('state', [])
        output = body['output']
        outputs_list = body.get('outputs') or split_callback_id(output)
        changed_props = body.get('changedPropIds', [])
        triggered_inputs = [{"prop_id": x, "value": input_values.get(x)} for x in changed_props]

        callback_context_info = {
            'inputs_list': inputs,
            'inputs': input_values,
            'states_list': states,
            'states': inputs_to_dict(states),
            'outputs_list': outputs_list,
            'outputs': outputs_list,
            'triggered': triggered_inputs,
            }

        callback_context = CallbackContext(**callback_context_info)

        # Overload dash global variable
        dash.callback_context = callback_context

        # Add context to arg map, if extended callbacks in use
        if len(argMap) > 0:
            argMap['callback_context'] = callback_context

        single_case = not(output.startswith('..') and output.endswith('..'))
        if single_case:
            # single Output (not in a list)
            outputs = [output]
        else:
            # multiple outputs in a list (the list could contain a single item)
            outputs = output[2:-2].split('...')

        args = []

        da = argMap.get('dash_app', None)

        callback_info = self.callback_map[output]

        for component_registration in callback_info['inputs']:
            for c in inputs:
                if c['property'] == component_registration['property'] and compare(id_python=c['id'],id_dash=component_registration['id']):
                    v = c.get('value', None)
                    args.append(v)
                    if da:
                        da.update_current_state(c['id'], c['property'], v)

        for component_registration in callback_info['state']:
            for c in states:
                if c['property'] == component_registration['property'] and compare(id_python=c['id'],id_dash=component_registration['id']):
                    v = c.get('value', None)
                    args.append(v)
                    if da:
                        da.update_current_state(c['id'], c['property'], v)

        # Dash 1.11 introduces a set of outputs
        outputs_list = body.get('outputs') or split_callback_id(output)
        argMap['outputs_list'] = outputs_list

        # Special: intercept case of insufficient arguments
        # This happens when a property has been updated with a pipe component
        # TODO see if this can be attacked from the client end

        if len(args) < len(callback_info['inputs']):
            return 'EDGECASEEXIT'

        callback = callback_info["callback"]
        # smart injection of parameters if .expanded is defined
        if callback.expanded is not None:
            parameters_to_inject = {*callback.expanded, 'outputs_list'}
            res = callback(*args, **{k: v for k, v in argMap.items() if k in parameters_to_inject})
        else:
            res = callback(*args, **argMap)

        if da:
            root_value = json.loads(res).get('response', {})

            for output_item in outputs:
                if isinstance(output_item, str):
                    output_id, output_property = output_item.split('.')
                    if da.have_current_state_entry(output_id, output_property):
                        value = root_value.get(output_id,{}).get(output_property, None)
                        da.update_current_state(output_id, output_property, value)
                else:
                    # todo: implement saving of state for pattern matching ouputs
                    raise NotImplementedError("Updating state for dict keys (pattern matching) is not yet implemented")

        return res

    def slugified_id(self):
        'Return the app id in a slug-friendly form'
        pre_slugified_id = self._uid
        return slugify(pre_slugified_id)

    def extra_html_properties(self, prefix=None, postfix=None, template_type=None):
        '''
        Return extra html properties to allow individual apps to be styled separately.

        The content returned from this function is injected unescaped into templates.
        '''

        prefix = prefix if prefix else "django-plotly-dash"

        post_part = "-%s" % postfix if postfix else ""
        template_type = template_type if template_type else "iframe"

        slugified_id = self.slugified_id()

        return "%(prefix)s %(prefix)s-%(template_type)s %(prefix)s-app-%(slugified_id)s%(post_part)s" % {'slugified_id':slugified_id,
                                                                                                         'post_part':post_part,
                                                                                                         'template_type':template_type,
                                                                                                         'prefix':prefix,
                                                                                                        }

    def index(self, *args, **kwargs):  # pylint: disable=unused-argument
        scripts = self._generate_scripts_html()
        css = self._generate_css_dist_html()
        config = self._generate_config_html()
        metas = self._generate_meta_html()
        renderer = self._generate_renderer()
        title = getattr(self, 'title', 'Dash')
        if self._favicon:
            import flask
            favicon = '<link rel="icon" type="image/x-icon" href="{}">'.format(
                flask.url_for('assets.static', filename=self._favicon))
        else:
            favicon = ''

            _app_entry = '''
<div id="react-entry-point">
  <div class="_dash-loading">
    Loading...
  </div>
</div>
'''
        index = self.interpolate_index(
            metas=metas, title=title, css=css, config=config,
            scripts=scripts, app_entry=_app_entry, favicon=favicon,
            renderer=renderer)

        return index

    def interpolate_index(self, **kwargs): #pylint: disable=arguments-differ

        if not self._return_embedded:
            resp = super(WrappedDash, self).interpolate_index(**kwargs)
            return resp

        self._return_embedded.add_css(kwargs['css'])
        self._return_embedded.add_config(kwargs['config'])
        self._return_embedded.add_scripts(kwargs['scripts'])

        return kwargs['app_entry']

    def set_embedded(self, embedded_holder=None):
        'Set a handler for embedded references prior to evaluating a view function'
        self._return_embedded = embedded_holder if embedded_holder else EmbeddedHolder()

    def exit_embedded(self):
        'Exit the embedded section after processing a view'
        self._return_embedded = False
示例#41
0
#파일
@app.route('/upload', methods=['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        f = request.files['the_file']
        f.save('/var/www/uploads/uploaded_file.txt')
        # f.save('/var/www/uploads/' + secure_filename(f.filename)

#값 넘기기
@app.route('/hello/')
@app.route('/hello/<name>')
def hello(name=None):
    return render_template('hello.html', name=name)

#라우팅이 설정된 함수의 URL을 얻기위해 실제 요청처럼 처리
with app.test_request_context():
    print(url_for('hello_world'))
    print(url_for('show_user_profile', username='******'))
    print(url_for('login', next='/'))
    print(url_for('static', filename='style.css'))

#ToDO ?
with app.test_request_context('/hello', method='POST'):
    assert request.path == '/hello'
    assert request.method == 'POST'

#ToDo ?
# with app.request_context(environ):
#     assert request.method == 'POST'

#TODO ? 메시지 플래싱?
示例#42
0
class TestLoadModels(unittest.TestCase):
    app = app1

    def setUp(self):
        self.ctx = self.app.test_request_context()
        self.ctx.push()

        db.create_all()
        self.session = db.session
        c = Container(name=u'c')
        self.session.add(c)
        self.container = c
        self.nd1 = NamedDocument(container=c, title=u"Named Document")
        self.session.add(self.nd1)
        self.session.commit()
        self.nd2 = NamedDocument(container=c, title=u"Another Named Document")
        self.session.add(self.nd2)
        self.session.commit()
        self.rd1 = RedirectDocument(container=c,
                                    title=u"Redirect Document",
                                    target=self.nd1)
        self.session.add(self.rd1)
        self.session.commit()
        self.snd1 = ScopedNamedDocument(container=c,
                                        title=u"Scoped Named Document")
        self.session.add(self.snd1)
        self.session.commit()
        self.snd2 = ScopedNamedDocument(container=c,
                                        title=u"Another Scoped Named Document")
        self.session.add(self.snd2)
        self.session.commit()
        self.ind1 = IdNamedDocument(container=c, title=u"Id Named Document")
        self.session.add(self.ind1)
        self.session.commit()
        self.ind2 = IdNamedDocument(container=c,
                                    title=u"Another Id Named Document")
        self.session.add(self.ind2)
        self.session.commit()
        self.sid1 = ScopedIdDocument(container=c)
        self.session.add(self.sid1)
        self.session.commit()
        self.sid2 = ScopedIdDocument(container=c)
        self.session.add(self.sid2)
        self.session.commit()
        self.sind1 = ScopedIdNamedDocument(container=c,
                                           title=u"Scoped Id Named Document")
        self.session.add(self.sind1)
        self.session.commit()
        self.sind2 = ScopedIdNamedDocument(
            container=c, title=u"Another Scoped Id Named Document")
        self.session.add(self.sind2)
        self.session.commit()
        self.pc = ParentDocument(title=u"Parent")
        self.session.add(self.pc)
        self.session.commit()
        self.child1 = ChildDocument(parent=self.pc.middle)
        self.session.add(self.child1)
        self.session.commit()
        self.child2 = ChildDocument(parent=self.pc.middle)
        self.session.add(self.child2)
        self.session.commit()
        self.app = Flask(__name__)
        self.app.add_url_rule('/<container>/<document>', 'redirect_document',
                              t_redirect_document)

    def tearDown(self):
        self.session.rollback()
        db.drop_all()
        self.ctx.pop()

    def test_container(self):
        with self.app.test_request_context():
            login_manager.set_user_for_testing(User(username='******'),
                                               load=True)
            self.assertEqual(t_container(container=u'c'), self.container)

    def test_named_document(self):
        self.assertEqual(
            t_named_document(container=u'c', document=u'named-document'),
            self.nd1)
        self.assertEqual(
            t_named_document(container=u'c',
                             document=u'another-named-document'), self.nd2)

    def test_redirect_document(self):
        with self.app.test_request_context('/c/named-document'):
            self.assertEqual(
                t_redirect_document(container=u'c',
                                    document=u'named-document'), self.nd1)
        with self.app.test_request_context('/c/another-named-document'):
            self.assertEqual(
                t_redirect_document(container=u'c',
                                    document=u'another-named-document'),
                self.nd2)
        with self.app.test_request_context('/c/redirect-document'):
            response = t_redirect_document(container=u'c',
                                           document=u'redirect-document')
            self.assertEqual(response.status_code, 307)
            self.assertEqual(response.headers['Location'], '/c/named-document')
        with self.app.test_request_context(
                '/c/redirect-document?preserve=this'):
            response = t_redirect_document(container=u'c',
                                           document=u'redirect-document')
            self.assertEqual(response.status_code, 307)
            self.assertEqual(response.headers['Location'],
                             '/c/named-document?preserve=this')

    def test_scoped_named_document(self):
        self.assertEqual(
            t_scoped_named_document(container=u'c',
                                    document=u'scoped-named-document'),
            self.snd1)
        self.assertEqual(
            t_scoped_named_document(container=u'c',
                                    document=u'another-scoped-named-document'),
            self.snd2)

    def test_id_named_document(self):
        self.assertEqual(
            t_id_named_document(container=u'c',
                                document=u'1-id-named-document'), self.ind1)
        self.assertEqual(
            t_id_named_document(container=u'c',
                                document=u'2-another-id-named-document'),
            self.ind2)
        with self.app.test_request_context('/c/1-wrong-name'):
            r = t_id_named_document(container=u'c', document=u'1-wrong-name')
            self.assertEqual(r.status_code, 302)
            self.assertEqual(r.location, '/c/1-id-named-document')
        with self.app.test_request_context('/c/1-wrong-name?preserve=this'):
            r = t_id_named_document(container=u'c', document=u'1-wrong-name')
            self.assertEqual(r.status_code, 302)
            self.assertEqual(r.location,
                             '/c/1-id-named-document?preserve=this')
        self.assertRaises(NotFound,
                          t_id_named_document,
                          container=u'c',
                          document=u'random-non-integer')

    def test_scoped_id_document(self):
        self.assertEqual(t_scoped_id_document(container=u'c', document=u'1'),
                         self.sid1)
        self.assertEqual(t_scoped_id_document(container=u'c', document=u'2'),
                         self.sid2)
        self.assertEqual(t_scoped_id_document(container=u'c', document=1),
                         self.sid1)
        self.assertEqual(t_scoped_id_document(container=u'c', document=2),
                         self.sid2)

    def test_scoped_id_named_document(self):
        self.assertEqual(
            t_scoped_id_named_document(container=u'c',
                                       document=u'1-scoped-id-named-document'),
            self.sind1)
        self.assertEqual(
            t_scoped_id_named_document(
                container=u'c',
                document=u'2-another-scoped-id-named-document'), self.sind2)
        with self.app.test_request_context('/c/1-wrong-name'):
            r = t_scoped_id_named_document(container=u'c',
                                           document=u'1-wrong-name')
            self.assertEqual(r.status_code, 302)
            self.assertEqual(r.location, '/c/1-scoped-id-named-document')
        self.assertRaises(NotFound,
                          t_scoped_id_named_document,
                          container=u'c',
                          document=u'random-non-integer')

    def test_callable_document(self):
        self.assertEqual(t_callable_document(document=u'parent', child=1),
                         self.child1)
        self.assertEqual(t_callable_document(document=u'parent', child=2),
                         self.child2)

    def test_dotted_document(self):
        self.assertEqual(t_dotted_document(document=u'parent', child=1),
                         self.child1)
        self.assertEqual(t_dotted_document(document=u'parent', child=2),
                         self.child2)

    def test_direct_permissions(self):
        user1 = User(username='******')
        user2 = User(username='******')
        self.assertEqual(self.pc.permissions(user1),
                         set(['view', 'edit', 'delete']))
        self.assertEqual(self.pc.permissions(user2), set(['view']))
        self.assertEqual(
            self.child1.permissions(user1,
                                    inherited=self.pc.permissions(user1)),
            set(['view', 'edit']))
        self.assertEqual(
            self.child1.permissions(user2,
                                    inherited=self.pc.permissions(user2)),
            set(['view']))

    def test_inherited_permissions(self):
        user = User(username='******')
        self.assertEqual(
            self.pc.permissions(user, inherited=set(['add-video'])),
            set(['add-video', 'view']))

    def test_unmutated_inherited_permissions(self):
        """The inherited permission set should not be mutated by a permission check"""
        user = User(username='******')
        inherited = set(['add-video'])
        self.assertEqual(self.pc.permissions(user, inherited=inherited),
                         set(['add-video', 'view']))
        self.assertEqual(inherited, set(['add-video']))

    def test_loadmodel_permissions(self):
        with self.app.test_request_context():
            login_manager.set_user_for_testing(User(username='******'), load=True)
            self.assertEqual(
                t_dotted_document_view(document=u'parent', child=1),
                self.child1)
            self.assertEqual(
                t_dotted_document_edit(document=u'parent', child=1),
                self.child1)
            self.assertRaises(Forbidden,
                              t_dotted_document_delete,
                              document=u'parent',
                              child=1)

    def test_load_user_to_g(self):
        with self.app.test_request_context():
            user = User(username=u'baz')
            self.session.add(user)
            self.session.commit()
            self.assertFalse(hasattr(g, 'user'))
            self.assertEqual(t_load_user_to_g(username=u'baz'), g.user)
            self.assertRaises(NotFound, t_load_user_to_g, username=u'boo')

    def test_single_model_in_loadmodels(self):
        with self.app.test_request_context():
            user = User(username=u'user1')
            self.session.add(user)
            self.session.commit()
            self.assertEqual(t_single_model_in_loadmodels(username=u'user1'),
                             g.user)
示例#43
0
 def test_url_for(self):
     app = Flask(__name__)
     api = flask_restful.Api(app)
     api.add_resource(HelloWorld, '/ids/<int:id>')
     with app.test_request_context('/foo'):
         self.assertEqual(api.url_for(HelloWorld, id=123), '/ids/123')
示例#44
0
class TestOIDCAuthentication(object):
    mock_time = Mock()
    mock_time_int = Mock()
    mock_time.return_value = time.mktime(datetime(2017, 1, 1).timetuple())
    mock_time_int.return_value = int(
        time.mktime(datetime(2017, 1, 1).timetuple()))

    @pytest.fixture(autouse=True)
    def create_flask_app(self):
        self.app = Flask(__name__)
        self.app.config.update({
            'SERVER_NAME': 'localhost',
            'SECRET_KEY': 'test_key'
        })

    @responses.activate
    def test_store_internal_redirect_uri_on_static_client_reg(self):
        responses.add(responses.GET,
                      ISSUER + '/.well-known/openid-configuration',
                      body=json.dumps(
                          dict(issuer=ISSUER,
                               token_endpoint=ISSUER + '/token')),
                      content_type='application/json')

        authn = OIDCAuthentication(self.app,
                                   issuer=ISSUER,
                                   client_registration_info=dict(
                                       client_id='abc', client_secret='foo'))
        assert len(authn.client.registration_response['redirect_uris']) == 1
        assert authn.client.registration_response['redirect_uris'][
            0] == 'http://localhost/redirect_uri'

    @pytest.mark.parametrize('method', ['GET', 'POST'])
    def test_configurable_userinfo_endpoint_method_is_used(self, method):
        state = 'state'
        nonce = 'nonce'
        sub = 'foobar'
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={
                'issuer': ISSUER,
                'token_endpoint': '/token'
            },
            client_registration_info={'client_id': 'foo'},
            userinfo_endpoint_method=method)
        authn.client.do_access_token_request = MagicMock(
            return_value=AccessTokenResponse(
                **{
                    'id_token': IdToken(**{
                        'sub': sub,
                        'nonce': nonce
                    }),
                    'access_token': 'access_token'
                }))
        userinfo_request_mock = MagicMock(return_value=OpenIDSchema(
            **{'sub': sub}))
        authn.client.do_user_info_request = userinfo_request_mock
        with self.app.test_request_context('/redirect_uri?code=foo&state=' +
                                           state):
            flask.session['state'] = state
            flask.session['nonce'] = nonce
            flask.session['destination'] = '/'
            authn._handle_authentication_response()
        userinfo_request_mock.assert_called_with(method=method, state=state)

    def test_no_userinfo_request_is_done_if_no_userinfo_endpoint_method_is_specified(
            self):
        state = 'state'
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info={'client_id': 'foo'},
            userinfo_endpoint_method=None)
        userinfo_request_mock = MagicMock()
        authn.client.do_user_info_request = userinfo_request_mock
        authn._do_userinfo_request(state, None)
        assert not userinfo_request_mock.called

    def test_authenticatate_with_extra_request_parameters(self):
        extra_params = {"foo": "bar", "abc": "xyz"}
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info={'client_id': 'foo'},
            extra_request_args=extra_params)

        with self.app.test_request_context('/'):
            a = authn._authenticate()
        request_params = dict(parse_qsl(urlparse(a.location).query))
        assert set(extra_params.items()).issubset(set(request_params.items()))

    def test_reauthenticate_if_no_session(self):
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info={'client_id': 'foo'})
        client_mock = MagicMock()
        callback_mock = MagicMock()
        callback_mock.__name__ = 'test_callback'  # required for Python 2
        authn.client = client_mock
        with self.app.test_request_context('/'):
            authn.oidc_auth(callback_mock)()
        assert client_mock.construct_AuthorizationRequest.called
        assert not callback_mock.called

    def test_reauthenticate_silent_if_refresh_expired(self):
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info={
                'client_id': 'foo',
                'session_refresh_interval_seconds': 1
            })
        client_mock = MagicMock()
        callback_mock = MagicMock()
        callback_mock.__name__ = 'test_callback'  # required for Python 2
        authn.client = client_mock
        with self.app.test_request_context('/'):
            flask.session['last_authenticated'] = time.time(
            ) - 1  # authenticated in the past
            authn.oidc_auth(callback_mock)()
        assert client_mock.construct_AuthorizationRequest.called
        assert client_mock.construct_AuthorizationRequest.call_args[1][
            'request_args']['prompt'] == 'none'
        assert not callback_mock.called

    def test_dont_reauthenticate_silent_if_authentication_not_expired(self):
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info={
                'client_id': 'foo',
                'session_refresh_interval_seconds': 999
            })
        client_mock = MagicMock()
        callback_mock = MagicMock()
        callback_mock.__name__ = 'test_callback'  # required for Python 2
        authn.client = client_mock
        with self.app.test_request_context('/'):
            flask.session['last_authenticated'] = time.time(
            )  # freshly authenticated
            authn.oidc_auth(callback_mock)()
        assert not client_mock.construct_AuthorizationRequest.called
        assert callback_mock.called

    @patch('time.time', mock_time)
    @patch('oic.utils.time_util.utc_time_sans_frac', mock_time_int)
    @responses.activate
    def test_session_expiration_set_to_id_token_exp(self):
        token_endpoint = ISSUER + '/token'
        userinfo_endpoint = ISSUER + '/userinfo'
        exp_time = 10
        epoch_int = int(time.mktime(datetime(2017, 1, 1).timetuple()))
        id_token = IdToken(
            **{
                'sub': 'sub1',
                'iat': epoch_int,
                'iss': ISSUER,
                'aud': 'foo',
                'nonce': 'test',
                'exp': epoch_int + exp_time
            })
        token_response = {
            'access_token': 'test',
            'token_type': 'Bearer',
            'id_token': id_token.to_jwt()
        }
        userinfo_response = {'sub': 'sub1'}
        responses.add(responses.POST,
                      token_endpoint,
                      body=json.dumps(token_response),
                      content_type='application/json')
        responses.add(responses.POST,
                      userinfo_endpoint,
                      body=json.dumps(userinfo_response),
                      content_type='application/json')
        authn = OIDCAuthentication(self.app,
                                   provider_configuration_info={
                                       'issuer': ISSUER,
                                       'token_endpoint': token_endpoint,
                                       'userinfo_endpoint': userinfo_endpoint
                                   },
                                   client_registration_info={
                                       'client_id': 'foo',
                                       'client_secret': 'foo'
                                   })

        self.app.config.update({'SESSION_PERMANENT': True})
        with self.app.test_request_context(
                '/redirect_uri?state=test&code=test'):
            flask.session['destination'] = '/'
            flask.session['state'] = 'test'
            flask.session['nonce'] = 'test'
            authn._handle_authentication_response()
            assert flask.session.permanent
            assert int(flask.session.permanent_session_lifetime) == exp_time

    def test_logout(self):
        end_session_endpoint = 'https://provider.example.com/end_session'
        post_logout_uri = 'https://client.example.com/post_logout'
        authn = OIDCAuthentication(self.app,
                                   provider_configuration_info={
                                       'issuer': ISSUER,
                                       'end_session_endpoint':
                                       end_session_endpoint
                                   },
                                   client_registration_info={
                                       'client_id':
                                       'foo',
                                       'post_logout_redirect_uris':
                                       [post_logout_uri]
                                   })
        id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'})
        with self.app.test_request_context('/logout'):
            flask.session['access_token'] = 'abcde'
            flask.session['userinfo'] = {'foo': 'bar', 'abc': 'xyz'}
            flask.session['id_token'] = id_token.to_dict()
            flask.session['id_token_jwt'] = id_token.to_jwt()

            end_session_redirect = authn._logout()
            assert all(
                k not in flask.session for k in
                ['access_token', 'userinfo', 'id_token', 'id_token_jwt'])

            assert end_session_redirect.status_code == 303
            assert end_session_redirect.headers['Location'].startswith(
                end_session_endpoint)
            parsed_request = dict(
                parse_qsl(
                    urlparse(end_session_redirect.headers['Location']).query))
            assert parsed_request['state'] == flask.session[
                'end_session_state']
            assert parsed_request['id_token_hint'] == id_token.to_jwt()
            assert parsed_request[
                'post_logout_redirect_uri'] == post_logout_uri

    def test_logout_handles_provider_without_end_session_endpoint(self):
        post_logout_uri = 'https://client.example.com/post_logout'
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info={
                'client_id': 'foo',
                'post_logout_redirect_uris': [post_logout_uri]
            })
        id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'})
        with self.app.test_request_context('/logout'):
            flask.session['access_token'] = 'abcde'
            flask.session['userinfo'] = {'foo': 'bar', 'abc': 'xyz'}
            flask.session['id_token'] = id_token.to_dict()
            flask.session['id_token_jwt'] = id_token.to_jwt()

            end_session_redirect = authn._logout()
            assert all(
                k not in flask.session for k in
                ['access_token', 'userinfo', 'id_token', 'id_token_jwt'])
        assert end_session_redirect is None

    def test_oidc_logout_redirects_to_provider(self):
        end_session_endpoint = 'https://provider.example.com/end_session'
        post_logout_uri = 'https://client.example.com/post_logout'
        authn = OIDCAuthentication(self.app,
                                   provider_configuration_info={
                                       'issuer': ISSUER,
                                       'end_session_endpoint':
                                       end_session_endpoint
                                   },
                                   client_registration_info={
                                       'client_id':
                                       'foo',
                                       'post_logout_redirect_uris':
                                       [post_logout_uri]
                                   })
        callback_mock = MagicMock()
        callback_mock.__name__ = 'test_callback'  # required for Python 2
        id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'})
        with self.app.test_request_context('/logout'):
            flask.session['id_token_jwt'] = id_token.to_jwt()
            resp = authn.oidc_logout(callback_mock)()
        assert resp.status_code == 303
        assert not callback_mock.called

    def test_oidc_logout_handles_redirects_from_provider(self):
        end_session_endpoint = 'https://provider.example.com/end_session'
        post_logout_uri = 'https://client.example.com/post_logout'
        authn = OIDCAuthentication(self.app,
                                   provider_configuration_info={
                                       'issuer': ISSUER,
                                       'end_session_endpoint':
                                       end_session_endpoint
                                   },
                                   client_registration_info={
                                       'client_id':
                                       'foo',
                                       'post_logout_redirect_uris':
                                       [post_logout_uri]
                                   })
        callback_mock = MagicMock()
        callback_mock.__name__ = 'test_callback'  # required for Python 2
        state = 'end_session_123'
        with self.app.test_request_context('/logout?state=' + state):
            flask.session['end_session_state'] = state
            authn.oidc_logout(callback_mock)()
            assert 'end_session_state' not in flask.session
        assert callback_mock.called

    def test_authentication_error_reponse_calls_to_error_view_if_set(self):
        state = 'test_tate'
        error_response = {
            'error': 'invalid_request',
            'error_description': 'test error'
        }
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info=dict(client_id='abc',
                                          client_secret='foo'))
        error_view_mock = MagicMock()
        authn._error_view = error_view_mock
        with self.app.test_request_context(
                '/redirect_uri?{error}&state={state}'.format(
                    error=urlencode(error_response), state=state)):
            flask.session['state'] = state
            authn._handle_authentication_response()
        error_view_mock.assert_called_with(**error_response)

    def test_authentication_error_reponse_returns_default_error_if_no_error_view_set(
            self):
        state = 'test_tate'
        error_response = {
            'error': 'invalid_request',
            'error_description': 'test error'
        }
        authn = OIDCAuthentication(
            self.app,
            provider_configuration_info={'issuer': ISSUER},
            client_registration_info=dict(client_id='abc',
                                          client_secret='foo'))
        with self.app.test_request_context(
                '/redirect_uri?{error}&state={state}'.format(
                    error=urlencode(error_response), state=state)):
            flask.session['state'] = state
            response = authn._handle_authentication_response()
        assert response == "Something went wrong with the authentication, please try to login again."

    @responses.activate
    def test_token_error_reponse_calls_to_error_view_if_set(self):
        token_endpoint = ISSUER + '/token'
        error_response = {
            'error': 'invalid_request',
            'error_description': 'test error'
        }
        responses.add(responses.POST,
                      token_endpoint,
                      body=json.dumps(error_response),
                      content_type='application/json')

        authn = OIDCAuthentication(self.app,
                                   provider_configuration_info={
                                       'issuer': ISSUER,
                                       'token_endpoint': token_endpoint
                                   },
                                   client_registration_info=dict(
                                       client_id='abc', client_secret='foo'))
        error_view_mock = MagicMock()
        authn._error_view = error_view_mock
        state = 'test_tate'
        with self.app.test_request_context('/redirect_uri?code=foo&state=' +
                                           state):
            flask.session['state'] = state
            authn._handle_authentication_response()
        error_view_mock.assert_called_with(**error_response)

    @responses.activate
    def test_token_error_reponse_returns_default_error_if_no_error_view_set(
            self):
        token_endpoint = ISSUER + '/token'
        error_response = {
            'error': 'invalid_request',
            'error_description': 'test error'
        }
        responses.add(responses.POST,
                      token_endpoint,
                      body=json.dumps(error_response),
                      content_type='application/json')

        authn = OIDCAuthentication(self.app,
                                   provider_configuration_info={
                                       'issuer': ISSUER,
                                       'token_endpoint': token_endpoint
                                   },
                                   client_registration_info=dict(
                                       client_id='abc', client_secret='foo'))
        state = 'test_tate'
        with self.app.test_request_context('/redirect_uri?code=foo&state=' +
                                           state):
            flask.session['state'] = state
            response = authn._handle_authentication_response()
        assert response == "Something went wrong with the authentication, please try to login again."
示例#45
0
 def test_resource_head(self):
     app = Flask(__name__)
     resource = flask_restful.Resource()
     with app.test_request_context("/foo", method="HEAD"):
         self.assertRaises(AssertionError,
                           lambda: resource.dispatch_request())
示例#46
0
    return "Post %d" % post_id


@app.route("/user/<path:subpath>")
def show_subpath(subpath):
    """

    :param subpath:
    :return:
    """
    # Show the subpath after /hello/
    return "Subpath %s" % escape(subpath)


# Printing tests for different paths
# simulating webapp handling requests
with app.test_request_context():
    print(url_for("show_user_profile", username="******"))
    print(url_for("show_post", post_id=123))
    print(url_for("show_subpath", subpath="path/next"))

# Testing context locals and sending requests with POST method
with app.test_request_context("/login", method="POST"):
    assert request.path == "/login"
    assert request.method == "POST"

# more request functionality found in login() function

if __name__ == "__main__":
    app.run()
class FlaskSessionCaptchaTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SECRET_KEY'] = 'aba'
        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://'
        self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
        self.app.config['SESSION_TYPE'] = 'sqlalchemy'
        self.app.config['CAPTCHA_ENABLE'] = True
        self.app.config['CAPTCHA_LENGTH'] = 5
        self.app.testing = True
        Session(self.app)

        self.client = self.app.test_client()

    def test_captcha_wrong(self):
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        # try some wrong values
        r = self.client.post("/", data={"s": "something"})
        assert r.data == b"nope"
        r = self.client.post("/", data={"s": "something", "captcha": ""})
        assert r.data == b"nope"
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": "also wrong"
                             })
        assert r.data == b"nope"

    def test_captcha_without_cookie(self):
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        # without right cookie
        r = self.client.get("/")
        self.client.set_cookie("localhost", "session", "wrong")
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": r.data.decode('utf-8')
                             })
        assert r.data == b"nope"  # no session

    def test_captcha_ok(self):
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)
        # everything ok
        r = self.client.get("/")
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": r.data.decode('utf-8')
                             })
        assert r.data == b"ok"

    def test_captcha_replay(self):
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        r = self.client.get("/")
        captcha_value = r.data.decode('utf-8')

        cookies = self.client.cookie_jar._cookies['localhost.local']['/'][
            'session']
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": captcha_value
                             })
        assert r.data == b"ok"
        self.client.set_cookie("localhost", "session", cookies.value)
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": captcha_value
                             })
        assert r.data == b"nope"

    def test_captcha_passthrough_when_disabled(self):
        self.app.config["CAPTCHA_ENABLE"] = False
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        r = self.client.post("/", data={"s": "something"})
        assert r.data == b"ok"
        r = self.client.get("/")
        captcha_value = r.data.decode('utf-8')
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": captcha_value
                             })
        assert r.data == b"ok"
        r = self.client.post("/", data={"s": "something", "captcha": "false"})
        assert r.data == b"ok"

    def test_captcha_least_digits(self):
        self.app.config["CAPTCHA_LENGTH"] = 8
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        r = self.client.get("http://localhost:5000/")
        captcha_value = r.data.decode('utf-8')
        assert len(captcha_value) == 8

    def test_captcha_validate_value(self):
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        with self.app.test_request_context('/'):
            captcha.generate()
            answer = captcha.get_answer()
            assert not captcha.validate(value="wrong")
            captcha.generate()
            answer = captcha.get_answer()
            assert captcha.validate(value=answer)

    def test_captcha_jinja_global(self):
        captcha = FlaskSessionCaptcha(self.app)
        with self.app.test_request_context('/'):
            function = self.app.jinja_env.globals['captcha']
            assert not captcha.get_answer()
            img = function()
            assert "<img" in img
            assert captcha.get_answer()

    def test_captcha_jinja_global_empty_while_disabled(self):
        self.app.config["CAPTCHA_ENABLE"] = False
        captcha = FlaskSessionCaptcha(self.app)
        with self.app.test_request_context('/'):
            function = self.app.jinja_env.globals['captcha']
            try:
                captcha.get_answer()
                assert False
            except:
                pass
            img = function()
            assert img == ""

    def test_captcha_warning_on_non_server_storage(self):
        self.app.config['SESSION_TYPE'] = 'null'
        Session(self.app)
        with self.assertRaises(RuntimeWarning):
            FlaskSessionCaptcha(self.app)
        self.app.config['SESSION_TYPE'] = None
        Session(self.app)
        with self.assertRaises(RuntimeWarning):
            FlaskSessionCaptcha(self.app)

    def test_captcha_session_file_storage(self):
        self.app.config['SESSION_TYPE'] = 'filesystem'
        Session(self.app)
        captcha = FlaskSessionCaptcha(self.app)
        _default_routes(captcha, self.app)

        r = self.client.get("/")
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": r.data.decode('utf-8')
                             })
        assert r.data == b"ok"

    def test_captcha_with_init_app_ok(self):
        captcha = FlaskSessionCaptcha()
        _default_routes(captcha, self.app)
        captcha.init_app(self.app)
        # everything ok
        r = self.client.get("/")
        r = self.client.post("/",
                             data={
                                 "s": "something",
                                 "captcha": r.data.decode('utf-8')
                             })
        assert r.data == b"ok"

    def tearDown(self):
        pass
class AvatarsTestCase(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)

        self.app.testing = True
        self.app.secret_key = 'for test'

        self.email_hash = hashlib.md5(
            '*****@*****.**'.lower().encode('utf-8')).hexdigest()

        avatars = Avatars(self.app)  # noqa

        self.real_avatars = avatars

        self.avatars = _Avatars

        self.context = self.app.test_request_context()
        self.context.push()
        self.client = self.app.test_client()

    def tearDown(self):
        self.context.pop()

    def test_extension_init(self):
        self.assertIn('avatars', current_app.extensions)

    def test_gravatar(self):
        avatar_url = self.avatars.gravatar(self.email_hash)
        self.assertIn('https://gravatar.com/avatar/%s' % self.email_hash,
                      avatar_url)
        self.assertIn('s=100', avatar_url)
        self.assertIn('r=g', avatar_url)
        self.assertIn('d=identicon', avatar_url)

        avatar_url = self.avatars.gravatar(self.email_hash,
                                           size=200,
                                           rating='x',
                                           default='monsterid',
                                           include_extension=True)
        self.assertIn('https://gravatar.com/avatar/%s' % self.email_hash,
                      avatar_url)
        self.assertIn('s=200', avatar_url)
        self.assertIn('r=x', avatar_url)
        self.assertIn('d=monsterid', avatar_url)

    def test_robohash(self):
        avatar_url = self.avatars.robohash(self.email_hash)
        self.assertEqual(
            avatar_url,
            'https://robohash.org/%s?size=200x200' % self.email_hash)

        avatar_url = self.avatars.robohash(self.email_hash, size=100)
        self.assertEqual(
            avatar_url,
            'https://robohash.org/%s?size=100x100' % self.email_hash)

    def test_social_media(self):
        avatar_url = self.avatars.social_media('greyli')
        self.assertEqual(avatar_url,
                         'https://avatars.io/twitter/greyli/medium')

        avatar_url = self.avatars.social_media('greyli',
                                               platform='facebook',
                                               size='small')
        self.assertEqual(avatar_url,
                         'https://avatars.io/facebook/greyli/small')

    def test_default_avatar(self):
        avatar_url = self.avatars.default()
        self.assertEqual(avatar_url, '/avatars/static/default/default_m.jpg')

        avatar_url = self.avatars.default(size='l')
        self.assertEqual(avatar_url, '/avatars/static/default/default_l.jpg')

        avatar_url = self.avatars.default(size='s')
        self.assertEqual(avatar_url, '/avatars/static/default/default_s.jpg')

        response = self.client.get(avatar_url)
        self.assertEqual(response.status_code, 200)

    def test_load_jcrop(self):
        rv = self.avatars.jcrop_css()
        self.assertIn('<link rel="stylesheet" href="https://cdn.jsdelivr.net',
                      rv)

        rv = self.avatars.jcrop_js()
        self.assertIn('<script src="https://cdn.jsdelivr.net', rv)
        self.assertIn('jquery.min.js', rv)

        rv = self.avatars.jcrop_js(with_jquery=False)
        self.assertIn('jquery.Jcrop.min.js', rv)
        self.assertNotIn('jquery.min.js', rv)

    def test_local_resources(self):
        current_app.config['AVATARS_SERVE_LOCAL'] = True

        response = self.client.get(
            '/avatars/static/jcrop/js/jquery.Jcrop.min.js')
        self.assertEqual(response.status_code, 200)

        response = self.client.get(
            '/avatars/static/jcrop/css/jquery.Jcrop.min.css')
        self.assertEqual(response.status_code, 200)

        response = self.client.get('/avatars/static/jcrop/js/jquery.min.js')
        self.assertEqual(response.status_code, 200)

        rv = self.avatars.jcrop_css()
        self.assertIn('/avatars/static/jcrop/css/jquery.Jcrop.min.css', rv)
        self.assertNotIn(
            '<link rel="stylesheet" href="https://cdn.jsdelivr.net', rv)

        rv = self.avatars.jcrop_js()
        self.assertIn('/avatars/static/jcrop/js/jquery.Jcrop.min.js', rv)
        self.assertIn('/avatars/static/jcrop/js/jquery.min.js', rv)
        self.assertNotIn('<script src="https://cdn.jsdelivr.net', rv)

        rv = self.avatars.jcrop_js(with_jquery=False)
        self.assertIn('/avatars/static/jcrop/js/jquery.Jcrop.min.js', rv)
        self.assertNotIn('jquery.min.js', rv)

    def test_init_jcrop(self):
        rv = self.avatars.init_jcrop()
        self.assertIn('var jcrop_api,', rv)

    def test_crop_box(self):
        rv = self.avatars.crop_box()
        self.assertIn('id="crop-box"', rv)

    def test_preview_box(self):
        rv = self.avatars.preview_box()
        self.assertIn('<div id="preview-box">', rv)

    def test_render_template(self):
        rv = render_template_string('''{{ avatars.jcrop_css() }}''')
        self.assertIn('<link rel="stylesheet" href="https://cdn.jsdelivr.net',
                      rv)

        rv = render_template_string('''{{ avatars.jcrop_js() }}''')
        self.assertIn('<script src="https://cdn.jsdelivr.net', rv)

        rv = render_template_string('''{{ avatars.init_jcrop() }}''')
        self.assertIn('var jcrop_api,', rv)

        rv = render_template_string('''{{ avatars.crop_box() }}''')
        self.assertIn('id="crop-box"', rv)

        rv = render_template_string('''{{ avatars.preview_box() }}''')
        self.assertIn('<div id="preview-box">', rv)

    def test_resize_avatar(self):
        current_app.config['AVATARS_SAVE_PATH'] = basedir
        img = Image.new(mode='RGB', size=(800, 800), color=(125, 125, 125))
        resized = self.real_avatars.resize_avatar(img, 300)
        self.assertEqual(resized.size[0], 300)

    def test_save_avatar(self):
        current_app.config['AVATARS_SAVE_PATH'] = basedir
        img = Image.new(mode='RGB', size=(800, 800), color=(125, 125, 125))
        filename = self.real_avatars.save_avatar(img)
        self.assertIn('_raw.png', filename)
        self.assertTrue(os.path.exists(os.path.join(basedir, filename)))
        os.remove(os.path.join(basedir, filename))

    def test_crop_avatar(self):
        current_app.config['AVATARS_SAVE_PATH'] = basedir
        img = Image.new(mode='RGB', size=(1000, 1000), color=(125, 125, 125))
        img.save(os.path.join(basedir, 'test.png'))
        filenames = self.real_avatars.crop_avatar('test.png',
                                                  x=1,
                                                  y=1,
                                                  w=30,
                                                  h=30)
        self.assertIn('_s.png', filenames[0])
        self.assertIn('_m.png', filenames[1])
        self.assertIn('_l.png', filenames[2])

        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[0])))
        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[1])))
        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[2])))

        file_s = Image.open(
            os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                         filenames[0]))
        self.assertEqual(file_s.size[0],
                         current_app.config['AVATARS_SIZE_TUPLE'][0])
        file_s.close()

        for filename in filenames:
            os.remove(os.path.join(basedir, filename))

        os.remove(os.path.join(basedir, 'test.png'))

    def test_crop_default_avatar(self):
        current_app.config['AVATARS_SAVE_PATH'] = basedir
        filenames = self.real_avatars.crop_avatar(None, x=1, y=1, w=100, h=100)
        self.assertIn('_s.png', filenames[0])
        self.assertIn('_m.png', filenames[1])
        self.assertIn('_l.png', filenames[2])

        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[0])))
        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[1])))
        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[2])))

        file_s = Image.open(
            os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                         filenames[0]))
        self.assertEqual(file_s.size[0],
                         current_app.config['AVATARS_SIZE_TUPLE'][0])
        file_s.close()

        for filename in filenames:
            os.remove(os.path.join(basedir, filename))

    def test_gravatar_mirror(self):
        mirror = self.real_avatars.gravatar(self.email_hash)
        real = self.avatars.gravatar(self.email_hash)
        self.assertEqual(mirror, real)

    def test_robohash_mirror(self):
        mirror = self.real_avatars.robohash(self.email_hash)
        real = self.avatars.robohash(self.email_hash)
        self.assertEqual(mirror, real)

    def test_social_media_mirror(self):
        mirror = self.real_avatars.social_media('grey')
        real = self.avatars.social_media('grey')
        self.assertEqual(mirror, real)

    def test_default_avatar_mirror(self):
        mirror = self.real_avatars.default()
        real = self.avatars.default()
        self.assertEqual(mirror, real)

    def test_identicon(self):
        current_app.config['AVATARS_SAVE_PATH'] = basedir

        avatar = Identicon()
        filenames = avatar.generate(text='grey')
        self.assertEqual(filenames[0], 'grey_s.png')
        self.assertEqual(filenames[1], 'grey_m.png')
        self.assertEqual(filenames[2], 'grey_l.png')

        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[0])))
        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[1])))
        self.assertTrue(
            os.path.exists(
                os.path.join(current_app.config['AVATARS_SAVE_PATH'],
                             filenames[2])))

        # comment out these two lines to check the generated image, then delete them manually.
        for filename in filenames:
            os.remove(os.path.join(basedir, filename))
示例#49
0
    result = re.split('\n', result)
    url = 'http://' + str(config['ip_diskstation']) + ':' + str(config['port'])
    syno = synology.synology(url, config['usuario'], config['password'])
    toAdd = ''
    for download in result:
        download = download.replace('\r', '')
        toAdd += download + ', '
    t = 'resultado.html'
    if (syno.addDownload(toAdd)):
        resultado = u'Añadido'
    else:
        resultado = u'Algo a fallado'
    return render_template(t, resultado=resultado)


with app.test_request_context('/', method='POST'):
    # now you can do something with the request until the
    # end of the with block, such as basic assertions:
    assert request.path == '/'
    assert request.method == 'POST'

with app.test_request_context('/enlaces', method='POST'):
    # now you can do something with the request until the
    # end of the with block, such as basic assertions:
    assert request.path == '/enlaces'
    assert request.method == 'POST'

with app.test_request_context('/enlaces', method='GET'):
    # now you can do something with the request until the
    # end of the with block, such as basic assertions:
    assert request.path == '/enlaces'
class OnFetchedMethodTestCase(ItemsServiceTestCase):
    """Tests for the on_fetched() method."""
    def setUp(self):
        super().setUp()

        self.app = Flask(__name__)
        self.app.config['PUBLICAPI_URL'] = 'http://content_api.com'
        self.app.config['URLS'] = {'items': 'items_endpoint'}

        self.app_context = self.app.app_context()
        self.app_context.push()
        self.req_context = self.app.test_request_context('items/')
        self.req_context.push()

    def tearDown(self):
        self.req_context.pop()
        self.app_context.pop()
        super().tearDown()

    def test_sets_uri_field_on_all_fetched_documents(self):
        result = {
            '_items': [
                {
                    '_id': 'item:123',
                    'headline': 'a test item'
                },
                {
                    '_id': 'item:555',
                    'headline': 'another item'
                },
            ]
        }

        instance = self._make_one(datasource='items')
        instance.on_fetched(result)

        documents = result['_items']
        self.assertEqual(
            documents[0].get('uri'),
            'http://content_api.com/items_endpoint/item%3A123'  # %3A == urlquote(':')
        )
        self.assertEqual(
            documents[1].get('uri'),
            'http://content_api.com/items_endpoint/item%3A555'  # %3A == urlquote(':')
        )

    def test_removes_non_ninjs_content_fields_from_all_fetched_documents(self):
        result = {
            '_items': [{
                '_id': 'item:123',
                '_etag': '12345abcde',
                '_created': '12345abcde',
                '_updated': '12345abcde',
                'headline': 'breaking news',
            }, {
                '_id': 'item:555',
                '_etag': '67890fedcb',
                '_created': '2121abab',
                '_updated': '2121abab',
                'headline': 'good news',
            }]
        }

        instance = self._make_one(datasource='items')
        instance.on_fetched(result)

        documents = result['_items']
        for doc in documents:
            for field in ('_created', '_etag', '_id', '_updated'):
                self.assertNotIn(field, doc)

    def test_does_not_remove_hateoas_links_from_fetched_documents(self):
        result = {
            '_items': [{
                '_id': 'item:123',
                '_etag': '12345abcde',
                '_created': '12345abcde',
                '_updated': '12345abcde',
                'headline': 'breaking news',
                '_links': {
                    'self': {
                        'href': 'link/to/item_123',
                        'title': 'Item'
                    }
                }
            }, {
                '_id': 'item:555',
                '_etag': '67890fedcb',
                '_created': '2121abab',
                '_updated': '2121abab',
                'headline': 'good news',
                '_links': {
                    'self': {
                        'href': 'link/to/item_555',
                        'title': 'Item'
                    }
                }
            }]
        }

        instance = self._make_one(datasource='items')
        instance.on_fetched(result)

        documents = result['_items']

        expected_links = {
            'self': {
                'href': 'link/to/item_123',
                'title': 'Item'
            }
        }
        self.assertEqual(documents[0].get('_links'), expected_links)

        expected_links = {
            'self': {
                'href': 'link/to/item_555',
                'title': 'Item'
            }
        }
        self.assertEqual(documents[1].get('_links'), expected_links)

    def test_sets_collection_self_link_to_relative_original_url(self):
        result = {'_items': [], '_links': {'self': {'href': 'foo/bar/baz'}}}

        request_url = 'items?start_date=1975-12-31#foo'
        with self.app.test_request_context(request_url):
            instance = self._make_one(datasource='items')
            instance.on_fetched(result)

        self_link = result.get('_links', {}).get('self', {}).get('href')
        self.assertEqual(self_link, 'items?start_date=1975-12-31')
示例#51
0
class TestJWTManager(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)

    def _parse_callback_result(self, result):
        """
        Returns a tuple, where the first item is http status code and
        the second is the data (via json.loads)
        """
        response = result[0]
        status_code = result[1]
        data = json.loads(response.get_data(as_text=True))
        return status_code, data

    def test_init_app(self):
        jwt_manager = JWTManager()
        jwt_manager.init_app(self.app)
        self.assertIsInstance(jwt_manager, JWTManager)

    def test_class_init(self):
        jwt_manager = JWTManager(self.app)
        self.assertIsInstance(jwt_manager, JWTManager)

    def test_default_user_claims_callback(self):
        identity = 'foobar'
        m = JWTManager(self.app)
        self.assertEqual(m._user_claims_callback(identity), {})

    def test_default_user_identity_callback(self):
        identity = 'foobar'
        m = JWTManager(self.app)
        self.assertEqual(m._user_identity_callback(identity), identity)

    def test_default_expired_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)
            result = m._expired_token_callback()
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 401)
            self.assertEqual(data, {'msg': 'Token has expired'})

    def test_default_invalid_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)
            err = "Test error"
            result = m._invalid_token_callback(err)
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 422)
            self.assertEqual(data, {'msg': err})

    def test_default_unauthorized_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)
            result = m._unauthorized_callback("Missing Authorization Header")
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 401)
            self.assertEqual(data, {'msg': 'Missing Authorization Header'})

    def test_default_needs_fresh_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)
            result = m._needs_fresh_token_callback()
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 401)
            self.assertEqual(data, {'msg': 'Fresh token required'})

    def test_default_revoked_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)
            result = m._revoked_token_callback()
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 401)
            self.assertEqual(data, {'msg': 'Token has been revoked'})

    def test_default_user_loader_callback(self):
        m = JWTManager(self.app)
        self.assertEqual(m._user_loader_callback, None)

    def test_default_user_loader_error_callback(self):
        with self.app.test_request_context():
            identity = 'foobar'
            m = JWTManager(self.app)
            result = m._user_loader_error_callback(identity)
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 401)
            self.assertEqual(data, {'msg': 'Error loading the user foobar'})

    def test_default_has_user_loader(self):
        m = JWTManager(self.app)
        self.assertEqual(m.has_user_loader(), False)

    def test_custom_user_claims_callback(self):
        identity = 'foobar'
        m = JWTManager(self.app)

        @m.user_claims_loader
        def custom_user_claims(identity):
            return {'foo': 'bar'}

        assert m._user_claims_callback(identity) == {'foo': 'bar'}

    def test_custom_expired_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)

            @m.expired_token_loader
            def custom_expired_token():
                return jsonify({"res": "TOKEN IS EXPIRED FOOL"}), 422

            result = m._expired_token_callback()
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 422)
            self.assertEqual(data, {'res': 'TOKEN IS EXPIRED FOOL'})

    def test_custom_invalid_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)
            err = "Test error"

            @m.invalid_token_loader
            def custom_invalid_token(err):
                return jsonify({"err": err}), 200

            result = m._invalid_token_callback(err)
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 200)
            self.assertEqual(data, {'err': err})

    def test_custom_unauthorized_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)

            @m.unauthorized_loader
            def custom_unauthorized(err_str):
                return jsonify({"err": err_str}), 200

            result = m._unauthorized_callback("GOTTA LOGIN FOOL")
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 200)
            self.assertEqual(data, {'err': 'GOTTA LOGIN FOOL'})

    def test_custom_needs_fresh_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)

            @m.needs_fresh_token_loader
            def custom_token_needs_refresh():
                return jsonify({'sub_status': 101}), 200

            result = m._needs_fresh_token_callback()
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 200)
            self.assertEqual(data, {'sub_status': 101})

    def test_custom_revoked_token_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)

            @m.revoked_token_loader
            def custom_revoken_token():
                return jsonify({"err": "Nice knowing you!"}), 422

            result = m._revoked_token_callback()
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 422)
            self.assertEqual(data, {'err': 'Nice knowing you!'})

    def test_custom_user_loader(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)

            @m.user_loader_callback_loader
            def custom_user_loader(identity):
                if identity == 'foo':
                    return None
                return identity

            identity = 'foobar'
            result = m._user_loader_callback(identity)
            self.assertEqual(result, identity)
            self.assertEqual(m.has_user_loader(), True)

    def test_custom_user_loader_error_callback(self):
        with self.app.test_request_context():
            m = JWTManager(self.app)

            @m.user_loader_error_loader
            def custom_user_loader_error(identity):
                return jsonify({'msg': 'Not found'}), 404

            identity = 'foobar'
            result = m._user_loader_error_callback(identity)
            status_code, data = self._parse_callback_result(result)

            self.assertEqual(status_code, 404)
            self.assertEqual(data, {'msg': 'Not found'})
示例#52
0
def create_app():
    app = Flask(__name__)
    app.config["SQLALCHEMY_DATABASE_URI"] = 'sqlite:///flask-admin.db'

    register_extensions(app)

    # Create modules
    #
    app.register_blueprint(indexModule)
    app.register_blueprint(controlModule)
    app.register_blueprint(personalCenterModule)
    app.register_blueprint(registeredModule)
    app.register_blueprint(inforDetectionModule)
    app.register_blueprint(monitorModule)
    app.register_blueprint(behavioGaugeModule)
    app.register_blueprint(reportManageModule)
    app.register_blueprint(knowledgeModule)
    app.register_blueprint(systemManageModule)
    app.register_blueprint(systemmanageModule)
    app.register_blueprint(weiboxnroperateModule)
    app.register_blueprint(weiboxnrcreateModule)
    app.register_blueprint(weiboxnrmanageModule)
    app.register_blueprint(weiboxnrassessmentModule)
    app.register_blueprint(weiboxnrknowledgebasemanagementModule)
    app.register_blueprint(weiboxnrmonitorModule)
    app.register_blueprint(weiboxnrwarmingModule)
    app.register_blueprint(weiboxnrwarmingnewModule)
    app.register_blueprint(weiboxnrreportmanageModule)
    app.register_blueprint(weiboxnrcommunityModule)
    app.register_blueprint(qqxnrmanageModule)
    app.register_blueprint(qqxnroperateModule)
    app.register_blueprint(qqxnrassessmentModule)
    app.register_blueprint(qqxnrmonitorModule)
    app.register_blueprint(qqxnrreportmanageModule)
    app.register_blueprint(qqxnrwarmingModule)
    app.register_blueprint(wxxnrmanageModule)
    app.register_blueprint(wxxnroperateModule)
    app.register_blueprint(wxxnrmonitorModule)
    app.register_blueprint(wxxnrassessmentModule)
    app.register_blueprint(wxxnrreportmanageModule)
    app.register_blueprint(wxxnrwarningModule)
    # the debug toolbar is only enabled in debug mode

    app.register_blueprint(facebookxnrcreateModule)
    app.register_blueprint(facebookxnrwarningModule)
    app.register_blueprint(facebookxnrmonitorModule)
    app.register_blueprint(facebookxnrassessmentModule)
    app.register_blueprint(facebookxnrmanageModule)
    app.register_blueprint(facebookxnrknowledgebasemanagementModule)
    app.register_blueprint(facebookxnrreportmanageModule)
    app.register_blueprint(facebookxnrcommunityModule)

    app.register_blueprint(twitterxnrassessmentModule)
    app.register_blueprint(twitterxnrcreateModule)
    app.register_blueprint(twitterxnrwarningModule)
    app.register_blueprint(twitterxnrmonitorModule)
    app.register_blueprint(twitterxnrmanageModule)
    app.register_blueprint(twitterxnrknowledgebasemanagementModule)
    app.register_blueprint(twitterxnrreportmanageModule)

    app.register_blueprint(facebookxnroperateModule)
    app.register_blueprint(twitterxnroperateModule)

    app.register_blueprint(intelligentwritingModule)

    # app.register_blueprint(commoncorpusmanagementModule)

    app.config['DEBUG'] = True

    app.config['ADMINS'] = frozenset(['*****@*****.**'])
    app.config['SECRET_KEY'] = 'SecretKeyForSessionSigning'

    app.config['SECURITY_REGISTERABLE'] = True
    app.config['SECURITY_REGISTER_URL'] = '/create_account'

    mail = Mail(app)
    app.config['MAIL_SERVER'] = 'smtp.163.com'
    app.config['MAIL_PORT'] = '25'
    app.config['MAIL_USE_TLS'] = True
    app.config['MAIL_USERNAME'] = os.environ.get('MAIL_USERNAME')
    app.config['MAIL_PASSWORD'] = os.environ.get('MAIL_PASSWORD')
    #app.config['SECURITY_CONFIRMABLE']=True
    app.config['SECURITY_RETYPABLE'] = True
    '''
    app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+mysqldb://%s:@%s/%s?charset=utf8' % (MYSQL_USER, MYSQL_HOST, MYSQL_DB)
    app.config['SQLALCHEMY_ECHO'] = False
    '''
    app.config['DATABASE_CONNECT_OPTIONS'] = {}

    app.config['THREADS_PER_PAGE'] = 8

    app.config['CSRF_ENABLED'] = True
    app.config['CSRF_SESSION_KEY'] = 'somethingimpossibletoguess'

    # Enable the toolbar?
    app.config['DEBUG_TB_ENABLED'] = app.debug
    # Should intercept redirects?
    app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = True
    # Enable the profiler on all requests, default to false
    app.config['DEBUG_TB_PROFILER_ENABLED'] = True
    # Enable the template editor, default to false
    app.config['DEBUG_TB_TEMPLATE_EDITOR_ENABLED'] = True

    # debug toolbar
    # toolbar = DebugToolbarExtension(app)
    # debug toolbar
    # toolbar = DebugToolbarExtension(app)
    # app.config['MONGO_HOST'] = '219.224.134.212'
    # app.config['MONGO_PORT'] = 27017
    # app.config['MONGO_DBNAME'] = 'mrq'

    # init database
    db.init_app(app)
    with app.test_request_context():
        db.create_all()

    # init security
    security.init_app(app, datastore=user_datastore)

    # init admin
    admin.init_app(app)
    admin.add_view(AdminAccessView_user(User, db.session))
    admin.add_view(AdminAccessView_role(Role, db.session))
    # admin.add_view(sqla.ModelView(User, db.session))
    # admin.add_view(sqla.ModelView(Role, db.session))

    return app
class EpisodeRoutesTestCase(unittest.TestCase):

    """
    Tests Show-related page content.
    """
    def setUp(self):
        bg_tracking.app.testing = True
        self.app = Flask(__name__)
        self.atc = bg_tracking.app.test_client()

    def test_edit_new_with_show_id(self):
        show_id = 2
        route = '/episode/add/?show_id={}'.format(show_id)
        with self.app.test_request_context(route):
            assert int(request.args.get('show_id')) == show_id

    def test_edit_new_with_missing_show_id(self):
        route = '/episode/add'
        with self.app.test_request_context(route):
            self.assertIsNone(request.args.get('show_id'))

    def test_edit_existing(self):
        episode_id = 4
        route = '/episode/{}/edit'.format(episode_id)
        response = self.atc.get(route, follow_redirects=True)
        self.assertEqual(response.status_code, 200)

    def test_default_details_view(self):
        episode_id = 1
        route = '/episode/{}'.format(episode_id)
        response = self.atc.get(route, follow_redirects=True)
        self.assertEqual(response.status_code, 200)
        # assert b'<title>Error</title>' in response.data

    def tet_submit_new_record(self):
        data = {'title': 'unit test {:%Y%d%m}'.format(datetime.now()),
                'number': 1,
                'show': 3,
                'id': None
                }
        response = self.atc.post('/episode/add/', data=data)
        self.assertEqual(response.status_code, 200)

    def _out_test_submit_new_record(self):
        data = {'title': 'unit test {:%Y%d%m}'.format(datetime.now()),
                'number': 1,
                'show': 3,
                'id': None
                }
        with self.app.test_request_context('/episode/add/',
                                           method='POST',
                                           data=data
                                           ):
            rv = self.app.preprocess_request()
            if rv is not None:
                response = self.app.make_response(rv)
            else:
                rv = self.app.dispatch_request()
                response = self.app.make_response(rv)
                response = self.app.process_response(response)

            self.assertEqual(response.status_code, 200)
示例#54
0
def create_app():
    app = Flask(__name__)
    app.config.from_object(os.environ['APP_SETTINGS'])

    db.init_app(app)
    with app.test_request_context():
        db.create_all()

    login_manager.init_app(app)
    login_manager.login_view = 'users.log'

    bootstrap.init_app(app)

    ckeditor.init_app(app)

    mail.init_app(app)

    csrf_protect.init_app(app)

    moment.init_app(app)

    import app.users.controllers as users
    app.register_blueprint(users.module)
    import app.posts.controllers as posts
    app.register_blueprint(posts.module)

    # Flask Admin

    from wtforms.fields import HiddenField
    from app.models import User, Role, Post, Tag, StorageImg

    class AdminMixin:
        def is_accessible(self):
            return current_user.is_administrator()

        def inaccessible_callback(self, name, **kwargs):
            return redirect(url_for('users.log', next=request.url))

    class BaseModelView(ModelView):
        def on_model_change(self, form, model, is_created):
            model.generate_slug()
            return super(BaseModelView,
                         self).on_model_change(form, model, is_created)

    class HomeAdminView(AdminMixin, AdminIndexView):
        pass

    class AdminUserView(AdminMixin, ModelView):
        can_create = False
        column_exclude_list = ('password_hash')
        form_overrides = dict(password_hash=HiddenField)

    class RoleView(AdminMixin, ModelView):
        pass

    class ImgView(AdminMixin, ModelView):
        pass

    class PostView(AdminMixin, BaseModelView):
        column_exclude_list = ('slug')
        form_overrides = dict(slug=HiddenField, body=CKEditorField)
        create_template = 'admin/edit.html'
        edit_template = 'admin/edit.html'

    class TagView(AdminMixin, BaseModelView):
        column_exclude_list = ('slug')
        form_overrides = dict(slug=HiddenField)

    admin = Admin(app,
                  'Adminka',
                  url='/',
                  index_view=HomeAdminView(name='Home'),
                  template_mode='bootstrap3')
    admin.add_view(AdminUserView(User, db.session))
    admin.add_view(RoleView(Role, db.session))
    admin.add_view(PostView(Post, db.session))
    admin.add_view(TagView(Tag, db.session))
    admin.add_view(ImgView(StorageImg, db.session))

    return app
示例#55
0
def test_cannot_signup_with_unknown_provider(app: Flask):
    with app.app_context():
        with app.test_request_context('/auth/signup/with/whatever'):
            with pytest.raises(exceptions.NotFound):
                signup_with('whatever')
示例#56
0
    for i in newData.index:
        if newData['Vendor name'][i] != vendordata1[newData['Vendor Code'][i]]:
            newData.drop([i])

    for i in newData.index:
        if newData['Vendor Code'][i] != vendordata2[newData['Vendor name'][i]]:
            newData.drop([i])
    newData = newData.drop_duplicates(['Invoice Numbers'])
    result['Ts'] = float(newData['Amt in loc.cur.'].sum())
    df = newData.drop_duplicates(subset='Vendor name', keep="first")
    result['Nu'] = int(df['Vendor name'].count())
    result['I'] = int(data.shape[0] - newData.shape[0])
    json_object = json.dumps(result, indent=4)
    with open("finaldata.json", "w") as outfile:
        outfile.write(json_object)
    os.remove(filename)


def wrongfilefunction():
    json_object = json.dumps({}, indent=4)
    with open("finaldata.json", "w") as outfile:
        outfile.write(json_object)


if __name__ == "__main__":
    with app.test_request_context("/"):
        session["key"] = os.urandom(24)
    app.run(debug=True, port=int(os.environ.get('PORT', 5000)))
CORS(app, expose_headers='Authorization')
示例#57
0
def app():
    _app = Flask(__name__)
    ctx = _app.test_request_context()
    ctx.push()
    yield _app
    ctx.pop()
示例#58
0
book_schema = BookSchema()

author = Author(name="Chuck Paluhniuk")
book = Book(title="Fight Club", author=author)
book1 = Book(title="street fighter", author=author)

# # what are the other way to do this ?
# db.session.add(author)
# db.session.add(book)
# db.session.commit()
# print(author_schema.dump(author))
# print( author_schema.dump(book) )

#this have to be request context? in what scerio that it would consider as request context?

with app.test_request_context(): # session, q 
        
    print(author_schema.dump(author))
    print( author_schema.dump(book) )

# view 
@app.route("/api/book/")
def users():
    all_book = Book.all()
    return book_schema.dump(all_book)


@app.route("/api/book/<id>")
def user_detail(id):
    book = Book.get(id)
    return book_schema.dump(book)
示例#59
0
from flask import Flask, url_for
app = Flask(__name__)


@app.route('/')
def index():
    pass


@app.route('/login')
def login():
    pass


@app.route('/user/<username>')
def profile(username):
    pass


with app.test_request_context():
    print url_for('index')
    print url_for('login')
    print url_for('login', next='/')
    print url_for('profile', username='******')
示例#60
0
def test_prettyprint():
    app = Flask("test")
    with app.test_request_context("/?prettyprint=1"):
        serializer = JSONSerializer()
        assert '{\n  "key": "1"\n}' == serializer.serialize_object({"key": "1"})