Exemple #1
0
def __add_config_to_connexion_app(
    app: App,
    config: Mapping
) -> App:
    """Adds configuration to Flask app and replaces default Connexion and Flask
    settings."""
    # Replace Connexion app settings
    app.host = get_conf(config, 'server', 'host')
    app.port = get_conf(config, 'server', 'port')
    app.debug = get_conf(config, 'server', 'debug')

    # Replace Flask app settings
    app.app.config['DEBUG'] = app.debug
    app.app.config['ENV'] = get_conf(config, 'server', 'environment')
    app.app.config['TESTING'] = get_conf(config, 'server', 'testing')

    # Log Flask config
    logger.debug('Flask app settings:')
    for (key, value) in app.app.config.items():
        logger.debug('* {}: {}'.format(key, value))

    # Add user configuration to Flask app config
    app.app.config.update(config)

    logger.info('Connexion app configured.')
    return app
Exemple #2
0
def cancel_run(config: Dict, celery_app: Celery, run_id: str, *args,
               **kwargs) -> Dict:
    """Cancels running workflow."""
    collection_runs = get_conf(config, 'database', 'collections', 'runs')
    document = collection_runs.find_one(filter={'run_id': run_id},
                                        projection={
                                            'user_id': True,
                                            'task_id': True,
                                            'api.state': True,
                                            '_id': False,
                                        })

    # Raise error if workflow run was not found
    if not document:
        logger.error("Run '{run_id}' not found.".format(run_id=run_id))
        raise WorkflowNotFound

    # Raise error trying to access workflow run that is not owned by user
    # Only if authorization enabled
    if 'user_id' in kwargs and document['user_id'] != kwargs['user_id']:
        logger.error(("User '{user_id}' is not allowed to access workflow run "
                      "'{run_id}'.").format(
                          user_id=kwargs['user_id'],
                          run_id=run_id,
                      ))
        raise Forbidden

    # Cancel unfinished workflow run in background
    if document['api']['state'] in States.CANCELABLE:

        # Get timeout duration
        timeout_duration = get_conf(
            config,
            'api',
            'endpoint_params',
            'timeout_cancel_run',
        )

        # Execute cancelation task in background
        task_id = uuid()
        logger.info(("Canceling run '{run_id}' as background task "
                     "'{task_id}'...").format(
                         run_id=run_id,
                         task_id=task_id,
                     ))
        task__cancel_run.apply_async(
            None,
            {
                'run_id': run_id,
                'task_id': document['task_id'],
            },
            task_id=task_id,
            soft_time_limit=timeout_duration,
        )

    response = {'run_id': run_id}
    return response
Exemple #3
0
def run_server():

    # Configure logger
    configure_logging(config_var='WES_CONFIG_LOG')

    # Parse app configuration
    config = parse_app_config(config_var='WES_CONFIG')

    # Create Connexion app
    connexion_app = create_connexion_app(config)

    # Register MongoDB
    connexion_app.app = register_mongodb(connexion_app.app)

    # Register error handlers
    connexion_app = register_error_handlers(connexion_app)

    # Create Celery app and register background task monitoring service
    register_task_service(connexion_app.app)

    # Register OpenAPI specs
    connexion_app = register_openapi(app=connexion_app,
                                     specs=get_conf_type(config,
                                                         'api',
                                                         'specs',
                                                         types=(list)),
                                     add_security_definitions=get_conf(
                                         config, 'security',
                                         'authorization_required'))

    # Enable cross-origin resource sharing
    enable_cors(connexion_app.app)

    return connexion_app, config
Exemple #4
0
def get_run_log(config: Dict, run_id: str, *args, **kwargs) -> Dict:
    """Gets detailed log information for specific run."""
    collection_runs = get_conf(config, 'database', 'collections', 'runs')
    document = collection_runs.find_one(filter={'run_id': run_id},
                                        projection={
                                            'user_id': True,
                                            'api': True,
                                            '_id': False,
                                        })

    # Raise error if workflow run was not found or has no task ID
    if document:
        run_log = document['api']
    else:
        logger.error("Run '{run_id}' not found.".format(run_id=run_id))
        raise WorkflowNotFound

    # Raise error trying to access workflow run that is not owned by user
    # Only if authorization enabled
    if 'user_id' in kwargs and document['user_id'] != kwargs['user_id']:
        logger.error(("User '{user_id}' is not allowed to access workflow run "
                      "'{run_id}'.").format(
                          user_id=kwargs['user_id'],
                          run_id=run_id,
                      ))
        raise Forbidden

    return run_log
Exemple #5
0
def task__cancel_run(
    self,
    run_id: str,
    task_id: str,
) -> None:
    """Revokes worfklow task and tries to cancel all running TES tasks."""
    try:
        config = current_app.config
        # Create MongoDB client
        mongo = create_mongo_client(
            app=current_app,
            config=config,
        )
        collection = mongo.db['runs']
        # Set run state to 'CANCELING'
        set_run_state(
            collection=collection,
            run_id=run_id,
            task_id=task_id,
            state='CANCELING',
        )
        # Cancel individual TES tasks
        __cancel_tes_tasks(
            collection=collection,
            run_id=run_id,
            url=get_conf(config, 'tes', 'url'),
            timeout=get_conf(config, 'tes', 'timeout'),
        )

    except SoftTimeLimitExceeded as e:
        set_run_state(
            collection=collection,
            run_id=run_id,
            task_id=task_id,
            state='SYSTEM_ERROR',
        )
        logger.warning(
            (
                "Canceling workflow run '{run_id}' timed out. Run state "
                "was set to 'SYSTEM_ERROR'. Original error message: "
                "{type}: {msg}"
            ).format(
                run_id=run_id,
                type=type(e).__name__,
                msg=e,
            )
        )
def create_mongo_client(
    app: Flask,
    config: Dict,
):
    """Instantiate MongoDB client."""
    uri = 'mongodb://{host}:{port}/{name}'.format(
        host=get_conf(config, 'database', 'host'),
        port=get_conf(config, 'database', 'port'),
        name=get_conf(config, 'database', 'name'),
    )
    mongo = PyMongo(app, uri=uri)
    logger.info(("Registered database '{name}' at URI '{uri}' with Flask "
                 'application.').format(
                     name=get_conf(config, 'database', 'name'),
                     uri=uri,
                 ))
    return mongo
Exemple #7
0
def list_runs(config: Dict, *args, **kwargs) -> Dict:
    """Lists IDs and status for all workflow runs."""
    collection_runs = get_conf(config, 'database', 'collections', 'runs')

    # Fall back to default page size if not provided by user
    if 'page_size' in kwargs:
        page_size = kwargs['page_size']
    else:
        page_size = (config['api']['endpoint_params']['default_page_size'])

    # Extract/set page token
    if 'page_token' in kwargs:
        page_token = kwargs['page_token']
    else:
        page_token = ''

    # Initialize filter dictionary
    filter_dict = {}

    # Add filter for user-owned runs if user ID is available
    if 'user_id' in kwargs:
        filter_dict['user_id'] = kwargs['user_id']

    # Add pagination filter based on last object ID
    if page_token != '':
        filter_dict['_id'] = {'$lt': ObjectId(page_token)}

    # Query database for workflow runs
    cursor = collection_runs.find(
        filter=filter_dict,
        projection={
            'run_id': True,
            'api.state': True,
        }
        # Sort results by descending object ID (+/- newest to oldest)
    ).sort('_id', -1
           # Implement page size limit
           ).limit(page_size)

    # Convert cursor to list
    runs_list = list(cursor)

    # Get next page token from ID of last run in cursor
    if runs_list:
        next_page_token = str(runs_list[-1]['_id'])
    else:
        next_page_token = ''

    # Reshape list of runs
    for run in runs_list:
        del run['_id']
        run['state'] = run['api']['state']
        del run['api']

    # Build and return response
    return {'next_page_token': next_page_token, 'runs': runs_list}
Exemple #8
0
def register_openapi(app: App,
                     specs: List[Dict] = [],
                     add_security_definitions: bool = True) -> App:
    """Registers OpenAPI specs with Connexion app."""
    # Iterate over list of API specs
    for spec in specs:

        # Get _this_ directory
        path = os.path.join(
            os.path.abspath(os.path.dirname(os.path.realpath(__file__))),
            get_conf(spec, 'path'))

        # Add security definitions to copy of specs
        if add_security_definitions:
            path = __add_security_definitions(in_file=path)

        # Generate API endpoints from OpenAPI spec
        try:
            app.add_api(
                path,
                strict_validation=get_conf(spec, 'strict_validation'),
                validate_responses=get_conf(spec, 'validate_responses'),
                swagger_ui=get_conf(spec, 'swagger_ui'),
                swagger_json=get_conf(spec, 'swagger_json'),
            )

            logger.info("API endpoints specified in '{path}' added.".format(
                path=path, ))

        except (FileNotFoundError, PermissionError) as e:
            logger.critical(
                ("API specification file not found or accessible at "
                 "'{path}'. Execution aborted. Original error message: "
                 "{type}: {msg}").format(
                     path=path,
                     type=type(e).__name__,
                     msg=e,
                 ))
            raise SystemExit(1)

    return (app)
def create_mongo_client(
    app: Flask,
    config: Dict,
):
    """Register MongoDB uri and credentials."""
    if os.environ.get('MONGO_USERNAME') != '':
        auth = '{username}:{password}@'.format(
            username=os.environ.get('MONGO_USERNAME'),
            password=os.environ.get('MONGO_PASSWORD'),
        )
    else:
        auth = ''

    app.config['MONGO_URI'] = 'mongodb://{auth}{host}:{port}/{dbname}'.format(
        host=os.environ.get('MONGO_HOST', get_conf(config, 'database',
                                                   'host')),
        port=os.environ.get('MONGO_PORT', get_conf(config, 'database',
                                                   'port')),
        dbname=os.environ.get('MONGO_DBNAME',
                              get_conf(config, 'database', 'name')),
        auth=auth)
    """Instantiate MongoDB client."""
    mongo = PyMongo(app)
    logger.info(
        ("Registered database '{name}' at URI '{uri}':'{port}' with Flask "
         'application.').format(
             name=os.environ.get('MONGO_DBNAME',
                                 get_conf(config, 'database', 'name')),
             uri=os.environ.get('MONGO_HOST',
                                get_conf(config, 'database', 'host')),
             port=os.environ.get('MONGO_PORT',
                                 get_conf(config, 'database', 'port'))))
    return mongo
Exemple #10
0
def create_celery_app(app: Flask) -> Celery:
    """Creates Celery application and configures it from Flask app."""
    broker = 'pyamqp://{host}:{port}//'.format(
        host=os.environ.get('RABBIT_HOST', get_conf(app.config, 'celery', 'broker_host')),
        port=os.environ.get('RABBIT_PORT', get_conf(app.config, 'celery', 'broker_port')),
    )
    backend = get_conf(app.config, 'celery', 'result_backend')
    include = get_conf_type(app.config, 'celery', 'include', types=(list))
    maxsize = get_conf(app.config, 'celery', 'message_maxsize')

    # Instantiate Celery app
    celery = Celery(
        app=__name__,
        broker=broker,
        backend=backend,
        include=include,
    )
    logger.info("Celery app created from '{calling_module}'.".format(
        calling_module=':'.join([stack()[1].filename, stack()[1].function])
    ))

    # Set Celery options
    celery.Task.resultrepr_maxsize = maxsize
    celery.amqp.argsrepr_maxsize = maxsize
    celery.amqp.kwargsrepr_maxsize = maxsize

    # Update Celery app configuration with Flask app configuration
    celery.conf.update(app.config)
    logger.info('Celery app configured.')

    class ContextTask(celery.Task):  # type: ignore
        # https://github.com/python/mypy/issues/4284)
        def __call__(self, *args, **kwargs):
            with app.app_context():
                return self.run(*args, **kwargs)

    celery.Task = ContextTask
    logger.debug("App context added to 'celery.Task' class.")

    return celery
Exemple #11
0
    def wrapper(*args, **kwargs):

        # Check if authentication is enabled
        if get_conf(current_app.config, 'security', 'authorization_required'):

            # Parse token from HTTP header
            token = parse_jwt_from_header(
                header_name=get_conf(current_app.config, 'security', 'jwt',
                                     'header_name'),
                expected_prefix=get_conf(current_app.config, 'security', 'jwt',
                                         'token_prefix'),
            )

            # Decode token
            try:
                token_data = decode(
                    jwt=token,
                    key=get_conf(current_app.config, 'security', 'jwt',
                                 'public_key'),
                    algorithms=get_conf(current_app.config, 'security', 'jwt',
                                        'algorithm'),
                    verify=True,
                )
            except Exception as e:
                logger.error(
                    ('Authentication token could not be decoded. Original '
                     'error message: {type}: {msg}').format(
                         type=type(e).__name__,
                         msg=e,
                     ))
                raise Unauthorized

            # Validate claims
            identity_claim = get_conf(current_app.config, 'security', 'jwt',
                                      'identity_claim')
            validate_claims(
                token_data=token_data,
                required_claims=[identity_claim],
            )

            # Extract user ID
            user_id = token_data[identity_claim]

            # Return wrapped function with token data
            return fn(token=token,
                      token_data=token_data,
                      user_id=user_id,
                      *args,
                      **kwargs)

        # Return wrapped function without token data
        else:
            return fn(*args, **kwargs)
def register_mongodb(app: Flask) -> Flask:
    """Instantiates database and initializes collections."""
    config = app.config

    # Instantiante PyMongo client
    mongo = create_mongo_client(
        app=app,
        config=config,
    )

    # Add database
    db = mongo.db[os.environ.get('MONGO_DBNAME',
                                 get_conf(config, 'database', 'name'))]

    # Add database collection for '/service-info'
    collection_service_info = mongo.db['service-info']
    logger.debug("Added database collection 'service_info'.")

    # Add database collection for '/runs'
    collection_runs = mongo.db['runs']
    collection_runs.create_index([
        ('run_id', ASCENDING),
        ('task_id', ASCENDING),
    ],
                                 unique=True,
                                 sparse=True)
    logger.debug("Added database collection 'runs'.")

    # Add database and collections to app config
    config['database']['database'] = db
    config['database']['collections'] = dict()
    config['database']['collections']['runs'] = collection_runs
    config['database']['collections']['service_info'] = collection_service_info
    app.config = config

    # Initialize service info
    logger.debug('Initializing service info...')
    get_service_info(config, silent=True)

    return app
Exemple #13
0
def list_runs(config: Dict, *args, **kwargs) -> Dict:
    """Lists IDs and status for all workflow runs."""
    collection_runs = get_conf(config, 'database', 'collections', 'runs')

    # TODO: stable ordering (newest last?)
    # TODO: implement next page token

    # Fall back to default page size if not provided by user
    # TODO: uncomment when implementing pagination
    # if 'page_size' in kwargs:
    #     page_size = kwargs['page_size']
    # else:
    #     page_size = (
    #         cnx_app.app.config
    #         ['api']
    #         ['endpoint_params']
    #         ['default_page_size']
    # )

    # Query database for workflow runs
    if 'user_id' in kwargs:
        filter_dict = {'user_id': kwargs['user_id']}
    else:
        filter_dict = {}
    cursor = collection_runs.find(filter=filter_dict,
                                  projection={
                                      'run_id': True,
                                      'state': True,
                                      '_id': False,
                                  })

    runs_list = list()
    for record in cursor:
        runs_list.append(record)

    response = {'next_page_token': 'token', 'runs': runs_list}
    return response
def __run_workflow(config: Dict, document: Dict, **kwargs) -> None:
    """Helper function `run_workflow()`."""
    tes_url = get_conf(config, 'tes', 'url')
    remote_storage_url = get_conf(config, 'storage', 'remote_storage_url')
    run_id = document['run_id']
    task_id = document['task_id']
    tmp_dir = document['internal']['tmp_dir']
    cwl_path = document['internal']['cwl_path']
    param_file_path = document['internal']['param_file_path']

    # Build command
    command_list = [
        'cwl-tes', '--debug', '--leave-outputs', '--remote-storage-url',
        remote_storage_url, '--tes', tes_url, cwl_path, param_file_path
    ]

    # Add authorization parameters
    if 'token' in kwargs:
        auth_params = [
            '--token-public-key',
            get_conf(config, 'security', 'jwt',
                     'public_key').encode('unicode_escape').decode('utf-8'),
            '--token',
            kwargs['token'],
        ]
        command_list[2:2] = auth_params

    # TEST CASE FOR SYSTEM ERROR
    # command_list = [
    #     '/path/to/non_existing/script',
    # ]
    # TEST CASE FOR EXECUTOR ERROR
    # command_list = [
    #     '/bin/false',
    # ]
    # TEST CASE FOR SLOW COMPLETION WITH ARGUMENT (NO STDOUT/STDERR)
    # command_list = [
    #     'sleep',
    #     '30',
    # ]

    # Get timeout duration
    timeout_duration = get_conf(
        config,
        'api',
        'endpoint_params',
        'timeout_run_workflow',
    )

    # Execute command as background task
    logger.info(("Starting execution of run '{run_id}' as task '{task_id}' in "
                 "'{tmp_dir}'...").format(
                     run_id=run_id,
                     task_id=task_id,
                     tmp_dir=tmp_dir,
                 ))
    task__run_workflow.apply_async(
        None,
        {
            'command_list': command_list,
            'tmp_dir': tmp_dir,
        },
        task_id=task_id,
        soft_time_limit=timeout_duration,
    )
    return None
def __create_run_environment(config: Dict, document: Dict, **kwargs) -> Dict:
    """Creates unique run identifier and permanent and temporary storage
    directories for current run."""
    collection_runs = get_conf(config, 'database', 'collections', 'runs')
    out_dir = get_conf(config, 'storage', 'permanent_dir')
    tmp_dir = get_conf(config, 'storage', 'tmp_dir')
    run_id_charset = eval(get_conf(config, 'database', 'run_id', 'charset'))
    run_id_length = get_conf(config, 'database', 'run_id', 'length')

    # Keep on trying until a unique run id was found and inserted
    # TODO: If no more possible IDs => inf loop; fix (raise custom error; 500
    #       to user)
    while True:

        # Create unique run and task ids
        run_id = __create_run_id(
            charset=run_id_charset,
            length=run_id_length,
        )
        task_id = uuid()

        # Set temporary and output directories
        current_tmp_dir = os.path.abspath(os.path.join(tmp_dir, run_id))
        current_out_dir = os.path.abspath(os.path.join(out_dir, run_id))

        # Try to create workflow run directory (temporary)
        try:
            # TODO: Think about permissions
            # TODO: Add working dir (currently one has to run the app from
            #       outermost dir)
            os.makedirs(current_tmp_dir)
            os.makedirs(current_out_dir)

        # Try new run id if directory already exists
        except FileExistsError:
            continue

        # Add run/task/user identifier, temp/output directories to document
        document['run_id'] = run_id
        document['task_id'] = task_id
        if 'user_id' in kwargs:
            document['user_id'] = kwargs['user_id']
        else:
            document['user_id'] = None
        document['internal']['tmp_dir'] = current_tmp_dir
        document['internal']['out_dir'] = current_out_dir

        # Process worflow attachments
        document = __process_workflow_attachments(document)

        # Try to insert document into database
        try:
            collection_runs.insert(document)

        # Try new run id if document already exists
        except DuplicateKeyError:

            # And remove run directories created previously
            shutil.rmtree(current_tmp_dir, ignore_errors=True)
            shutil.rmtree(current_out_dir, ignore_errors=True)

            continue

        # Catch other database errors
        # TODO: implement properly
        except Exception as e:
            print('Database error')
            print(e)
            break

        # Exit loop
        break

    return document
Exemple #16
0
    def wrapper(*args, **kwargs):

        # Check if authentication is enabled
        if get_conf(current_app.config, 'security', 'authorization_required'):

            # Get config parameters
            validation_methods = get_conf_type(
                current_app.config,
                'security',
                'jwt',
                'validation_methods',
                types=(List),
            )
            validation_checks = get_conf(
                current_app.config,
                'security',
                'jwt',
                'validation_checks',
            )
            algorithms = get_conf_type(
                current_app.config,
                'security',
                'jwt',
                'algorithms',
                types=(List),
            )
            expected_prefix = get_conf(current_app.config, 'security', 'jwt',
                                       'token_prefix')
            header_name = get_conf(current_app.config, 'security', 'jwt',
                                   'header_name')
            claim_key_id = get_conf(current_app.config, 'security', 'jwt',
                                    'claim_key_id')
            claim_issuer = get_conf(current_app.config, 'security', 'jwt',
                                    'claim_issuer')
            claim_identity = get_conf(current_app.config, 'security', 'jwt',
                                      'claim_identity')

            # Ensure that at least one validation method was configured
            if not len(validation_methods):
                logger.error("No JWT validation methods configured.")
                raise Unauthorized

            # Ensure that a valid validation checks argument was configured
            if validation_checks == 'any':
                required_validations = 1
            elif validation_checks == 'all':
                required_validations = len(validation_methods)
            else:
                logger.error(
                    ("Illegal argument '{validation_checks} passed to "
                     "configuration paramater 'validation_checks'. Allowed "
                     "values: 'any', 'all'"))
                raise Unauthorized

            # Parse JWT token from HTTP header
            jwt = parse_jwt_from_header(
                header_name=header_name,
                expected_prefix=expected_prefix,
            )

            # Initialize validation counter
            validated = 0

            # Validate JWT via /userinfo endpoint
            if 'userinfo' in validation_methods \
                and validated < required_validations:
                logger.info(
                    ("Validating JWT via identity provider's '/userinfo' "
                     "endpoint..."))
                claims = validate_jwt_via_userinfo_endpoint(
                    jwt=jwt,
                    algorithms=algorithms,
                    claim_issuer=claim_issuer,
                )
                if claims:
                    validated += 1

            # Validate JWT via public key
            if 'public_key' in validation_methods \
                and validated < required_validations:
                logger.info(
                    ("Validating JWT via identity provider's public key..."))
                claims = validate_jwt_via_public_key(
                    jwt=jwt,
                    algorithms=algorithms,
                    claim_key_id=claim_key_id,
                    claim_issuer=claim_issuer,
                )
                if claims:
                    validated += 1

            # Check whether enough validation checks passed
            if not validated == required_validations:
                logger.error(
                    ("Insufficient number of JWT validation checks passed."))
                raise Unauthorized

            # Ensure that specified identity claim is available
            if not validate_jwt_claims(
                    claim_identity,
                    claims=claims,
            ):
                raise Unauthorized

            # Return wrapped function with token data
            return fn(jwt=jwt,
                      claims=claims,
                      user_id=claims[claim_identity],
                      *args,
                      **kwargs)

        # Return wrapped function without token data
        else:
            return fn(*args, **kwargs)
Exemple #17
0
    # Register MongoDB
    connexion_app.app = register_mongodb(connexion_app.app)

    # Register error handlers
    connexion_app = register_error_handlers(connexion_app)

    # Create Celery app and register background task monitoring service
    register_task_service(connexion_app.app)

    # Register OpenAPI specs
    connexion_app = register_openapi(app=connexion_app,
                                     specs=get_conf_type(config,
                                                         'api',
                                                         'specs',
                                                         types=(list)),
                                     add_security_definitions=get_conf(
                                         config, 'security',
                                         'authorization_required'))

    # Enable cross-origin resource sharing
    enable_cors(connexion_app.app)

    return connexion_app, config


if __name__ == '__main__':
    connexion_app, config = run_server()
    # Run app
    connexion_app.run(use_reloader=get_conf(config, 'server', 'use_reloader'))
Exemple #18
0
from wes_elixir.config.config_parser import get_conf
from wes_elixir.config.app_config import parse_app_config

# Source the WES config for defaults
flask_config = parse_app_config(config_var='WES_CONFIG')

# Gunicorn number of workers and threads
workers = int(os.environ.get('GUNICORN_PROCESSES', '3'))
threads = int(os.environ.get('GUNICORN_THREADS', '1'))

forwarded_allow_ips = '*'

# Gunicorn bind address
bind = '{address}:{port}'.format(
    address=get_conf(flask_config, 'server', 'host'),
    port=get_conf(flask_config, 'server', 'port'),
)

# Source the environment variables for the Gunicorn workers
raw_env = [
    "WES_CONFIG=%s" % os.environ.get('WES_CONFIG', ''),
    "RABBIT_HOST=%s" % os.environ.get(
        'RABBIT_HOST', get_conf(flask_config, 'celery', 'broker_host')),
    "RABBIT_PORT=%s" % os.environ.get(
        'RABBIT_PORT', get_conf(flask_config, 'celery', 'broker_port')),
    "MONGO_HOST=%s" %
    os.environ.get('MONGO_HOST', get_conf(flask_config, 'database', 'host')),
    "MONGO_PORT=%s" %
    os.environ.get('MONGO_PORT', get_conf(flask_config, 'database', 'port')),
    "MONGO_DBNAME=%s" %