Пример #1
0
def create_celery_app(app: Flask) -> Celery:
    """Creates Celery application and configures it from Flask app."""
    broker = 'pyamqp://{host}:{port}//'.format(
        host=os.environ.get('RABBIT_HOST', get_conf(app.config, 'celery', 'broker_host')),
        port=os.environ.get('RABBIT_PORT', get_conf(app.config, 'celery', 'broker_port')),
    )
    backend = get_conf(app.config, 'celery', 'result_backend')
# TODO: TES   include = get_conf_type(app.config, 'celery', 'include', types=(list))

    # Instantiate Celery app
    celery = Celery(
        app=__name__,
        broker=broker,
        backend=backend,
# TODO: TES        include=include,
    )
    logger.info("Celery app created from '{calling_module}'.".format(
        calling_module=':'.join([stack()[1].filename, stack()[1].function])
    ))

    # Update Celery app configuration with Flask app configuration
    celery.conf.update(app.config)
    logger.info('Celery app configured.')

    class ContextTask(celery.Task):  # type: ignore
        # https://github.com/python/mypy/issues/4284)
        def __call__(self, *args, **kwargs):
            with app.app_context():
                return self.run(*args, **kwargs)

    celery.Task = ContextTask
    logger.debug("App context added to 'celery.Task' class.")

    return celery
def _create_task_document(
    config: Dict,
    document: Dict,
    sender: str,
    init_state: str = 'UNKNOWN',
) -> Dict:
    """
    Creates unique task identifier and inserts task document into database.
    """
    collection_tasks = get_conf(config, 'database', 'collections', 'tasks')
    id_charset = eval(get_conf(config, 'database', 'task_id', 'charset'))
    id_length = get_conf(config, 'database', 'task_id', 'length')

    # Keep on trying until a unique run id was found and inserted
    # TODO: If no more possible IDs => inf loop; fix (raise customerror; 500
    #       to user)
    while True:

        # Create unique task and Celery IDs
        task_id = _create_uuid(
            charset=id_charset,
            length=id_length,
        )
        worker_id = uuid()

        # Add task, work, user and run identifiers
        document['task_id'] = document['task']['id'] = task_id
        document['worker_id'] = worker_id
        document['sender'] = sender
        document['user_id'] = None
        document['token'] = None
        document['run_id'] = None
        document['run_id_secondary'] = None

        # Set initial state
        document['task']['state'] = init_state

        # Try to insert document into database
        try:
            collection_tasks.insert(document)

        # Try new run id if document already exists
        except DuplicateKeyError:
            continue

        # Catch other database errors
        except Exception as e:
            logger.exception(('Database error. Original error message {type}: '
                              "{msg}").format(
                                  type=type(e).__name__,
                                  msg=e,
                              ))
            break

        # Exit loop
        break

    # Return
    return document
Пример #3
0
def register_mongodb(app: Flask) -> Flask:
    """Instantiates database and initializes collections."""
    config = app.config

    # Instantiante PyMongo client
    mongo = create_mongo_client(
        app=app,
        config=config,
    )

    # Add database
    db = mongo.db[os.environ.get('MONGO_DBNAME',
                                 get_conf(config, 'database', 'name'))]

    # Add database collection for '/service-info'
    collection_service_info = mongo.db['service-info-proxy-tes']
    logger.debug("Added database collection 'service_info'.")

    # Add database collection for '/runs'
    collection_runs = mongo.db['runs']
    logger.debug("Added database collection 'runs'.")

    # Add database and collections to app config
    config['database']['database'] = db
    config['database']['collections'] = dict()
    config['database']['collections']['runs'] = collection_runs
    config['database']['collections'][
        'service_info_proxy_tes'] = collection_service_info
    app.config = config

    # Initialize service info
    logger.debug('Initializing service info...')
    get_service_info(config, silent=True)

    return app
Пример #4
0
def run_server():

    # Configure logger
    configure_logging(config_var='TES_CONFIG_LOG')

    # Parse app configuration
    config = parse_app_config(config_var='TES_CONFIG')

    # Create Connexion app
    connexion_app = create_connexion_app(config)

    # Register MongoDB
    connexion_app.app = register_mongodb(connexion_app.app)

    # Register error handlers
    connexion_app = register_error_handlers(connexion_app)

    # Create Celery app and register background task monitoring service
    register_task_service(connexion_app.app)

    # Register OpenAPI specs
    connexion_app = register_openapi(
        app=connexion_app,
        specs=get_conf_type(config, 'api', 'specs', types=(list)),
        spec_dir=get_conf(config, 'storage', 'spec_dir'),
        add_security_definitions=True,
    )

    # Enable cross-origin resource sharing
    enable_cors(connexion_app.app)

    return connexion_app, config
Пример #5
0
def create_mongo_client(
    app: Flask,
    config: Dict,
):
    """Register MongoDB uri and credentials."""
    if os.environ.get('MONGO_USERNAME') != '':
        auth = '{username}:{password}@'.format(
            username=os.environ.get('MONGO_USERNAME'),
            password=os.environ.get('MONGO_PASSWORD'),
        )
    else:
        auth = ''

    app.config['MONGO_URI'] = 'mongodb://{auth}{host}:{port}/{dbname}'.format(
        host=os.environ.get('MONGO_HOST', get_conf(config, 'database',
                                                   'host')),
        port=os.environ.get('MONGO_PORT', get_conf(config, 'database',
                                                   'port')),
        dbname=os.environ.get('MONGO_DBNAME',
                              get_conf(config, 'database', 'name')),
        auth=auth)
    """Instantiate MongoDB client."""
    mongo = PyMongo(app)
    logger.info(
        ("Registered database '{name}' at URI '{uri}':'{port}' with Flask "
         'application.').format(
             name=os.environ.get('MONGO_DBNAME',
                                 get_conf(config, 'database', 'name')),
             uri=os.environ.get('MONGO_HOST',
                                get_conf(config, 'database', 'host')),
             port=os.environ.get('MONGO_PORT',
                                 get_conf(config, 'database', 'port'))))
    return mongo
Пример #6
0
def __add_config_to_connexion_app(app: App, config: Mapping) -> App:
    """Adds configuration to Flask app and replaces default Connexion and Flask
    settings."""
    # Replace Connexion app settings
    app.host = get_conf(config, 'server', 'host')
    app.port = get_conf(config, 'server', 'port')
    app.debug = get_conf(config, 'server', 'debug')

    # Replace Flask app settings
    app.app.config['DEBUG'] = app.debug
    app.app.config['ENV'] = get_conf(config, 'server', 'environment')
    app.app.config['TESTING'] = get_conf(config, 'server', 'testing')

    # Log Flask config
    logger.debug('Flask app settings:')
    for (key, value) in app.app.config.items():
        logger.debug('* {}: {}'.format(key, value))

    # Add user configuration to Flask app config
    app.app.config.update(config)

    logger.info('Connexion app configured.')
    return app
Пример #7
0
def create_mongo_client(
    app: Flask,
    config: Dict,
):
    """Register MongoDB uri and credentials."""
    # Set authentication
    username = os.getenv('MONGO_USERNAME', '')
    password = os.getenv('MONGO_PASSWORD', '')
    if username:
        auth = '{username}:{password}@'.format(
            username=username,
            password=password,
        )
    else:
        auth = ''

    # Compile Mongo URI string
    app.config['MONGO_URI'] = 'mongodb://{auth}{host}:{port}/{dbname}'.format(
        host=os.getenv('MONGO_HOST', get_conf(config, 'database', 'host')),
        port=os.getenv('MONGO_PORT', get_conf(config, 'database', 'port')),
        dbname=os.getenv('MONGO_DBNAME', get_conf(config, 'database', 'name')),
        auth=auth
    )

    # Instantiate MongoDB client
    mongo = PyMongo(app)
    logger.info(
        (
            "Registered database at '{mongo_uri}' with Flask application."
        ).format(
            mongo_uri=app.config['MONGO_URI']
        )
    )

    # Return Mongo client
    return mongo
Пример #8
0
def register_openapi(
    app: App,
    specs: List[Dict] = [],
    spec_dir: Optional[str] = None,
    add_security_definitions: bool = True,
) -> App:
    """Registers OpenAPI specs with Connexion app."""
    # Iterate over list of API specs
    for spec in specs:

        # Get _this_ directory
        path = os.path.join(
            os.path.abspath(os.path.dirname(os.path.realpath(__file__))),
            get_conf(spec, 'path'))

        # Convert JSON to YAML
        if get_conf(spec, 'type') == 'json':
            path = __json_to_yaml(path)

        # Add security definitions to copy of specs
        if add_security_definitions:
            path = __add_security_definitions(
                in_file=path,
                out_dir=spec_dir,
            )

        # Generate API endpoints from OpenAPI spec
        try:
            app.add_api(
                path,
                strict_validation=get_conf(spec, 'strict_validation'),
                validate_responses=get_conf(spec, 'validate_responses'),
                swagger_ui=get_conf(spec, 'swagger_ui'),
                swagger_json=get_conf(spec, 'swagger_json'),
            )

            logger.info("API endpoints specified in '{path}' added.".format(
                path=path, ))

        except (FileNotFoundError, PermissionError) as e:
            logger.critical(
                ("API specification file not found or accessible at "
                 "'{path}'. Execution aborted. Original error message: "
                 "{type}: {msg}").format(
                     path=path,
                     type=type(e).__name__,
                     msg=e,
                 ))
            raise SystemExit(1)

    return (app)
Пример #9
0
    def wrapper(*args, **kwargs):

        # Check if authentication is enabled
        if get_conf(
                current_app.config,
                'security',
                'authorization_required',
        ):

            jwt = JWT(request=request)
            jwt.validate()
            jwt.get_user()
            ## Create JWT instance
            #try:
            #    jwt = JWT(request=request)
            #except Exception as e:
            #    raise Unauthorized from e

            ## Validate JWT
            #try:
            #    jwt.validate()
            #except Exception as e:
            #    raise Unauthorized from e

            ## Get user ID
            #try:
            #    jwt.get_user()
            #except Exception as e:
            #    raise Unauthorized from e

            # Return wrapped function with token data
            return fn(jwt=jwt.jwt,
                      claims=jwt.claims,
                      user_id=jwt.user,
                      *args,
                      **kwargs)

        # Return wrapped function without token data
        else:
            return fn(*args, **kwargs)
Пример #10
0
    # Register MongoDB
    connexion_app.app = register_mongodb(connexion_app.app)

    # Register error handlers
    connexion_app = register_error_handlers(connexion_app)

    # Create Celery app and register background task monitoring service
    register_task_service(connexion_app.app)

    # Register OpenAPI specs
    connexion_app = register_openapi(
        app=connexion_app,
        specs=get_conf_type(config, 'api', 'specs', types=(list)),
        spec_dir=get_conf(config, 'storage', 'spec_dir'),
        add_security_definitions=True,
    )

    # Enable cross-origin resource sharing
    enable_cors(connexion_app.app)

    return connexion_app, config


if __name__ == '__main__':
    connexion_app, config = run_server()
    # Run app
    connexion_app.run(
        use_reloader=get_conf(config, 'server', 'use_reloader')
    )
def list_tasks(
    config: Dict,
    *args,
    **kwargs,
) -> Dict:
    """Lists IDs and status for all workflow runs."""
    # Get collection
    collection_tasks = get_conf(config, 'database', 'collections', 'tasks')

    # TODO: stable ordering (newest last?)
    # TODO: implement next page token

    # Fall back to default page size if not provided by user
    # TODO: uncomment when implementing pagination
    # if 'page_size' in kwargs:
    #     page_size = kwargs['page_size']
    # else:
    #     page_size = (
    #         cnx_app.app.config
    #         ['api']
    #         ['endpoint_params']
    #         ['default_page_size']
    # )

    # Set filters
    if 'user_id' in kwargs:
        filter_dict = {'user_id': kwargs['user_id']}
    else:
        filter_dict = {}

    # Set projections
    projection_MINIMAL = {
        '_id': False,
        'task.id': True,
        'task.state': True,
    }
    
    projection_BASIC = {
        '_id': False,
        'task.inputs.content': False,
        'task.logs.system_logs': False,
        'task.logs.logs.stdout': False,
        'task.logs.logs.stderr': False,
    }
    projection_FULL = {
        '_id': False,
        'task': True,
    }

    # Check view mode
    if 'view' in kwargs:
        view = kwargs['view']
    else:
        view = "BASIC"
    if view == "MINIMAL":
        projection = projection_MINIMAL
    elif view == "BASIC":
        projection = projection_BASIC
    elif view == "FULL":
        projection = projection_FULL
    else:
        raise BadRequest 
    
    # Get tasks    
    cursor = collection_tasks.find(
        filter=filter_dict,
        projection=projection,
    )
    tasks_list = list()
    for record in cursor:
        tasks_list.append(record['task'])

    # Return response
    return {
        'next_page_token': 'token',
        'tasks': tasks_list
    }
def create_task(config: Dict, sender: str, *args, **kwargs) -> Dict:
    """Relays task to best TES instance; returns universally unique task id."""
    # Validate input data
    if not 'body' in kwargs:
        raise BadRequest

    # TODO (MAYBE): Check service info compatibility

    # Initialize database document
    document: Dict = dict()
    document['request'] = kwargs['body']
    document['task'] = kwargs['body']
    document['tes_uri'] = None
    document['task_id_tes'] = None

    # Get known TES instances
    document['tes_uris'] = get_conf_type(
        config,
        'tes',
        'service_list',
        types=(list),
    )

    # Create task document and insert into database
    document = _create_task_document(
        config=config,
        document=document,
        sender=sender,
        init_state='UNKNOWN',
    )

    # Get timeout duration
    timeout = get_conf(
        config,
        'api',
        'endpoint_params',
        'timeout_task_execution',
    )
    if timeout is not None and timeout < 5:
        timeout = 5

    # Process and submit task asynchronously
    logger.info(("Starting submission of task '{task_id}' as worker task "
                 "'{worker_id}'...").format(
                     task_id=document['task_id'],
                     worker_id=document['worker_id'],
                 ))
    task__submit_task.apply_async(
        None,
        {
            'request': document['request'],
            'task_id': document['task_id'],
            'worker_id': document['worker_id'],
            'sender': sender,
            'tes_uris': document['tes_uris'],
        },
        worker_id=document['worker_id'],
        soft_time_limit=timeout,
    )

    # Return response
    return {'id': document['task_id']}
Пример #13
0
def cancel_task(config: Dict, id: str, *args, **kwargs) -> Dict:
    """Cancels running workflow."""
    collection = get_conf(config, 'database', 'collections', 'tasks')
    document = collection.find_one(filter={'task_id': id},
                                   projection={
                                       'task_id_tes': True,
                                       'tes_uri': True,
                                       'task.state': True,
                                       'user_id': True,
                                       'worker_id': True,
                                       '_id': False,
                                   })

    # Raise error if task was not found
    if not document:
        logger.error("Task '{id}' not found.".format(id=id))
        raise TaskNotFound

    # Raise error trying to access workflow run that is not owned by user
    # Only if authorization enabled
    if 'user_id' in kwargs and document['user_id'] != kwargs['user_id']:
        logger.error(
            ("User '{user_id}' is not allowed to access task '{id}'.").format(
                user_id=kwargs['user_id'],
                id=id,
            ))
        raise Forbidden

    # If task is in cancelable state...
    if document['task']['state'] in States.CANCELABLE or \
       document['task']['state'] in States.UNDEFINED:

        # Get timeout duration
        timeout = get_conf(
            config,
            'api',
            'endpoint_params',
            'timeout_service_calls',
        )

        # Cancel local task
        current_app.control.revoke(document['worker_id'],
                                   terminate=True,
                                   signal='SIGKILL')

        # Cancel remote task
        if document['tes_uri'] is not None and document[
                'task_id_tes'] is not None:
            cli = tes.HTTPClient(document['tes_uri'], timeout=timeout)
            try:
                cli.cancel_task(document['task_id_tes'])
            except HTTPError:
                # TODO: handle more robustly: only 400/Bad Request is okay;
                # TODO: other errors (e.g. 500) should be dealt with
                pass

        # Write log entry
        logger.info(
            ("Task '{id}' (worker ID '{worker_id}') was canceled.").format(
                id=id,
                worker_id=document['worker_id'],
            ))

        # Update task state
        set_task_state(
            collection=collection,
            task_id=id,
            worker_id=document['worker_id'],
            state='CANCELED',
        )

    return {}
def task__submit_task(
    self,
    request: Dict,
    task_id: str,
    worker_id: str,
    sender: str,
    tes_uris: List,
) -> None:
    """Processes task and delivers it to TES instance."""
    # Get app config
    config = current_app.config

    # Get timeout for service calls
    timeout_service_calls = get_conf(
        config,
        'api',
        'endpoint_params',
        'timeout_service_calls',
    )

    # Create MongoDB client
    mongo = create_mongo_client(
        app=current_app,
        config=config,
    )
    collection = mongo.db['tasks']

    # Process task
    try:

        # TODO (LATER): Get associated workflow run & related info
        # NOTE: 
        # - Get the following from callback via sender:
        #   - user_id
        #   - token
        #   - run_id
        #   - run_id_secondary (worker ID on WES)
        user_id = None
        token = "ey23f423n4fln2flk3nf23lfn"
        run_id = "RUN123"
        run_id_secondary = "1234-23141-12341-12341"

        # Update database document
        upsert_fields_in_root_object(
            collection=collection,
            worker_id=worker_id,
            root='',
            user_id=user_id,
            token=token,
            run_id=run_id,
            run_id_secondary=run_id_secondary
        )

        # TODO (LATER): Apply middleware
        # - Token validation / renewal
        # - TEStribute
        # - Replace DRS IDs

        # TODO (PROPERLY): Send task to TES instance
        try:
            task_id_tes, tes_uri = _send_task(
                tes_uris=tes_uris,
                request=request,
                token=token,
                timeout=timeout_service_calls,
            )
            logger.info(
                (
                    "Task '{task_id}' was sent to TES '{tes_uri}' under remote "
                    "task ID '{task_id_tes}'."
                ).format(
                    task_id=task_id,
                    tes_uri=tes_uri,
                    task_id_tes=task_id_tes,
                )
            )

        # Handle submission failure
        except Exception as e:
            task_id_tes = None
            tes_uri = None
            set_task_state(
                collection=collection,
                task_id=task_id,
                worker_id=worker_id,
                state='SYSTEM_ERROR',
            )
            logger.error(
                (
                    "Task '{task_id}' could not be sent to any TES instance. "
                    "Task state was set to 'SYSTEM_ERROR'. Original error "
                    "message: '{type}: {msg}'"
                ).format(
                    task_id=task_id,
                    type=type(e).__name__,
                    msg='.'.join(e.args),
                )
            )
        
        # TODO: Update database document
        document = upsert_fields_in_root_object(
            collection=collection,
            worker_id=worker_id,
            root='',
            task_id_tes=task_id_tes,
            tes_uri=tes_uri,
        )

        # TODO: Initiate polling
        interval = get_conf(
            config,
            'api',
            'endpoint_params',
            'interval_polling',
        )
        max_missed_heartbeats = get_conf(
            config,
            'api',
            'endpoint_params',
            'max_missed_heartbeats',
        )
        if tes_uri is not None and task_id_tes is not None:
            _poll_task(
                collection=collection,
                task_id=task_id,
                worker_id=worker_id,
                tes_uri=tes_uri,
                tes_task_id=task_id_tes,
                initial_state=document['task']['state'],
                token=token,
                interval=interval,
                max_missed_heartbeats=max_missed_heartbeats,
                timeout=timeout_service_calls,
            )

        # TODO (LATER): Logging

    except SoftTimeLimitExceeded as e:
        set_task_state(
            collection=collection,
            task_id=task_id,
            worker_id=worker_id,
            state='SYSTEM_ERROR',
        )
        logger.warning(
            (
                "Processing/submission of '{task_id}' timed out. Task state "
                "was set to 'SYSTEM_ERROR'. Original error message: "
                "{type}: {msg}"
            ).format(
                task_id=task_id,
                type=type(e).__name__,
                msg=e,
            )
        )
Пример #15
0
from pro_tes.config.config_parser import get_conf
from pro_tes.config.app_config import parse_app_config

# Source the WES config for defaults
flask_config = parse_app_config(config_var='TES_CONFIG')

# Gunicorn number of workers and threads
workers = int(os.environ.get('GUNICORN_PROCESSES', '3'))
threads = int(os.environ.get('GUNICORN_THREADS', '1'))

forwarded_allow_ips = '*'

# Gunicorn bind address
bind = '{address}:{port}'.format(
        address=get_conf(flask_config, 'server', 'host'),
        port=get_conf(flask_config, 'server', 'port'),
    )

# Source the environment variables for the Gunicorn workers
raw_env = [
    "TES_CONFIG=%s" % os.environ.get('TES_CONFIG', ''),
    "RABBIT_HOST=%s" % os.environ.get('RABBIT_HOST', get_conf(flask_config, 'celery', 'broker_host')),
    "RABBIT_PORT=%s" % os.environ.get('RABBIT_PORT', get_conf(flask_config, 'celery', 'broker_port')),
    "MONGO_HOST=%s" % os.environ.get('MONGO_HOST', get_conf(flask_config, 'database', 'host')),
    "MONGO_PORT=%s" % os.environ.get('MONGO_PORT', get_conf(flask_config, 'database', 'port')),
    "MONGO_DBNAME=%s" % os.environ.get('MONGO_DBNAME', get_conf(flask_config, 'database', 'name')),
    "MONGO_USERNAME=%s" % os.environ.get('MONGO_USERNAME', ''),
    "MONGO_PASSWORD=%s" % os.environ.get('MONGO_PASSWORD', '')
]
Пример #16
0
    register_task_service(connexion_app.app)

    # Register OpenAPI specs
    #connexion_app = register_openapi(
    #    app=connexion_app,
    #    specs=get_conf_type(config, 'api', 'specs', types=(list)),
    #    spec_dir=get_conf(config, 'storage', 'spec_dir'),
    #    add_security_definitions=True,
    #)
    connexion_app = register_openapi(app=connexion_app,
                                     specs=get_conf_type(config,
                                                         'api',
                                                         'specs',
                                                         types=(list)),
                                     spec_dir=get_conf(config, 'storage',
                                                       'spec_dir'),
                                     add_security_definitions=get_conf(
                                         config, 'security',
                                         'authorization_required'))

    # Enable cross-origin resource sharing
    enable_cors(connexion_app.app)

    return connexion_app, config


if __name__ == '__main__':
    connexion_app, config = run_server()
    # Run app
    connexion_app.run(use_reloader=get_conf(config, 'server', 'use_reloader'))
Пример #17
0
def run_workflow(
    config: Dict,
    body: Dict,
    sender: str,
    *args,
    **kwargs
) -> Dict:
    """Relays task to best TES instance; returns universally unique task id."""
    # Get config parameters
    authorization_required = get_conf(
        config,
        'security',
        'authorization_required'
    )
    endpoint_params = get_conf_type(
        config,
        'tes',
        'endpoint_params',
        types=(list),
    )
    security_params = get_conf_type(
        config,
        'security',
        'jwt',
    )
    remote_urls = get_conf_type(
        config,
        'tes',
        'service-list',
        types=(list),
    )
    
    # Get associated workflow run
    # TODO: get run_id, task_id and user_id
    
    # Set initial task state
    # TODO:
    
    # Set access token
    if authorization required:
        try:
            access_token = request_access_token(
                user_id=document['user_id'],
                token_endpoint=endpoint_params['token_endpoint'],
                timeout=endpoint_params['timeout_token_request'],
            )
            validate_token(
                token=access_token,
                key=security_params['public_key'],
                identity_claim=security_params['identity_claim'],
            )
        except Exception as e:
            logger.exception(
                (
                    'Could not get access token from token endpoint '
                    "'{token_endpoint}'. Original error message {type}: {msg}"
                ).format(
                    token_endpoint=endpoint_params['token_endpoint'],
                    type=type(e).__name__,
                    msg=e,
                )
            )
            raise Forbidden
    else:
        access_token = None

    # Order TES instances by priority
    testribute = TEStribute_Interface()
    remote_urls_ordered = testribute.order_endpoint_list(
        tes_json=body,
        endpoints=remote_urls,
        access_token=access_token,
        method=endpoint_params['tes_distribution_method'],
    )
    
    # Send task to best TES instance
    try:
        remote_id, remote_url = __send_task(
            urls=remote_urls_ordered,
            body=body,
            access_token=access_token,
            timeout=endpoint_params['timeout_tes_submission'],
        )
    except Exception as e:
        logger.exception('{type}: {msg}'.format(
            default_path=default_path,
            config_var=config_var,
            type=type(e).__name__,
            msg=e,
        )
        raise InternalServerError

    # Poll TES instance for state updates
    __initiate_state_polling(
        task_id=remote_id,
        run_id=document['run_id'],
        url=remote_url,
        interval_polling=endpoint_params['interval_polling'],
        timeout_polling=endpoint_params['timeout_polling'],
        max_time_polling=endpoint_params['max_time_polling'],
    )
    
    # Generate universally unique ID
    local_id = __amend_task_id(
        remote_id=remote_id,
        remote_url=remote_url,
        separator=endpoint_params['id_separator'],
        encoding=endpoint_params['id_encoding'],
    )
    
    # Format and return response
    response = {'id': local_id}
    return response


def request_access_token(
    user_id: str,
    token_endpoint: str,
    timeout: int = 5
) -> str:
    """Get access token from token endpoint."""
    try: 
        response = post(
            token_endpoint,
            data={'user_id': user_id},
            timeout=timeout
        )
    except Exception as e:
        raise
    if response.status_code != 200:
        raise ConnectionError(
            (
                "Could not access token endpoint '{endpoint}'. Received "
                "status code '{code}'."
            ).format(
                endpoint=token_endpoint,
                code=response.status_code
            )
        )
    return response.json()['access_token']


def validate_token(
    token:str,
    key:str,
    identity_claim:str,
) -> None:

    # Decode token
    try:
        token_data = decode(
            jwt=token,
            key=get_conf(
                current_app.config,
                'security',
                'jwt',
                'public_key'
            ),
            algorithms=get_conf(
                current_app.config,
                'security',
                'jwt',
                'algorithm'
            ),
            verify=True,
        )
    except Exception as e:
        raise ValueError(
            (
                'Authentication token could not be decoded. Original '
                'error message: {type}: {msg}'
            ).format(
                type=type(e).__name__,
                msg=e,
            )
        )

    # Validate claims
    identity_claim = get_conf(
        current_app.config,
        'security',
        'jwt',
        'identity_claim'
    )
    validate_claims(
        token_data=token_data,
        required_claims=[identity_claim],
    )


def __send_task(
    urls: List[str],
    body: Dict,
    timeout: int = 5
) -> Tuple[str, str]:
    """Send task to TES instance."""
    task = tes.Task(body)       # TODO: implement this properly
    for url in urls:
        # Try to submit task to TES instance
        try:
            cli = tes.HTTPClient(url, timeout=timeout)
            task_id = cli.create_task(task)
            # TODO: fix problem with marshaling
        # Issue warning and try next TES instance if task submission failed
        except Exception as e:
            logger.warning(
                (
                    "Task could not be submitted to TES instance '{url}'. "
                    'Trying next TES instance in list. Original error '
                    "message: {type}: {msg}"
                ).format(
                    url=url,
                    type=type(e).__name__,
                    msg=e,
                )
            )
            continue
        # Return task ID and URL of TES instance
        return (task_id, url)
    # Log error if no suitable TES instance was found
    raise ConnectionError(
        'Task could not be submitted to any known TES instance.'
    )


def __initiate_state_polling(
    task_id: str,
    run_id: str,
    url: str,
    interval_polling: int = 2,
    timeout_polling: int = 1,
    max_time_polling: Optional[int] = None
) -> None:
    """Initiate polling of TES instance for task state."""
    celery_id = uuid()
    logger.debug(
        (
            "Starting polling of TES task '{task_id}' in "
            "background task '{celery_id}'..."
        ).format(
            task_id=task_id,
            celery_id=celery_id,
        )
    )
    task__poll_task_state.apply_async(
        None,
        {
            'task_id': task_id,
            'run_id': run_id,
            'url': url,
            'interval': interval_polling,
            'timeout': timeout_polling,
        },
        task_id=celery_id,
        soft_time_limit=max_time_polling,
    )
    return None


def __amend_task_id(
    remote_id: str,
    remote_url: str,
    separator: str = '@',   # TODO: add to config
    encoding: str= 'utf-8'  # TODO: add to config
) -> str:
    """Appends base64 to remote task ID."""
    append = base64.b64encode(remote_url.encode(encoding))
    return separator.join([remote_id, append])
Пример #18
0
def get_task(config: Dict, id: str, *args, **kwargs) -> Dict:
    """Gets detailed log information for specific run."""
    # Get collection
    collection_tasks = get_conf(config, 'database', 'collections', 'tasks')

    # Set filters
    if 'user_id' in kwargs:
        filter_dict = {
            'user_id': kwargs['user_id'],
            'task.id': id,
        }
    else:
        filter_dict = {
            'task.id': id,
        }

    # Set projections
    projection_MINIMAL = {
        '_id': False,
        'task.id': True,
        'task.state': True,
    }

    projection_BASIC = {
        '_id': False,
        'task.inputs.content': False,
        'task.logs.system_logs': False,
        'task.logs.logs.stdout': False,
        'task.logs.logs.stderr': False,
    }
    projection_FULL = {
        '_id': False,
        'task': True,
    }

    # Check view mode
    if 'view' in kwargs:
        view = kwargs['view']
    else:
        view = "BASIC"
    if view == "MINIMAL":
        projection = projection_MINIMAL
    elif view == "BASIC":
        projection = projection_BASIC
    elif view == "FULL":
        projection = projection_FULL
    else:
        raise BadRequest

    # Get task
    document = collection_tasks.find_one(
        filter=filter_dict,
        projection=projection,
    )

    # Raise error if workflow run was not found or has no task ID
    if document:
        task = document['task']
    else:
        logger.error("Task '{id}' not found.".format(id=id))
        raise TaskNotFound

    # Raise error trying to access workflow run that is not owned by user
    # Only if authorization enabled
    if 'user_id' in kwargs and document['user_id'] != kwargs['user_id']:
        logger.error(
            "User '{user_id}' is not allowed to access task '{id}'.".format(
                user_id=kwargs['user_id'],
                id=id,
            ))
        raise Forbidden

    return task