async def _in_tx_mutation_logic(cur): run_db_response = await self._async_table.get_run(flow_name, run_number, cur=cur) if run_db_response.response_code != 200: # if something went wrong with get_run, just return the error from that directly # e.g. 404, or some other error. This is useful for the client (vs additional wrapping, etc). return run_db_response run = run_db_response.body existing_tag_set = set(run["tags"]) existing_system_tag_set = set(run["system_tags"]) if tags_to_remove_set & existing_system_tag_set: # We use 422 here to communicate that the request was well-formatted in terms of structure and # that the server understood what was being requested. However, it failed business rules. return DBResponse( response_code=422, body="Cannot remove tags that are existing system tags %s" % str(tags_to_remove_set & existing_system_tag_set)) # Apply removals before additions. # And, make sure no existing system tags get added as a user tag next_run_tag_set = (existing_tag_set - tags_to_remove_set) | ( tags_to_add_set - existing_system_tag_set) if next_run_tag_set == existing_tag_set: return DBResponse(response_code=200, body={"tags": list(next_run_tag_set)}) next_run_tags = list(next_run_tag_set) update_db_response = await self._async_table.update_run_tags( flow_name, run_number, next_run_tags, cur=cur) if update_db_response.response_code != 200: return update_db_response return DBResponse(response_code=200, body={"tags": next_run_tags})
async def get_tags(self, conditions: List[str] = None, values=[], limit: int = 0, offset: int = 0): sql_template = """ SELECT DISTINCT tag FROM ( SELECT JSONB_ARRAY_ELEMENTS_TEXT(tags||system_tags) AS tag FROM {table_name} ) AS t {conditions} {limit} {offset} """ select_sql = sql_template.format( table_name=self.table_name, conditions="WHERE {}".format(" AND ".join(conditions)) if conditions else "", limit="LIMIT {}".format(limit) if limit else "", offset="OFFSET {}".format(offset) if offset else "", ) res, pagination = await self.execute_sql(select_sql=select_sql, values=values, serialize=False) # process the unserialized DBResponse _body = [row[0] for row in res.body] return DBResponse(res.response_code, _body), pagination
async def apply_run_tags_to_db_response(flow_id, run_number, run_table_postgres, db_response: DBResponse) -> DBResponse: """ We want read APIs to return steps, tasks and artifact objects with tags and system_tags set to their ancestral Run. This is a prerequisite for supporting Run-based tag mutation. """ # we will return a modified copy of db_response new_db_response = copy.deepcopy(db_response) # Only replace tags if response code is legit # Object creation ought to return 201 (let's prepare for that) if new_db_response.response_code not in (200, 201): return new_db_response if isinstance(new_db_response.body, list): items_to_modify = new_db_response.body else: items_to_modify = [new_db_response.body] if not items_to_modify: return new_db_response # items_to_modify now references all the items we want to modify # The ancestral run must be successfully read from DB db_response_for_run = await run_table_postgres.get_run(flow_id, run_number) if db_response_for_run.response_code != 200: return DBResponse(response_code=500, body=db_response_for_run.body) run = db_response_for_run.body for item_as_dict in items_to_modify: item_as_dict['tags'] = run['tags'] item_as_dict['system_tags'] = run['system_tags'] return new_db_response
async def get_step_names( self, conditions: List[str] = [], values: List[str] = [], limit: int = 0, offset: int = 0) -> Tuple[DBResponse, DBPagination]: """ Get a paginated set of step names. Parameters ---------- conditions : List[str] list of conditions to pass the sql execute, with %s placeholders for values values : List[str] list of values to be passed for the sql execute. limit : int (optional) (default 0) limit for the number of results offset : int (optional) (default 0) offset for the results. Returns ------- (DBResponse, DBPagination) """ sql_template = """ SELECT step_name FROM ( SELECT DISTINCT step_name, flow_id, run_number, run_id FROM {table_name} ) T {conditions} {limit} {offset} """ select_sql = sql_template.format( table_name=self.table_name, keys=",".join(self.select_columns), conditions=("WHERE {}".format(" AND ".join(conditions)) if conditions else ""), limit="LIMIT {}".format(limit) if limit else "", offset="OFFSET {}".format(offset) if offset else "") res, pag = await self.execute_sql(select_sql=select_sql, values=values, fetch_single=False, expanded=False, limit=limit, offset=offset, serialize=False) # process the unserialized DBResponse _body = [row[0] for row in res.body] return DBResponse(res.response_code, _body), pag
async def get_task_log(self, request, logtype=STDOUT): "fetches log and emits it as a list of rows wrapped in json" task = await self.get_task_by_request(request) if not task: return web_response(404, {'data': []}) limit, page, reverse_order = get_pagination_params(request) lines, page_count = await read_and_output(self.cache, task, logtype, limit, page, reverse_order) # paginated response response = DBResponse(200, lines) pagination = DBPagination(limit, limit * (page - 1), len(response.body), page) status, body = format_response_list(request, response, pagination, page, page_count) return web_response(status, body)
async def execute_sql( self, select_sql: str, values=[], fetch_single=False, expanded=False, limit: int = 0, offset: int = 0, serialize: bool = True) -> Tuple[DBResponse, DBPagination]: try: with (await self.db.pool.cursor(cursor_factory=psycopg2.extras.DictCursor )) as cur: await cur.execute(select_sql, values) rows = [] records = await cur.fetchall() if serialize: for record in records: # pylint-initial-ignore: Lack of __init__ makes this too hard for pylint # pylint: disable=not-callable row = self._row_type(**record) rows.append(row.serialize(expanded)) else: rows = records count = len(rows) # Will raise IndexError in case fetch_single=True and there's no results body = rows[0] if fetch_single else rows pagination = DBPagination( limit=limit, offset=offset, count=count, page=math.floor(int(offset) / max(int(limit), 1)) + 1, ) cur.close() return DBResponse(response_code=200, body=body), pagination except IndexError as error: return aiopg_exception_handling(error), None except (Exception, psycopg2.DatabaseError) as error: self.db.logger.exception("Exception occured") return aiopg_exception_handling(error), None
async def get_run_parameters(self, request): """ --- description: Get parameters of a run tags: - Run parameters: - $ref: '#/definitions/Params/Path/flow_id' - $ref: '#/definitions/Params/Path/run_number' - $ref: '#/definitions/Params/Custom/invalidate' produces: - application/json responses: "200": description: Returns parameters of a run schema: $ref: '#/definitions/ResponsesRunParameters' "405": description: invalid HTTP Method schema: $ref: '#/definitions/ResponsesError405' "500": description: Internal Server Error (with error id) schema: $ref: '#/definitions/ResponsesRunParametersError500' """ flow_name = request.match_info['flow_id'] run_number = request.match_info.get("run_number") invalidate_cache = query_param_enabled(request, "invalidate") # _artifact_store.get_run_parameters will translate run_number/run_id properly combined_results = await self._artifact_store.get_run_parameters( flow_name, run_number, invalidate_cache=invalidate_cache) postprocess_error = combined_results.get("postprocess_error", None) if postprocess_error: raise GetParametersFailed( postprocess_error["detail"], postprocess_error["id"], postprocess_error["traceback"]) else: response = DBResponse(200, combined_results) status, body = format_response(request, response) return web_response(status, body)
def test_format_response(): request = make_mocked_request('GET', '/runs?_limit=10', headers={'Host': 'test'}) db_response = DBResponse(response_code=200, body={"foo": "bar"}) expected_response = { "data": { "foo": "bar" }, "status": 200, "links": { "self": "http://test/runs?_limit=10" }, "query": { "_limit": "10" }, } status, response = format_response(request, db_response) assert json.dumps(response) == json.dumps(expected_response) assert status == 200
def test_format_response_list_next_page_null(): request = make_mocked_request('GET', '/runs?_limit=10&_page=2', headers={'Host': 'test'}) db_response = DBResponse(response_code=200, body=[{"foo": "bar"}]) pagination = DBPagination(limit=10, offset=0, count=1, page=2) expected_response = { "data": [{ "foo": "bar" }], "status": 200, "links": { "self": "http://test/runs?_limit=10&_page=2", "first": "http://test/runs?_limit=10&_page=1", "prev": "http://test/runs?_limit=10&_page=1", "next": None, "last": None }, "pages": { "self": 2, "first": 1, "prev": 1, "next": None, "last": None }, "query": { "_limit": "10", "_page": "2" }, } status, response = format_response_list(request, db_response, pagination, 2) assert json.dumps(response) == json.dumps(expected_response) assert status == 200
async def postprocess(self, response: DBResponse, invalidate_cache=False): """ Calls the refiner postprocessing to fetch Metaflow artifacts. Parameters ---------- response : DBResponse The DBResponse to be refined Returns ------- A refined DBResponse, or in case of errors, the original DBResponse """ if FEATURE_REFINE_DISABLE: return response if response.response_code != 200 or not response.body: return response input = self._response_to_action_input(response) errors = {} def _event_stream(event): if event.get("type") == "error" and event.get("key"): # Get last element from cache key which usually translates to "target" target = event["key"].split(':')[-1:][0] errors[target] = event data = await self.fetch_data(input, event_stream=_event_stream, invalidate_cache=invalidate_cache) async def _process(record): target = self._record_to_action_input(record) if target in errors: # Add streamed postprocess errors if any record["postprocess_error"] = format_error_body( errors[target].get("id"), errors[target].get("message"), errors[target].get("traceback")) if target in data: success, value, detail, trace = unpack_processed_value( data[target]) if success: record = await self.refine_record(record, value) else: record['postprocess_error'] = format_error_body( value if value else "artifact-handle-failed", detail if detail else "Unknown error during postprocessing", trace) else: record['postprocess_error'] = format_error_body( "artifact-value-not-found", "Artifact value not found") return record if isinstance(response.body, list): body = [await _process(task) for task in response.body] else: body = await _process(response.body) return DBResponse(response_code=response.response_code, body=body)
async def get_run_dag(self, request): """ --- description: Get DAG structure for a run. tags: - Run parameters: - $ref: '#/definitions/Params/Path/flow_id' - $ref: '#/definitions/Params/Path/run_number' - $ref: '#/definitions/Params/Custom/invalidate' produces: - application/json responses: "200": description: Return DAG structure for a specific run schema: $ref: '#/definitions/ResponsesDag' "405": description: invalid HTTP Method schema: $ref: '#/definitions/ResponsesError405' "404": description: necessary data for DAG generation Not Found schema: $ref: '#/definitions/ResponsesError404' "500": description: Internal Server Error (with error id) schema: $ref: '#/definitions/ResponsesDagError500' """ flow_name = request.match_info['flow_id'] run_number = request.match_info.get("run_number") # Before running the cache action, we make sure that the run has # the necessary data to generate a DAG. db_response = await get_run_dag_data(self.db, flow_name, run_number) if not db_response.response_code == 200: # DAG data was not found, return with the corresponding status. status, body = format_response(request, db_response) return web_response(status, body) # Prefer run_id over run_number flow_name = db_response.body['flow_id'] run_id = db_response.body.get( 'run_id') or db_response.body['run_number'] invalidate_cache = query_param_enabled(request, "invalidate") dag = await self._dag_store.cache.GenerateDag( flow_name, run_id, invalidate_cache=invalidate_cache) if dag.has_pending_request(): async for event in dag.stream(): if event["type"] == "error": # raise error, there was an exception during processing. raise GenerateDAGFailed(event["message"], event["id"], event["traceback"]) await dag.wait() # wait until results are ready dag = dag.get() response = DBResponse(200, dag) status, body = format_response(request, response) return web_response(status, body)
async def mutate_user_tags(self, request): """ --- description: mutate user tags tags: - Run parameters: - name: "flow_id" in: "path" description: "flow_id" required: true type: "string" - name: "run_number" in: "path" description: "run_number" required: true type: "string" - name: "body" in: "body" description: "body" required: true schema: type: object properties: tags_to_add: type: array of string tags_to_remove: type: array of string produces: - 'text/plain' responses: "200": description: successful operation. Tags updated. Returns latest user tags "400": description: invalid HTTP Request "405": description: invalid HTTP Method "409": description: mutation request conflicts with an existing in-flight mutation. Retry recommended "422": description: illegal tag mutation. No update performed. E.g. could be because we tried to remove a system tag. """ flow_name = request.match_info.get("flow_id") run_number = request.match_info.get("run_number") body = await read_body(request.content) tags_to_add = body.get("tags_to_add", []) tags_to_remove = body.get("tags_to_remove", []) # We return 400 when request structure is wrong if not isinstance(tags_to_add, list): return DBResponse(response_code=400, body="tags_to_add must be a list") if not isinstance(tags_to_remove, list): return DBResponse(response_code=400, body="tags_to_remove must be a list") # let's make sure we have a list of strings if not all( isinstance(t, str) for t in chain(tags_to_add, tags_to_remove)): return DBResponse(response_code=400, body="All tag values must be strings") tags_to_add_set = set(tags_to_add) tags_to_remove_set = set(tags_to_remove) async def _in_tx_mutation_logic(cur): run_db_response = await self._async_table.get_run(flow_name, run_number, cur=cur) if run_db_response.response_code != 200: # if something went wrong with get_run, just return the error from that directly # e.g. 404, or some other error. This is useful for the client (vs additional wrapping, etc). return run_db_response run = run_db_response.body existing_tag_set = set(run["tags"]) existing_system_tag_set = set(run["system_tags"]) if tags_to_remove_set & existing_system_tag_set: # We use 422 here to communicate that the request was well-formatted in terms of structure and # that the server understood what was being requested. However, it failed business rules. return DBResponse( response_code=422, body="Cannot remove tags that are existing system tags %s" % str(tags_to_remove_set & existing_system_tag_set)) # Apply removals before additions. # And, make sure no existing system tags get added as a user tag next_run_tag_set = (existing_tag_set - tags_to_remove_set) | ( tags_to_add_set - existing_system_tag_set) if next_run_tag_set == existing_tag_set: return DBResponse(response_code=200, body={"tags": list(next_run_tag_set)}) next_run_tags = list(next_run_tag_set) update_db_response = await self._async_table.update_run_tags( flow_name, run_number, next_run_tags, cur=cur) if update_db_response.response_code != 200: return update_db_response return DBResponse(response_code=200, body={"tags": next_run_tags}) return await self._async_table.run_in_transaction_with_serializable_isolation_level( _in_tx_mutation_logic)