def get_outputs_for_deletion(task, force=False): with TimingContext("mongo", "get_task_models"): models = TaskOutputs( attrgetter("ready"), Model, Model.objects(task=task.id).only("id", "task", "ready"), ) if not force and models.published: raise errors.bad_request.TaskCannotBeDeleted( "has output models, use force=True", task=task.id, models=len(models.published), ) if task.output.model: output_model = get_output_model(task, force) if output_model: if output_model.ready: models.published.append(output_model) else: models.draft.append(output_model) with TimingContext("mongo", "get_task_children"): tasks = Task.objects(parent=task.id).only("id", "parent", "status") published_tasks = [ task for task in tasks if task.status == TaskStatus.published ] if not force and published_tasks: raise errors.bad_request.TaskCannotBeDeleted( "has children, use force=True", task=task.id, children=published_tasks) return models, tasks
def cleanup_task(task: Task, force: bool = False): """ Validate task deletion and delete/modify all its output. :param task: task object :param force: whether to delete task with published outputs :return: count of delete and modified items """ models, child_tasks = get_outputs_for_deletion(task, force) deleted_task_id = trash_task_id(task.id) if child_tasks: with TimingContext("mongo", "update_task_children"): updated_children = child_tasks.update(parent=deleted_task_id) else: updated_children = 0 if models.draft: with TimingContext("mongo", "delete_models"): deleted_models = models.draft.objects().delete() else: deleted_models = 0 if models.published: with TimingContext("mongo", "update_task_models"): updated_models = models.published.objects().update( task=deleted_task_id) else: updated_models = 0 event_bll.delete_task_events(task.company, task.id, allow_locked=force) return CleanupResult( deleted_models=deleted_models, updated_children=updated_children, updated_models=updated_models, )
def authorize_credentials(auth_data, service, action, call_data_items): """ Validate credentials against service/action and request data (dicts). Returns a new basic object (auth payload) """ try: access_key, _, secret_key = base64.b64decode( auth_data.encode()).decode('latin-1').partition(':') except Exception as e: log.exception('malformed credentials') raise errors.unauthorized.BadCredentials(str(e)) query = Q( credentials__match=Credentials(key=access_key, secret=secret_key)) fixed_user = None if FixedUser.enabled(): fixed_user = FixedUser.get_by_username(access_key) if fixed_user: if secret_key != fixed_user.password: raise errors.unauthorized.InvalidCredentials( 'bad username or password') if fixed_user.is_guest and not FixedUser.is_guest_endpoint( service, action): raise errors.unauthorized.InvalidCredentials( 'endpoint not allowed for guest') query = Q(id=fixed_user.user_id) with TimingContext( "mongo", "user_by_cred"), translate_errors_context('authorizing request'): user = User.objects(query).first() if not user: raise errors.unauthorized.InvalidCredentials( 'failed to locate provided credentials') if not fixed_user: # In case these are proper credentials, update last used time User.objects(id=user.id, credentials__key=access_key).update( **{"set__credentials__$__last_used": datetime.utcnow()}) with TimingContext("mongo", "company_by_id"): company = Company.objects(id=user.company).only('id', 'name').first() if not company: raise errors.unauthorized.InvalidCredentials('invalid user company') identity = Identity(user=user.id, company=user.company, role=user.role, user_name=user.name, company_name=company.name) basic = Basic(user_key=access_key, identity=identity) return basic
def _validate_and_get_task_from_call(call: APICall, **kwargs): with translate_errors_context(field_does_not_exist_cls=errors.bad_request. ValidationError), TimingContext( "code", "parse_call"): fields = prepare_create_fields(call, **kwargs) task = task_bll.create(call, fields) with TimingContext("code", "validate"): task_bll.validate(task) return task
def scroll_task_events( self, company_id, task_id, order, event_type=None, batch_size=10000, scroll_id=None, ): if scroll_id: with translate_errors_context(), TimingContext( "es", "task_log_events"): es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h") else: size = min(batch_size, 10000) if event_type is None: event_type = "*" es_index = EventMetrics.get_index_name(company_id, event_type) if not self.es.indices.exists(es_index): return [], None, 0 es_req = { "size": size, "sort": { "timestamp": { "order": order } }, "query": { "bool": { "must": [{ "term": { "task": task_id } }] } }, } with translate_errors_context(), TimingContext( "es", "scroll_task_events"): es_res = self.es.search(index=es_index, body=es_req, scroll="1h", routing=task_id) events = [hit["_source"] for hit in es_res["hits"]["hits"]] next_scroll_id = es_res["_scroll_id"] total_events = es_res["hits"]["total"] return events, next_scroll_id, total_events
def update(call: APICall): """ update :summary: Update project information. See `project.create` for parameters. :return: updated - `int` - number of projects updated fields - `[string]` - updated fields """ project_id = call.data["project"] with translate_errors_context(): project = Project.get_for_writing(company=call.identity.company, id=project_id) if not project: raise errors.bad_request.InvalidProjectId(id=project_id) fields = parse_from_call(call.data, create_fields, Project.get_fields(), discard_none_values=False) conform_tag_fields(call, fields, validate=True) fields["last_update"] = datetime.utcnow() with TimingContext("mongo", "projects_update"): updated = project.update(upsert=False, **fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
def get_by_id( company_id, task_id, required_status=None, only_fields=None, allow_public=False, ): if only_fields: if isinstance(only_fields, string_types): only_fields = [only_fields] else: only_fields = list(only_fields) only_fields = only_fields + ["status"] with TimingContext("mongo", "task_by_id_all"): tasks = Task.get_many( company=company_id, query=Q(id=task_id), allow_public=allow_public, override_projection=only_fields, return_dicts=False, ) task = None if not tasks else tasks[0] if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus( expected=required_status) return task
def assert_exists(company_id, task_ids, only=None, allow_public=False, return_tasks=True) -> Optional[Sequence[Task]]: task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids with translate_errors_context(), TimingContext("mongo", "task_exists"): ids = set(task_ids) q = Task.get_many( company=company_id, query=Q(id__in=ids), allow_public=allow_public, return_dicts=False, ) res = None if only: res = q.only(*only) elif return_tasks: res = list(q) count = len(res) if res is not None else q.count() if count != len(ids): raise errors.bad_request.InvalidTaskId(ids=task_ids) if return_tasks: return res
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant): es_index = EventBLL.get_index_name(company_id, "training_stats_vector") if not self.es.indices.exists(es_index): return [], [] es_req = { "size": 10000, "query": { "bool": { "must": [ {"term": {"task": task_id}}, {"term": {"metric": metric}}, {"term": {"variant": variant}}, ] } }, "_source": ["iter", "value"], "sort": ["iter"], } with translate_errors_context(), TimingContext("es", "task_stats_vector"): es_res = self.es.search(index=es_index, body=es_req, routing=task_id) vectors = [] iterations = [] for hit in es_res["hits"]["hits"]: vectors.append(hit["_source"]["value"]) iterations.append(hit["_source"]["iter"]) return iterations, vectors
def get_by_id( company_id, task_id, required_status=None, required_dataset=None, only_fields=None, ): with TimingContext("mongo", "task_by_id_all"): qs = Task.objects(id=task_id, company=company_id) if only_fields: qs = (qs.only(only_fields) if isinstance( only_fields, string_types) else qs.only(*only_fields)) qs = qs.only( "status", "input" ) # make sure all fields we rely on here are also returned task = qs.first() if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus( expected=required_status) if required_dataset and required_dataset not in ( entry.dataset for entry in task.input.view.entries): raise errors.bad_request.InvalidId("not in input view", dataset=required_dataset) return task
def get_metrics_and_variants(self, company_id, task_id, event_type): es_index = EventBLL.get_index_name(company_id, event_type) if not self.es.indices.exists(es_index): return {} es_req = { "size": 0, "aggs": { "metrics": { "terms": {"field": "metric", "size": 200}, "aggs": {"variants": {"terms": {"field": "variant", "size": 200}}}, } }, "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, } with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): es_res = self.es.search(index=es_index, body=es_req, routing=task_id) metrics = {} for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): metric = metric_bucket["key"] metrics[metric] = [ b["key"] for b in metric_bucket["variants"].get("buckets") ] return metrics
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence: es_req = { "size": 0, "query": { "bool": { "must": [ {"term": {"task": task_id}}, {"term": {"type": event_type.value}}, ] } }, "aggs": { "metrics": { "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT} } }, } with translate_errors_context(), TimingContext("es", "_get_task_metrics"): es_res = self.es.search(index=es_index, body=es_req, routing=task_id) return [ metric["key"] for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[]) ]
def delete(call: APICall, company_id, req_model: DeleteRequest): task = TaskBLL.get_task_with_access(req_model.task, company_id=company_id, requires_write_access=True) move_to_trash = req_model.move_to_trash force = req_model.force if task.status != TaskStatus.created and not force: raise errors.bad_request.TaskCannotBeDeleted( "due to status, use force=True", task=task.id, expected=TaskStatus.created, current=task.status, ) with translate_errors_context(): result = cleanup_task(task, force) if move_to_trash: collection_name = task._get_collection_name() archived_collection = "{}__trash".format(collection_name) task.switch_collection(archived_collection) try: # A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force # an insert. However, if for some reason such an ID exists, let's make sure we'll keep going. with TimingContext("mongo", "save_task"): task.save(force_insert=True) except Exception: pass task.switch_collection(collection_name) task.delete() org_bll.update_org_tags(company_id, reset=True) call.result.data = dict(deleted=True, **attr.asdict(result))
def _query_aggregation_for_metrics_and_tasks( self, es_index: str, aggs: dict, task_ids: Sequence[str], task_metrics: Sequence[TaskMetric], ) -> dict: """ Return the result of elastic search query for the given aggregation filtered by the given task_ids and metrics """ if task_metrics: condition = { "should": [ self._build_metric_terms(task, metric, variant) for task, metric, variant in task_metrics ] } else: condition = {"must": [{"terms": {"task": task_ids}}]} es_req = { "size": 0, "_source": {"excludes": []}, "query": {"bool": condition}, "aggs": aggs, "version": True, } with translate_errors_context(), TimingContext("es", "task_stats_scalar"): es_res = self.es.search( index=es_index, body=es_req, routing=",".join(task_ids) ) return es_res.get("aggregations")
def get_last_iters(self, es_index, task_id, event_type, iters): if not self.es.indices.exists(es_index): return [] es_req: dict = { "size": 0, "aggs": { "iters": { "terms": { "field": "iter", "size": iters, "order": {"_term": "desc"}, } } }, "query": {"bool": {"must": [{"term": {"task": task_id}}]}}, } if event_type: es_req["query"]["bool"]["must"].append({"term": {"type": event_type}}) with translate_errors_context(), TimingContext("es", "task_last_iter"): es_res = self.es.search(index=es_index, body=es_req, routing=task_id) if "aggregations" not in es_res: return [] return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
def get_queue_metrics( self, company_id: str, from_date: float, to_date: float, interval: int, queue_ids: Sequence[str], ) -> dict: """ Get the company queue metrics in the specified time range. Returned as date histograms of average values per queue and metric type. The from_date is extended by 'metrics_before_from_date' seconds from queues.conf due to possibly small amount of points. The default extension is 3600s In case no queue ids are specified the avg across all the company queues is calculated for each metric """ # self._log_current_metrics(company, queue_ids=queue_ids) if from_date >= to_date: raise bad_request.FieldsValueError( "from_date must be less than to_date") seconds_before = config.get("services.queues.metrics_before_from_date", 3600) must_terms = [ QueryBuilder.dates_range(from_date - seconds_before, to_date) ] if queue_ids: must_terms.append(QueryBuilder.terms("queue", queue_ids)) es_req = { "size": 0, "query": { "bool": { "must": must_terms } }, "aggs": self._get_dates_agg(interval), } with translate_errors_context(), TimingContext("es", "get_queue_metrics"): res = self._search_company_metrics(company_id, es_req) if "aggregations" not in res: return {} date_metrics = [ dict( timestamp=d["key"], queue_metrics=self._extract_queue_metrics( d["queues"]["buckets"]), ) for d in res["aggregations"]["dates"]["buckets"] if d["doc_count"] > 0 ] if queue_ids: return self._datetime_histogram_per_queue(date_metrics) return self._average_datetime_histogram(date_metrics)
def get_output_model(task, force=False): with TimingContext("mongo", "get_task_output_model"): output_model = Model.objects(id=task.output.model).first() if output_model and output_model.ready and not force: raise errors.bad_request.TaskCannotBeDeleted( "has output model, use force=True", task=task.id, model=task.output.model ) return output_model
def _get( self, company: str, user: str = "*", worker_id: str = "*" ) -> Sequence[WorkerEntry]: """Get worker entries matching the company and user, worker patterns""" match = self._get_worker_key(company, user, worker_id) with TimingContext("redis", "workers_get_all"): res = self.redis.scan_iter(match) return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
def create(call: APICall, company_id, req_model: CreateRequest): task = _validate_and_get_task_from_call(call) with translate_errors_context(), TimingContext("mongo", "save_task"): task.save() update_project_time(task.project) call.result.data = {"id": task.id}
def get_scalar_metrics_average_per_iter(self, company_id, task_id): es_index = EventBLL.get_index_name(company_id, "training_stats_scalar") if not self.es.indices.exists(es_index): return {} es_req = { "size": 0, "_source": {"excludes": []}, "query": {"term": {"task": task_id}}, "aggs": { "iters": { "histogram": {"field": "iter", "interval": 1, "min_doc_count": 1}, "aggs": { "metrics": { "terms": { "field": "metric", "size": 200, "order": {"_term": "desc"}, }, "aggs": { "variants": { "terms": { "field": "variant", "size": 500, "order": {"_term": "desc"}, }, "aggs": {"avg_val": {"avg": {"field": "value"}}}, } }, } }, } }, "version": True, } with translate_errors_context(), TimingContext("es", "task_stats_scalar"): es_res = self.es.search(index=es_index, body=es_req, routing=task_id) metrics = {} if "aggregations" in es_res: for iter_bucket in es_res["aggregations"]["iters"]["buckets"]: iteration = int(iter_bucket["key"]) for metric_bucket in iter_bucket["metrics"]["buckets"]: metric_name = metric_bucket["key"] if metrics.get(metric_name) is None: metrics[metric_name] = {} metric_data = metrics[metric_name] for variant_bucket in metric_bucket["variants"]["buckets"]: variant = variant_bucket["key"] value = variant_bucket["avg_val"]["value"] if metric_data.get(variant) is None: metric_data[variant] = {"x": [], "y": [], "name": variant} metric_data[variant]["x"].append(iteration) metric_data[variant]["y"].append(value) return metrics
def edit(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task force = req_model.force with translate_errors_context(): task = Task.get_for_writing(id=task_id, company=company_id) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) if not force and task.status != TaskStatus.created: raise errors.bad_request.InvalidTaskStatus( expected=TaskStatus.created, status=task.status) edit_fields = create_fields.copy() edit_fields.update(dict(status=None)) with translate_errors_context( field_does_not_exist_cls=errors.bad_request.ValidationError ), TimingContext("code", "parse_and_validate"): fields = prepare_create_fields(call, valid_fields=edit_fields, output=task.output, previous_task=task) for key in fields: field = getattr(task, key, None) value = fields[key] if (field and isinstance(value, dict) and isinstance(field, EmbeddedDocument)): d = field.to_mongo(use_db_field=False).to_dict() d.update(value) fields[key] = d task_bll.validate(task_bll.create(call, fields)) # make sure field names do not end in mongoengine comparison operators fixed_fields = {(k if k not in COMPARISON_OPERATORS else "%s__" % k): v for k, v in fields.items()} if fixed_fields: now = datetime.utcnow() fields.update(last_update=now) fixed_fields.update(last_update=now) updated = task.update(upsert=False, **fixed_fields) if updated: new_project = fixed_fields.get("project", task.project) if new_project != task.project: _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags(company_id, project=task.project, fields=fixed_fields) update_project_time(fields.get("project")) unprepare_from_saved(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0)
def get_all_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) with translate_errors_context(): with TimingContext("mongo", "models_get_all_ex"): models = Model.get_many_with_join(company=company_id, query_dict=call.data, allow_public=True) conform_output_tags(call, models) call.result.data = {"models": models}
def create(call: APICall, company_id, req_model: CreateRequest): task, fields = _validate_and_get_task_from_call(call) with translate_errors_context(), TimingContext("mongo", "save_task"): task.save() _update_org_tags(company_id, fields) update_project_time(task.project) call.result.data_model = IdResponse(id=task.id)
def delete_task_events(self, company_id, task_id): es_index = EventBLL.get_index_name(company_id, "*") es_req = {"query": {"term": {"task": task_id}}} with translate_errors_context(), TimingContext("es", "delete_task_events"): es_res = self.es.delete_by_query( index=es_index, body=es_req, routing=task_id, refresh=True ) return es_res.get("deleted", 0)
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set: """Verify that task exists and can be updated""" if not task_ids: return set() with translate_errors_context(), TimingContext("mongo", "task_by_ids"): query = Q(id__in=task_ids, company=company_id) if not allow_locked_tasks: query &= Q(status__nin=LOCKED_TASK_STATUSES) res = Task.objects(query).only("id") return {r.id for r in res}
def get_all_ex(call: APICall): with translate_errors_context(): with TimingContext("mongo", "task_get_all_ex"): tasks = Task.get_many_with_join( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, allow_public= True, # required in case projection is requested for public dataset/versions ) call.result.data = {"tasks": tasks}
def get_response(self): def make_version_number(version): """ Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception """ if version is None: return None if self.requested_endpoint_version < PartialVersion("2.1"): return float(str(version)) return str(version) if self.result.raw_data and not self.failed: # endpoint returned raw data and no error was detected, return raw data, no fancy dicts return self.result.raw_data, self.result.content_type else: res = { "meta": { "id": self.id, "trx": self.trx, "endpoint": { "name": self.endpoint_name, "requested_version": make_version_number(self.requested_endpoint_version), "actual_version": make_version_number(self.actual_endpoint_version), }, "result_code": self.result.code, "result_subcode": self.result.subcode, "result_msg": self.result.msg, "error_stack": self.result.traceback, }, "data": self.result.data, } if self.content_type.lower() == JSON_CONTENT_TYPE: with TimingContext("json", "serialization"): try: res = json.dumps(res) except Exception as ex: # JSON serialization may fail, probably problem with data so pop it and try again if not self.result.data: raise self.result.data = None msg = "Error serializing response data: " + str(ex) self.set_error_result(code=500, subcode=0, msg=msg, include_stack=False) return self.get_response() return res, self.content_type
def get_all(call): assert isinstance(call, APICall) with translate_errors_context(), TimingContext("mongo", "projects_get_all"): res = Project.get_many( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, parameters=call.data, allow_public=True, ) call.result.data = {"projects": res}
def get_all(call: APICall): conform_tag_fields(call, call.data) with translate_errors_context(), TimingContext("mongo", "projects_get_all"): projects = Project.get_many( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, parameters=call.data, allow_public=True, ) conform_output_tags(call, projects) call.result.data = {"projects": projects}
def get_all(call: APICall): conform_tag_fields(call, call.data) with translate_errors_context(): with TimingContext("mongo", "models_get_all"): models = Model.get_many( company=call.identity.company, parameters=call.data, query_dict=call.data, allow_public=True, query_options=get_all_query_options, ) conform_output_tags(call, models) call.result.data = {"models": models}