Exemple #1
0
        async def _in_tx_mutation_logic(cur):
            run_db_response = await self._async_table.get_run(flow_name,
                                                              run_number,
                                                              cur=cur)
            if run_db_response.response_code != 200:
                # if something went wrong with get_run, just return the error from that directly
                # e.g. 404, or some other error. This is useful for the client (vs additional wrapping, etc).
                return run_db_response
            run = run_db_response.body
            existing_tag_set = set(run["tags"])
            existing_system_tag_set = set(run["system_tags"])

            if tags_to_remove_set & existing_system_tag_set:
                # We use 422 here to communicate that the request was well-formatted in terms of structure and
                # that the server understood what was being requested. However, it failed business rules.
                return DBResponse(
                    response_code=422,
                    body="Cannot remove tags that are existing system tags %s"
                    % str(tags_to_remove_set & existing_system_tag_set))

            # Apply removals before additions.
            # And, make sure no existing system tags get added as a user tag
            next_run_tag_set = (existing_tag_set - tags_to_remove_set) | (
                tags_to_add_set - existing_system_tag_set)
            if next_run_tag_set == existing_tag_set:
                return DBResponse(response_code=200,
                                  body={"tags": list(next_run_tag_set)})
            next_run_tags = list(next_run_tag_set)

            update_db_response = await self._async_table.update_run_tags(
                flow_name, run_number, next_run_tags, cur=cur)
            if update_db_response.response_code != 200:
                return update_db_response
            return DBResponse(response_code=200, body={"tags": next_run_tags})
Exemple #2
0
    async def get_tags(self,
                       conditions: List[str] = None,
                       values=[],
                       limit: int = 0,
                       offset: int = 0):
        sql_template = """
        SELECT DISTINCT tag
        FROM (
            SELECT JSONB_ARRAY_ELEMENTS_TEXT(tags||system_tags) AS tag
            FROM {table_name}
        ) AS t
        {conditions}
        {limit}
        {offset}
        """
        select_sql = sql_template.format(
            table_name=self.table_name,
            conditions="WHERE {}".format(" AND ".join(conditions))
            if conditions else "",
            limit="LIMIT {}".format(limit) if limit else "",
            offset="OFFSET {}".format(offset) if offset else "",
        )

        res, pagination = await self.execute_sql(select_sql=select_sql,
                                                 values=values,
                                                 serialize=False)

        # process the unserialized DBResponse
        _body = [row[0] for row in res.body]

        return DBResponse(res.response_code, _body), pagination
Exemple #3
0
async def apply_run_tags_to_db_response(flow_id, run_number,
                                        run_table_postgres,
                                        db_response: DBResponse) -> DBResponse:
    """
    We want read APIs to return steps, tasks and artifact objects with tags
    and system_tags set to their ancestral Run.

    This is a prerequisite for supporting Run-based tag mutation.
    """
    # we will return a modified copy of db_response
    new_db_response = copy.deepcopy(db_response)
    # Only replace tags if response code is legit
    # Object creation ought to return 201 (let's prepare for that)
    if new_db_response.response_code not in (200, 201):
        return new_db_response
    if isinstance(new_db_response.body, list):
        items_to_modify = new_db_response.body
    else:
        items_to_modify = [new_db_response.body]
    if not items_to_modify:
        return new_db_response
    # items_to_modify now references all the items we want to modify

    # The ancestral run must be successfully read from DB
    db_response_for_run = await run_table_postgres.get_run(flow_id, run_number)
    if db_response_for_run.response_code != 200:
        return DBResponse(response_code=500, body=db_response_for_run.body)
    run = db_response_for_run.body
    for item_as_dict in items_to_modify:
        item_as_dict['tags'] = run['tags']
        item_as_dict['system_tags'] = run['system_tags']
    return new_db_response
Exemple #4
0
    async def get_step_names(
            self,
            conditions: List[str] = [],
            values: List[str] = [],
            limit: int = 0,
            offset: int = 0) -> Tuple[DBResponse, DBPagination]:
        """
        Get a paginated set of step names.

        Parameters
        ----------
        conditions : List[str]
            list of conditions to pass the sql execute, with %s placeholders for values
        values : List[str]
            list of values to be passed for the sql execute.
        limit : int (optional) (default 0)
            limit for the number of results
        offset : int (optional) (default 0)
            offset for the results.

        Returns
        -------
        (DBResponse, DBPagination)
        """
        sql_template = """
            SELECT step_name FROM (
                SELECT DISTINCT step_name, flow_id, run_number, run_id
                FROM {table_name}
            ) T
            {conditions}
            {limit}
            {offset}
            """
        select_sql = sql_template.format(
            table_name=self.table_name,
            keys=",".join(self.select_columns),
            conditions=("WHERE {}".format(" AND ".join(conditions))
                        if conditions else ""),
            limit="LIMIT {}".format(limit) if limit else "",
            offset="OFFSET {}".format(offset) if offset else "")

        res, pag = await self.execute_sql(select_sql=select_sql,
                                          values=values,
                                          fetch_single=False,
                                          expanded=False,
                                          limit=limit,
                                          offset=offset,
                                          serialize=False)
        # process the unserialized DBResponse
        _body = [row[0] for row in res.body]

        return DBResponse(res.response_code, _body), pag
Exemple #5
0
    async def get_task_log(self, request, logtype=STDOUT):
        "fetches log and emits it as a list of rows wrapped in json"
        task = await self.get_task_by_request(request)
        if not task:
            return web_response(404, {'data': []})
        limit, page, reverse_order = get_pagination_params(request)

        lines, page_count = await read_and_output(self.cache, task, logtype,
                                                  limit, page, reverse_order)

        # paginated response
        response = DBResponse(200, lines)
        pagination = DBPagination(limit, limit * (page - 1),
                                  len(response.body), page)
        status, body = format_response_list(request, response, pagination,
                                            page, page_count)
        return web_response(status, body)
Exemple #6
0
    async def execute_sql(
            self,
            select_sql: str,
            values=[],
            fetch_single=False,
            expanded=False,
            limit: int = 0,
            offset: int = 0,
            serialize: bool = True) -> Tuple[DBResponse, DBPagination]:
        try:
            with (await
                  self.db.pool.cursor(cursor_factory=psycopg2.extras.DictCursor
                                      )) as cur:
                await cur.execute(select_sql, values)

                rows = []
                records = await cur.fetchall()
                if serialize:
                    for record in records:
                        # pylint-initial-ignore: Lack of __init__ makes this too hard for pylint
                        # pylint: disable=not-callable
                        row = self._row_type(**record)
                        rows.append(row.serialize(expanded))
                else:
                    rows = records

                count = len(rows)

                # Will raise IndexError in case fetch_single=True and there's no results
                body = rows[0] if fetch_single else rows

                pagination = DBPagination(
                    limit=limit,
                    offset=offset,
                    count=count,
                    page=math.floor(int(offset) / max(int(limit), 1)) + 1,
                )

                cur.close()
                return DBResponse(response_code=200, body=body), pagination
        except IndexError as error:
            return aiopg_exception_handling(error), None
        except (Exception, psycopg2.DatabaseError) as error:
            self.db.logger.exception("Exception occured")
            return aiopg_exception_handling(error), None
Exemple #7
0
    async def get_run_parameters(self, request):
        """
         ---
          description: Get parameters of a run
          tags:
          - Run
          parameters:
            - $ref: '#/definitions/Params/Path/flow_id'
            - $ref: '#/definitions/Params/Path/run_number'
            - $ref: '#/definitions/Params/Custom/invalidate'
          produces:
          - application/json
          responses:
              "200":
                  description: Returns parameters of a run
                  schema:
                    $ref: '#/definitions/ResponsesRunParameters'
              "405":
                  description: invalid HTTP Method
                  schema:
                    $ref: '#/definitions/ResponsesError405'
              "500":
                  description: Internal Server Error (with error id)
                  schema:
                    $ref: '#/definitions/ResponsesRunParametersError500'
        """
        flow_name = request.match_info['flow_id']
        run_number = request.match_info.get("run_number")

        invalidate_cache = query_param_enabled(request, "invalidate")

        # _artifact_store.get_run_parameters will translate run_number/run_id properly
        combined_results = await self._artifact_store.get_run_parameters(
            flow_name, run_number, invalidate_cache=invalidate_cache)

        postprocess_error = combined_results.get("postprocess_error", None)
        if postprocess_error:
            raise GetParametersFailed(
                postprocess_error["detail"], postprocess_error["id"], postprocess_error["traceback"])
        else:
            response = DBResponse(200, combined_results)

        status, body = format_response(request, response)

        return web_response(status, body)
Exemple #8
0
def test_format_response():
    request = make_mocked_request('GET',
                                  '/runs?_limit=10',
                                  headers={'Host': 'test'})

    db_response = DBResponse(response_code=200, body={"foo": "bar"})

    expected_response = {
        "data": {
            "foo": "bar"
        },
        "status": 200,
        "links": {
            "self": "http://test/runs?_limit=10"
        },
        "query": {
            "_limit": "10"
        },
    }

    status, response = format_response(request, db_response)
    assert json.dumps(response) == json.dumps(expected_response)
    assert status == 200
Exemple #9
0
def test_format_response_list_next_page_null():
    request = make_mocked_request('GET',
                                  '/runs?_limit=10&_page=2',
                                  headers={'Host': 'test'})

    db_response = DBResponse(response_code=200, body=[{"foo": "bar"}])
    pagination = DBPagination(limit=10, offset=0, count=1, page=2)

    expected_response = {
        "data": [{
            "foo": "bar"
        }],
        "status": 200,
        "links": {
            "self": "http://test/runs?_limit=10&_page=2",
            "first": "http://test/runs?_limit=10&_page=1",
            "prev": "http://test/runs?_limit=10&_page=1",
            "next": None,
            "last": None
        },
        "pages": {
            "self": 2,
            "first": 1,
            "prev": 1,
            "next": None,
            "last": None
        },
        "query": {
            "_limit": "10",
            "_page": "2"
        },
    }

    status, response = format_response_list(request, db_response, pagination,
                                            2)
    assert json.dumps(response) == json.dumps(expected_response)
    assert status == 200
Exemple #10
0
    async def postprocess(self, response: DBResponse, invalidate_cache=False):
        """
        Calls the refiner postprocessing to fetch Metaflow artifacts.

        Parameters
        ----------
        response : DBResponse
            The DBResponse to be refined

        Returns
        -------
        A refined DBResponse, or in case of errors, the original DBResponse
        """
        if FEATURE_REFINE_DISABLE:
            return response

        if response.response_code != 200 or not response.body:
            return response

        input = self._response_to_action_input(response)

        errors = {}

        def _event_stream(event):
            if event.get("type") == "error" and event.get("key"):
                # Get last element from cache key which usually translates to "target"
                target = event["key"].split(':')[-1:][0]
                errors[target] = event

        data = await self.fetch_data(input,
                                     event_stream=_event_stream,
                                     invalidate_cache=invalidate_cache)

        async def _process(record):
            target = self._record_to_action_input(record)

            if target in errors:
                # Add streamed postprocess errors if any
                record["postprocess_error"] = format_error_body(
                    errors[target].get("id"), errors[target].get("message"),
                    errors[target].get("traceback"))

            if target in data:
                success, value, detail, trace = unpack_processed_value(
                    data[target])
                if success:
                    record = await self.refine_record(record, value)
                else:
                    record['postprocess_error'] = format_error_body(
                        value if value else "artifact-handle-failed", detail
                        if detail else "Unknown error during postprocessing",
                        trace)
            else:
                record['postprocess_error'] = format_error_body(
                    "artifact-value-not-found", "Artifact value not found")

            return record

        if isinstance(response.body, list):
            body = [await _process(task) for task in response.body]
        else:
            body = await _process(response.body)

        return DBResponse(response_code=response.response_code, body=body)
Exemple #11
0
    async def get_run_dag(self, request):
        """
        ---
        description: Get DAG structure for a run.
        tags:
        - Run
        parameters:
          - $ref: '#/definitions/Params/Path/flow_id'
          - $ref: '#/definitions/Params/Path/run_number'
          - $ref: '#/definitions/Params/Custom/invalidate'
        produces:
        - application/json
        responses:
            "200":
                description: Return DAG structure for a specific run
                schema:
                  $ref: '#/definitions/ResponsesDag'
            "405":
                description: invalid HTTP Method
                schema:
                  $ref: '#/definitions/ResponsesError405'
            "404":
                description: necessary data for DAG generation Not Found
                schema:
                  $ref: '#/definitions/ResponsesError404'
            "500":
                description: Internal Server Error (with error id)
                schema:
                    $ref: '#/definitions/ResponsesDagError500'
        """
        flow_name = request.match_info['flow_id']
        run_number = request.match_info.get("run_number")
        # Before running the cache action, we make sure that the run has
        # the necessary data to generate a DAG.
        db_response = await get_run_dag_data(self.db, flow_name, run_number)

        if not db_response.response_code == 200:
            # DAG data was not found, return with the corresponding status.
            status, body = format_response(request, db_response)
            return web_response(status, body)

        # Prefer run_id over run_number
        flow_name = db_response.body['flow_id']
        run_id = db_response.body.get(
            'run_id') or db_response.body['run_number']
        invalidate_cache = query_param_enabled(request, "invalidate")

        dag = await self._dag_store.cache.GenerateDag(
            flow_name, run_id, invalidate_cache=invalidate_cache)

        if dag.has_pending_request():
            async for event in dag.stream():
                if event["type"] == "error":
                    # raise error, there was an exception during processing.
                    raise GenerateDAGFailed(event["message"], event["id"],
                                            event["traceback"])
            await dag.wait()  # wait until results are ready
        dag = dag.get()
        response = DBResponse(200, dag)
        status, body = format_response(request, response)

        return web_response(status, body)
Exemple #12
0
    async def mutate_user_tags(self, request):
        """
        ---
        description: mutate user tags
        tags:
        - Run
        parameters:
        - name: "flow_id"
          in: "path"
          description: "flow_id"
          required: true
          type: "string"
        - name: "run_number"
          in: "path"
          description: "run_number"
          required: true
          type: "string"
        - name: "body"
          in: "body"
          description: "body"
          required: true
          schema:
            type: object
            properties:
                tags_to_add:
                    type: array of string
                tags_to_remove:
                    type: array of string
        produces:
        - 'text/plain'
        responses:
            "200":
                description: successful operation. Tags updated.  Returns latest user tags
            "400":
                description: invalid HTTP Request
            "405":
                description: invalid HTTP Method
            "409":
                description: mutation request conflicts with an existing in-flight mutation. Retry recommended
            "422":
                description: illegal tag mutation. No update performed.  E.g. could be because we tried to remove
                             a system tag.
        """
        flow_name = request.match_info.get("flow_id")
        run_number = request.match_info.get("run_number")
        body = await read_body(request.content)
        tags_to_add = body.get("tags_to_add", [])
        tags_to_remove = body.get("tags_to_remove", [])

        # We return 400 when request structure is wrong
        if not isinstance(tags_to_add, list):
            return DBResponse(response_code=400,
                              body="tags_to_add must be a list")

        if not isinstance(tags_to_remove, list):
            return DBResponse(response_code=400,
                              body="tags_to_remove must be a list")

        # let's make sure we have a list of strings
        if not all(
                isinstance(t, str)
                for t in chain(tags_to_add, tags_to_remove)):
            return DBResponse(response_code=400,
                              body="All tag values must be strings")

        tags_to_add_set = set(tags_to_add)
        tags_to_remove_set = set(tags_to_remove)

        async def _in_tx_mutation_logic(cur):
            run_db_response = await self._async_table.get_run(flow_name,
                                                              run_number,
                                                              cur=cur)
            if run_db_response.response_code != 200:
                # if something went wrong with get_run, just return the error from that directly
                # e.g. 404, or some other error. This is useful for the client (vs additional wrapping, etc).
                return run_db_response
            run = run_db_response.body
            existing_tag_set = set(run["tags"])
            existing_system_tag_set = set(run["system_tags"])

            if tags_to_remove_set & existing_system_tag_set:
                # We use 422 here to communicate that the request was well-formatted in terms of structure and
                # that the server understood what was being requested. However, it failed business rules.
                return DBResponse(
                    response_code=422,
                    body="Cannot remove tags that are existing system tags %s"
                    % str(tags_to_remove_set & existing_system_tag_set))

            # Apply removals before additions.
            # And, make sure no existing system tags get added as a user tag
            next_run_tag_set = (existing_tag_set - tags_to_remove_set) | (
                tags_to_add_set - existing_system_tag_set)
            if next_run_tag_set == existing_tag_set:
                return DBResponse(response_code=200,
                                  body={"tags": list(next_run_tag_set)})
            next_run_tags = list(next_run_tag_set)

            update_db_response = await self._async_table.update_run_tags(
                flow_name, run_number, next_run_tags, cur=cur)
            if update_db_response.response_code != 200:
                return update_db_response
            return DBResponse(response_code=200, body={"tags": next_run_tags})

        return await self._async_table.run_in_transaction_with_serializable_isolation_level(
            _in_tx_mutation_logic)