Exemple #1
0
def app_init(app):
    app.logger.setLevel(logging.INFO)

    # Register duplicates only at runtime
    app.logger.info("Initializing app")
    dictionary_init(app)

    if app.config.get("USE_USER_HARAKIRI", True):
        setup_user_harakiri(app)

    app_register_blueprints(app)
    app_register_duplicate_blueprints(app)

    db_init(app)
    # exclude es init as it's not used yet
    # es_init(app)
    cors_init(app)
    submission.graphql.make_graph_traversal_dict(app)
    app.graphql_schema = submission.graphql.get_schema()
    app.schema_file = submission.generate_schema_file(app.graphql_schema, app.logger)
    try:
        app.secret_key = app.config["FLASK_SECRET_KEY"]
    except KeyError:
        app.logger.error("Secret key not set in config! Authentication will not work")
    async_pool_init(app)

    # ARBORIST deprecated, replaced by ARBORIST_URL
    arborist_url = os.environ.get("ARBORIST_URL", os.environ.get("ARBORIST"))
    if arborist_url:
        app.auth = ArboristClient(arborist_base_url=arborist_url)
    else:
        app.logger.info("Using default Arborist base URL")
        app.auth = ArboristClient()

    app.logger.info("Initialization complete.")
Exemple #2
0
def app_init() -> FastAPI:
    logger.info("Initializing app")
    config.validate(logger)

    debug = config["DEBUG"]
    app = FastAPI(
        title="Audit Service",
        version=version("audit"),
        debug=debug,
        root_path=config["DOCS_URL_PREFIX"],
    )
    app.add_middleware(ClientDisconnectMiddleware)
    app.async_client = httpx.AsyncClient()

    # Following will update logger level, propagate, and handlers
    get_logger("audit-service", log_level="debug" if debug == True else "info")

    logger.info("Initializing Arborist client")
    if os.environ.get("ARBORIST_URL"):
        app.arborist_client = ArboristClient(
            arborist_base_url=os.environ["ARBORIST_URL"],
            logger=logger,
        )
    else:
        app.arborist_client = ArboristClient(logger=logger)

    db.init_app(app)
    load_modules(app)

    @app.on_event("startup")
    async def startup_event():
        if (config["PULL_FROM_QUEUE"]
                and config["QUEUE_CONFIG"].get("type") == "aws_sqs"):
            loop = asyncio.get_running_loop()
            loop.create_task(pull_from_queue_loop())
            loop.set_exception_handler(handle_exception)

    @app.on_event("shutdown")
    async def shutdown_event():
        logger.info("Closing async client.")
        await app.async_client.aclose()
        logger.info("[Completed] Closing async client.")

    def handle_exception(loop, context):
        """
        Whenever an exception occurs in the asyncio loop, the loop still continues to execute without crashing.
        Therefore, we implement a custom exception handler that will ensure that the loop is stopped upon an Exception.
        """
        msg = context.get("exception", context.get("message"))
        logger.error(f"Caught exception: {msg}")
        for index, task in enumerate(asyncio.all_tasks()):
            task.cancel()
        logger.info("Closed all tasks")

    return app
Exemple #3
0
def app_init(app):
    # Register duplicates only at runtime
    app.logger.info("Initializing app")

    # explicit options set for compatibility with gdc's api
    app.config["AUTH_SUBMISSION_LIST"] = True
    app.config["USE_DBGAP"] = False
    app.config["IS_GDC"] = False

    # default settings
    app.config["AUTO_MIGRATE_DATABASE"] = app.config.get(
        "AUTO_MIGRATE_DATABASE", True)
    app.config["REQUIRE_FILE_INDEX_EXISTS"] = (
        # If True, enforce indexd record exists before file node registration
        app.config.get("REQUIRE_FILE_INDEX_EXISTS", False))

    if app.config.get("USE_USER_HARAKIRI", True):
        setup_user_harakiri(app)

    app.config["AUTH_NAMESPACE"] = "/" + os.getenv("AUTH_NAMESPACE",
                                                   "").strip("/")

    app_register_blueprints(app)
    db_init(app)
    # exclude es init as it's not used yet
    # es_init(app)
    try:
        app.secret_key = app.config["FLASK_SECRET_KEY"]
    except KeyError:
        app.logger.error(
            "Secret key not set in config! Authentication will not work")

    # ARBORIST deprecated, replaced by ARBORIST_URL
    arborist_url = os.environ.get("ARBORIST_URL", os.environ.get("ARBORIST"))
    if arborist_url:
        app.auth = ArboristClient(arborist_base_url=arborist_url)
    else:
        app.logger.info("Using default Arborist base URL")
        app.auth = ArboristClient()

    app.node_authz_entity_name = os.environ.get("AUTHZ_ENTITY_NAME", None)
    app.node_authz_entity = None
    app.subject_entity = None
    if app.node_authz_entity_name:
        full_module_name = "datamodelutils.models"
        mymodule = importlib.import_module(full_module_name)
        for i in dir(mymodule):
            app.logger.warn(i)
            if i.lower() == "person":
                attribute = getattr(mymodule, i)
                app.subject_entity = attribute
            if i.lower() == app.node_authz_entity_name.lower():
                attribute = getattr(mymodule, i)
                app.node_authz_entity = attribute
Exemple #4
0
def app_init(app):
    app.logger.setLevel(logging.INFO)

    # Register duplicates only at runtime
    app.logger.info("Initializing app")
    dictionary_init(app)

    if app.config.get("USE_USER_HARAKIRI", True):
        setup_user_harakiri(app)

    app_register_blueprints(app)
    app_register_duplicate_blueprints(app)

    db_init(app)
    # exclude es init as it's not used yet
    # es_init(app)
    cors_init(app)
    submission.graphql.make_graph_traversal_dict(app)
    app.graphql_schema = submission.graphql.get_schema()
    app.schema_file = submission.generate_schema_file(app.graphql_schema,
                                                      app.logger)
    try:
        app.secret_key = app.config["FLASK_SECRET_KEY"]
    except KeyError:
        app.logger.error(
            "Secret key not set in config! Authentication will not work")
    async_pool_init(app)

    # ARBORIST deprecated, replaced by ARBORIST_URL
    arborist_url = os.environ.get("ARBORIST_URL", os.environ.get("ARBORIST"))
    if arborist_url:
        app.auth = ArboristClient(arborist_base_url=arborist_url)
    else:
        app.logger.info("Using default Arborist base URL")
        app.auth = ArboristClient()

    app.node_authz_entity_name = os.environ.get("AUTHZ_ENTITY_NAME", None)
    app.node_authz_entity = None
    app.subject_entity = None
    if app.node_authz_entity_name:
        full_module_name = "datamodelutils.models"
        mymodule = importlib.import_module(full_module_name)
        for i in dir(mymodule):
            # TODO look at all the nodes between node_authz_entity_name and project to generalize
            if i.lower() == "person":
                attribute = getattr(mymodule, i)
                app.subject_entity = attribute
            if i.lower() == app.node_authz_entity_name.lower():
                attribute = getattr(mymodule, i)
                app.node_authz_entity = attribute

    app.logger.info("Initialization complete.")
Exemple #5
0
def app(tmpdir, request):

    port = 8000
    dictionary_setup(_app)
    # this is to make sure sqlite is initialized
    # for every unit test
    reload(default_settings)

    # fresh files before running
    for filename in ['auth.sq3', 'index.sq3', 'alias.sq3']:
        if os.path.exists(filename):
            os.remove(filename)
    indexd_app = get_indexd_app()

    indexd_init(*INDEX_CLIENT['auth'])
    indexd = Process(target=indexd_app.run, args=['localhost', port])
    indexd.start()
    wait_for_indexd_alive(port)

    gencode_json = tmpdir.mkdir("slicing").join("test_gencode.json")
    gencode_json.write(
        json.dumps({
            'a_gene': ['chr1', None, 200],
            'b_gene': ['chr1', 150, 300],
            'c_gene': ['chr1', 200, None],
            'd_gene': ['chr1', None, None],
        }))

    def teardown():
        for filename in ['auth.sq3', 'index.sq3', 'alias.sq3']:
            if os.path.exists(filename):
                os.remove(filename)

        indexd.terminate()
        wait_for_indexd_not_alive(port)

    _app.config.from_object("sheepdog.test_settings")
    _app.config["PATH_TO_SCHEMA_DIR"] = PATH_TO_SCHEMA_DIR

    request.addfinalizer(teardown)

    app_init(_app)

    _app.logger.setLevel(os.environ.get("GDC_LOG_LEVEL", "WARNING"))

    _app.jwt_public_keys = {
        _app.config['USER_API']: {
            'key-test':
            utils.read_file('./integration/resources/keys/test_public_key.pem')
        }
    }

    _app.auth = ArboristClient()

    return _app
Exemple #6
0
 def __init__(self, conn, arborist=None, **config):
     """
     Initialize the SQLAlchemy database driver.
     """
     super().__init__(conn, **config)
     Base.metadata.bind = self.engine
     Base.metadata.create_all()
     self.Session = sessionmaker(bind=self.engine)
     if arborist is not None:
         arborist = ArboristClient(arborist_base_url=arborist)
     self.arborist = arborist
Exemple #7
0
async def list_policies(arborist_client: ArboristClient,
                        expand: bool = False) -> list:
    """
    We can cache this data later if needed, but it's tricky - the length
    we can cache depends on the source of the information, so this MUST
    invalidate the cache whenever Arborist changes a policy.
    For now, just make a call to Arborist every time we need this information.
    """
    res = arborist_client.list_policies(expand=expand)
    if inspect.isawaitable(res):
        res = await res
    return res
Exemple #8
0
def app_init(app):
    # Register duplicates only at runtime
    app.logger.info("Initializing app")

    # explicit options set for compatibility with gdc's api
    app.config["AUTH_SUBMISSION_LIST"] = True
    app.config["USE_DBGAP"] = False
    app.config["IS_GDC"] = False

    # default settings
    app.config["AUTO_MIGRATE_DATABASE"] = app.config.get(
        "AUTO_MIGRATE_DATABASE", True)
    app.config["REQUIRE_FILE_INDEX_EXISTS"] = (
        # If True, enforce indexd record exists before file node registration
        app.config.get("REQUIRE_FILE_INDEX_EXISTS", False))

    if app.config.get("USE_USER_HARAKIRI", True):
        setup_user_harakiri(app)

    app.config["AUTH_NAMESPACE"] = "/" + os.getenv("AUTH_NAMESPACE",
                                                   "").strip("/")

    app_register_blueprints(app)
    db_init(app)
    # exclude es init as it's not used yet
    # es_init(app)
    try:
        app.secret_key = app.config["FLASK_SECRET_KEY"]
    except KeyError:
        app.logger.error(
            "Secret key not set in config! Authentication will not work")

    # ARBORIST deprecated, replaced by ARBORIST_URL
    arborist_url = os.environ.get("ARBORIST_URL", os.environ.get("ARBORIST"))
    if arborist_url:
        app.auth = ArboristClient(arborist_base_url=arborist_url)
    else:
        app.logger.info("Using default Arborist base URL")
        app.auth = ArboristClient()
Exemple #9
0
def app_init() -> FastAPI:
    logger.info("Initializing app")
    config.validate()

    debug = config["DEBUG"]
    app = FastAPI(
        title="Requestor",
        version=version("requestor"),
        debug=debug,
        root_path=config["DOCS_URL_PREFIX"],
    )
    app.add_middleware(ClientDisconnectMiddleware)
    app.async_client = httpx.AsyncClient()

    # Following will update logger level, propagate, and handlers
    get_logger("requestor", log_level="debug" if debug == True else "info")

    logger.info("Initializing Arborist client")
    custom_arborist_url = os.environ.get("ARBORIST_URL", config["ARBORIST_URL"])
    if custom_arborist_url:
        app.arborist_client = ArboristClient(
            arborist_base_url=custom_arborist_url,
            authz_provider="requestor",
            logger=logger,
        )
    else:
        app.arborist_client = ArboristClient(authz_provider="requestor", logger=logger)

    db.init_app(app)
    load_modules(app)

    @app.on_event("shutdown")
    async def shutdown_event():
        logger.info("Closing async client.")
        await app.async_client.aclose()

    return app
Exemple #10
0
def mock_arborist_requests(app, request):
    """
    This fixture returns a function which you call to mock the call to
    arborist client's auth_request method.
    It returns a 401 error unless the resource and method match the provided
    `resource_method_to_authorized` dict, in which case it returns a 401 error
    or a 200 response depending on the dict value.
    """
    arborist_base_url = "arborist"
    app.auth.arborist = ArboristClient(arborist_base_url=arborist_base_url)

    def do_patch(resource_method_to_authorized={}):
        # Resource/Method to authorized: { RESOURCE: { METHOD: True/False } }

        def make_mock_response(method, url, *args, **kwargs):
            method = method.upper()
            mocked_response = mock.MagicMock(requests.Response)

            if url != f"{arborist_base_url}/auth/request":
                mocked_response.status_code = 404
                mocked_response.text = "NOT FOUND"
            elif method != "POST":
                mocked_response.status_code = 405
                mocked_response.text = "METHOD NOT ALLOWED"
            else:
                authz_res, authz_met = None, None
                authz_requests = kwargs["json"]["requests"]
                if authz_requests:
                    authz_res = kwargs["json"]["requests"][0]["resource"]
                    authz_met = kwargs["json"]["requests"][0]["action"][
                        "method"]
                authorized = resource_method_to_authorized.get(
                    authz_res, {}).get(authz_met, False)
                mocked_response.status_code = 200
                mocked_response.json.return_value = {"auth": authorized}
            return mocked_response

        mocked_method = mock.MagicMock(side_effect=make_mock_response)
        patch_method = patch(
            "gen3authz.client.arborist.client.httpx.Client.request",
            mocked_method)

        patch_method.start()
        request.addfinalizer(patch_method.stop)

    return do_patch
Exemple #11
0
def _setup_arborist_client(app):
    if app.config.get("ARBORIST"):
        app.arborist = ArboristClient(arborist_base_url=config["ARBORIST"])
Exemple #12
0
def main():
    args = parse_arguments()

    # get database information
    sys.path.append(args.path)

    # replicate cfg loading done in flask app to maintain backwards compatibility
    # TODO (DEPRECATE LOCAL_SETTINGS): REMOVE this when putting cfg in
    # settings/local_settings is deprecated
    import flask

    settings_cfg = flask.Config(".")
    settings_cfg.from_object("fence.settings")
    config.update(dict(settings_cfg))

    # END - TODO (DEPRECATE LOCAL_SETTINGS): REMOVE

    config.load(search_folders=CONFIG_SEARCH_FOLDERS)

    DB = os.environ.get("FENCE_DB") or config.get("DB")

    # attempt to get from settings, this is backwards-compatibility for integration
    # tests
    if DB is None:
        try:
            from fence.settings import DB
        except ImportError:
            pass

    BASE_URL = os.environ.get("BASE_URL") or config.get("BASE_URL")
    ROOT_DIR = os.environ.get("ROOT_DIR") or os.path.dirname(
        os.path.dirname(os.path.realpath(__file__)))
    dbGaP = os.environ.get("dbGaP") or config.get("dbGaP")
    if not isinstance(dbGaP, list):
        dbGaP = [dbGaP]
    STORAGE_CREDENTIALS = os.environ.get("STORAGE_CREDENTIALS") or config.get(
        "STORAGE_CREDENTIALS")
    usersync = config.get("USERSYNC", {})
    sync_from_visas = usersync.get("sync_from_visas", False)
    fallback_to_dbgap_sftp = usersync.get("fallback_to_dbgap_sftp", False)

    arborist = None
    if args.arborist:
        arborist = ArboristClient(
            arborist_base_url=args.arborist,
            logger=get_logger("user_syncer.arborist_client"),
            authz_provider="user-sync",
        )

    if args.action == "create":
        yaml_input = args.__dict__["yaml-file-path"]
        create_sample_data(DB, yaml_input)
    elif args.action == "client-create":
        confidential = not args.public
        create_client_action(
            DB,
            username=args.username,
            client=args.client,
            urls=args.urls,
            auto_approve=args.auto_approve,
            grant_types=args.grant_types,
            confidential=confidential,
            arborist=arborist,
            policies=args.policies,
            allowed_scopes=args.allowed_scopes,
        )
    elif args.action == "client-modify":
        modify_client_action(
            DB,
            client=args.client,
            delete_urls=args.delete_urls,
            urls=args.urls,
            name=args.name,
            description=args.description,
            set_auto_approve=args.set_auto_approve,
            unset_auto_approve=args.unset_auto_approve,
            arborist=arborist,
            policies=args.policies,
            allowed_scopes=args.allowed_scopes,
            append=args.append,
        )
    elif args.action == "client-delete":
        delete_client_action(DB, args.client)
    elif args.action == "client-list":
        list_client_action(DB)
    elif args.action == "user-delete":
        delete_users(DB, args.users)
    elif args.action == "expired-service-account-delete":
        delete_expired_service_accounts(DB)
    elif args.action == "bucket-access-group-verify":
        verify_bucket_access_group(DB)
    elif args.action == "sync":
        sync_users(
            dbGaP,
            STORAGE_CREDENTIALS,
            DB,
            projects=args.project_mapping,
            is_sync_from_dbgap_server=str2bool(args.sync_from_dbgap),
            sync_from_local_csv_dir=args.csv_dir,
            sync_from_local_yaml_file=args.yaml,
            folder=args.folder,
            arborist=arborist,
            sync_from_visas=sync_from_visas,
            fallback_to_dbgap_sftp=fallback_to_dbgap_sftp,
        )
    elif args.action == "dbgap-download-access-files":
        download_dbgap_files(
            dbGaP,
            STORAGE_CREDENTIALS,
            DB,
            folder=args.folder,
        )
    elif args.action == "google-manage-keys":
        remove_expired_google_service_account_keys(DB)
    elif args.action == "google-init":
        google_init(DB)
    elif args.action == "google-manage-user-registrations":
        verify_user_registration(DB)
    elif args.action == "google-manage-account-access":
        remove_expired_google_accounts_from_proxy_groups(DB)
    elif args.action == "google-bucket-create":
        # true if true provided, false if anything else provided, leave as
        # None if not provided at all (policy will remain unchanged)
        if args.public and args.public.lower().strip() == "true":
            args.public = True
        elif args.public is not None:
            args.public = False

        create_or_update_google_bucket(
            DB,
            args.unique_name,
            storage_class=args.storage_class,
            public=args.public,
            requester_pays=args.requester_pays,
            google_project_id=args.google_project_id,
            project_auth_id=args.project_auth_id,
            access_logs_bucket=args.access_logs_bucket,
            allowed_privileges=args.allowed_privileges,
        )
    elif args.action == "google-logging-bucket-create":
        create_google_logging_bucket(
            args.unique_name,
            storage_class=args.storage_class,
            google_project_id=args.google_project_id,
        )
    elif args.action == "link-external-bucket":
        link_external_bucket(DB, name=args.bucket_name)
    elif args.action == "link-bucket-to-project":
        link_bucket_to_project(
            DB,
            bucket_id=args.bucket_id,
            bucket_provider=args.bucket_provider,
            project_auth_id=args.project_auth_id,
        )
    elif args.action == "google-list-authz-groups":
        google_list_authz_groups(DB)
    elif args.action == "token-create":
        keys_path = getattr(args, "keys-dir", os.path.join(ROOT_DIR, "keys"))
        keypairs = keys.load_keypairs(keys_path)
        # Default to the most recent one, but try to find the keypair with
        # matching ``kid`` to the argument provided.
        keypair = keypairs[-1]
        kid = getattr(args, "kid")
        if kid:
            for try_keypair in keypairs:
                if try_keypair.kid == kid:
                    keypair = try_keypair
                    break
        jwt_creator = JWTCreator(
            DB,
            BASE_URL,
            kid=keypair.kid,
            private_key=keypair.private_key,
            username=args.username,
            scopes=args.scopes,
            expires_in=args.exp,
        )
        token_type = str(args.type).strip().lower()
        if token_type == "access_token" or token_type == "access":
            print(jwt_creator.create_access_token().token)
        elif token_type == "refresh_token" or token_type == "refresh":
            print(jwt_creator.create_refresh_token().token)
        else:
            print('invalid token type "{}"; expected "access" or "refresh"'.
                  format(token_type))
            sys.exit(1)
    elif args.action == "force-link-google":
        exp = force_update_google_link(
            DB,
            username=args.username,
            google_email=args.google_email,
            expires_in=args.expires_in,
        )
        print(exp)
    elif args.action == "notify-problem-users":
        notify_problem_users(DB, args.emails, args.auth_ids,
                             args.check_linking, args.google_project_id)
    elif args.action == "migrate":
        migrate_database(DB)
    elif args.action == "update-visas":
        update_user_visas(
            DB,
            chunk_size=args.chunk_size,
            concurrency=args.concurrency,
            thread_pool_size=args.thread_pool_size,
            buffer_size=args.buffer_size,
        )
Exemple #13
0
def arborist_client(arborist_base_url, use_async):
    if use_async:
        return AsyncClient(arborist_base_url=arborist_base_url)
    else:
        return ArboristClient(arborist_base_url=arborist_base_url)
Exemple #14
0
async def create_arborist_policy(
    arborist_client: ArboristClient,
    resource_path: str,
    resource_description: str = None,
):
    # create the resource
    logger.debug(f"Attempting to create resource {resource_path} in Arborist")
    resources = resource_path.split("/")
    resource_name = resources[-1]
    parent_path = "/".join(resources[:-1])
    resource = {
        "name": resource_name,
        "description": resource_description,
    }
    res = arborist_client.create_resource(parent_path,
                                          resource,
                                          create_parents=True)
    if inspect.isawaitable(res):
        await res

    # Create "reader" and "storage_reader" roles in Arborist.
    # If they already exist arborist would "Do Nothing"
    roles = [
        {
            "id":
            "reader",
            "permissions": [{
                "id": "reader",
                "action": {
                    "service": "*",
                    "method": "read"
                }
            }],
        },
        {
            "id":
            "storage_reader",
            "permissions": [{
                "id": "storage_reader",
                "action": {
                    "service": "*",
                    "method": "read-storage"
                },
            }],
        },
    ]

    for role in roles:
        try:
            res = arborist_client.update_role(role["id"], role)
            if inspect.isawaitable(res):
                await res
        except ArboristError as e:
            logger.info(
                "An error occured while updating role - '{}', '{}'".format(
                    {role["id"]}, str(e)))
            logger.debug(
                f"Attempting to create role '{role['id']}' in Arborist")
            res = arborist_client.create_role(role)
            if inspect.isawaitable(res):
                await res

    # create the policy
    policy_id = get_auto_policy_id_for_resource_path(resource_path)
    logger.debug(f"Attempting to create policy {policy_id} in Arborist")
    policy = {
        "id": policy_id,
        "description": "policy created by requestor",
        "role_ids": ["reader", "storage_reader"],
        "resource_paths": [resource_path],
    }
    res = arborist_client.create_policy(policy, skip_if_exists=True)
    if inspect.isawaitable(res):
        await res

    return policy_id
Exemple #15
0
    def post_login(self, user=None, token_result=None):
        # TODO: I'm not convinced this code should be in post_login.
        # Just putting it in here for now, but might refactor later.
        # This saves us a call to RAS /userinfo, but will not make sense
        # when there is more than one visa issuer.

        # Clear all of user's visas, to avoid having duplicate visas
        # where only iss/exp/jti differ
        # TODO: This is not IdP-specific and will need a rethink when
        # we have multiple IdPs
        user.ga4gh_visas_v1 = []

        current_session.commit()

        encoded_visas = []

        try:
            encoded_visas = flask.current_app.ras_client.get_encoded_visas_v11_userinfo(
                flask.g.userinfo
            )
        except Exception as e:
            err_msg = "Could not retrieve visas"
            logger.error("{}: {}".format(e, err_msg))
            raise

        for encoded_visa in encoded_visas:
            try:
                # Do not move out of loop unless we can assume every visa has same issuer and kid
                public_key = get_public_key_for_token(
                    encoded_visa, attempt_refresh=True
                )
            except Exception as e:
                # (But don't log the visa contents!)
                logger.error(
                    "Could not get public key to validate visa: {}. Discarding visa.".format(
                        e
                    )
                )
                continue

            try:
                # Validate the visa per GA4GH AAI "Embedded access token" format rules.
                # pyjwt also validates signature and expiration.
                decoded_visa = validate_jwt(
                    encoded_visa,
                    public_key,
                    # Embedded token must not contain aud claim
                    aud=None,
                    # Embedded token must contain scope claim, which must include openid
                    scope={"openid"},
                    issuers=config.get("GA4GH_VISA_ISSUER_ALLOWLIST", []),
                    # Embedded token must contain iss, sub, iat, exp claims
                    # options={"require": ["iss", "sub", "iat", "exp"]},
                    # ^ FIXME 2021-05-13: Above needs pyjwt>=v2.0.0, which requires cryptography>=3.
                    # Once we can unpin and upgrade cryptography and pyjwt, switch to above "options" arg.
                    # For now, pyjwt 1.7.1 is able to require iat and exp;
                    # authutils' validate_jwt (i.e. the function being called) checks issuers already (see above);
                    # and we will check separately for sub below.
                    options={
                        "require_iat": True,
                        "require_exp": True,
                    },
                )

                # Also require 'sub' claim (see note above about pyjwt and the options arg).
                if "sub" not in decoded_visa:
                    raise JWTError("Visa is missing the 'sub' claim.")
            except Exception as e:
                logger.error("Visa failed validation: {}. Discarding visa.".format(e))
                continue

            visa = GA4GHVisaV1(
                user=user,
                source=decoded_visa["ga4gh_visa_v1"]["source"],
                type=decoded_visa["ga4gh_visa_v1"]["type"],
                asserted=int(decoded_visa["ga4gh_visa_v1"]["asserted"]),
                expires=int(decoded_visa["exp"]),
                ga4gh_visa=encoded_visa,
            )
            current_session.add(visa)
            current_session.commit()

        # Store refresh token in db
        assert "refresh_token" in flask.g.tokens, "No refresh_token in user tokens"
        refresh_token = flask.g.tokens["refresh_token"]
        assert "id_token" in flask.g.tokens, "No id_token in user tokens"
        id_token = flask.g.tokens["id_token"]
        decoded_id = jwt.decode(id_token, verify=False)

        # Add 15 days to iat to calculate refresh token expiration time
        issued_time = int(decoded_id.get("iat"))
        expires = config["RAS_REFRESH_EXPIRATION"]

        # User definied RAS refresh token expiration time
        parsed_url = urlparse(flask.session.get("redirect"))
        query_params = parse_qs(parsed_url.query)
        if query_params.get("upstream_expires_in"):
            custom_refresh_expiration = query_params.get("upstream_expires_in")[0]
            expires = get_valid_expiration(
                custom_refresh_expiration,
                expires,
                expires,
            )

        flask.current_app.ras_client.store_refresh_token(
            user=user, refresh_token=refresh_token, expires=expires + issued_time
        )

        global_parse_visas_on_login = config["GLOBAL_PARSE_VISAS_ON_LOGIN"]
        usersync = config.get("USERSYNC", {})
        sync_from_visas = usersync.get("sync_from_visas", False)
        parse_visas = global_parse_visas_on_login or (
            global_parse_visas_on_login == None
            and (
                strtobool(query_params.get("parse_visas")[0])
                if query_params.get("parse_visas")
                else False
            )
        )
        # if sync_from_visas and (global_parse_visas_on_login or global_parse_visas_on_login == None):
        # Check if user has any project_access from a previous session or from usersync AND if fence is configured to use visas as authZ source
        # if not do an on-the-fly usersync for this user to give them instant access after logging in through RAS
        # If GLOBAL_PARSE_VISAS_ON_LOGIN is true then we want to run it regardless of whether or not the client sent parse_visas on request
        if sync_from_visas and parse_visas and not user.project_access:
            # Close previous db sessions. Leaving it open causes a race condition where we're viewing user.project_access while trying to update it in usersync
            # not closing leads to partially updated records
            current_session.close()

            DB = os.environ.get("FENCE_DB") or config.get("DB")
            if DB is None:
                try:
                    from fence.settings import DB
                except ImportError:
                    pass

            arborist = ArboristClient(
                arborist_base_url=config["ARBORIST"],
                logger=get_logger("user_syncer.arborist_client"),
                authz_provider="user-sync",
            )
            dbGaP = os.environ.get("dbGaP") or config.get("dbGaP")
            if not isinstance(dbGaP, list):
                dbGaP = [dbGaP]

            sync = init_syncer(
                dbGaP,
                None,
                DB,
                arborist=arborist,
            )
            sync.sync_single_user_visas(user, current_session)

        super(RASCallback, self).post_login()
Exemple #16
0

# revision identifiers, used by Alembic.
revision = "42cbae986650"
down_revision = "c0a92da5ac69"
branch_labels = None
depends_on = None


logger = get_logger("requestor-migrate", log_level="debug")

custom_arborist_url = os.environ.get("ARBORIST_URL", config["ARBORIST_URL"])
if custom_arborist_url:
    arborist_client = ArboristClient(
        arborist_base_url=custom_arborist_url,
        authz_provider="requestor",
        logger=logger,
    )
else:
    arborist_client = ArboristClient(authz_provider="requestor", logger=logger)


def escape(str):
    # escape single quotes for SQL statement
    return str.replace("'", "''")


def upgrade():
    # get the list of existing policies from Arborist
    if not config["LOCAL_MIGRATION"]:
        existing_policies = list_policies(arborist_client)