Exemplo n.º 1
0
def _update_model(call: APICall, company_id, model_id=None):
    model_id = model_id or call.data["model"]

    with translate_errors_context():
        # get model by id
        query = dict(id=model_id, company=company_id)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        data = prepare_update_fields(call, company_id, call.data)

        task_id = data.get("task")
        iteration = data.get("iteration")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=company_id,
                last_iteration_max=iteration,
            )

        updated_count, updated_fields = Model.safe_update(
            company_id, model.id, data)
        if updated_count:
            new_project = updated_fields.get("project", model.project)
            if new_project != model.project:
                _reset_cached_tags(company_id,
                                   projects=[new_project, model.project])
            else:
                _update_cached_tags(company_id,
                                    project=model.project,
                                    fields=updated_fields)
        conform_output_tags(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
Exemplo n.º 2
0
def _update_model(call: APICall, model_id=None):
    identity = call.identity
    model_id = model_id or call.data["model"]

    with translate_errors_context():
        # get model by id
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        data = prepare_update_fields(call, call.data)

        task_id = data.get("task")
        iteration = data.get("iteration")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=identity.company,
                last_iteration_max=iteration,
            )

        updated_count, updated_fields = Model.safe_update(
            call.identity.company, model.id, data
        )
        conform_output_tags(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)
Exemplo n.º 3
0
    def _update_task(self,
                     company_id,
                     task_id,
                     now,
                     iter_max=None,
                     last_events=None):
        """
        Update task information in DB with aggregated results after handling event(s) related to this task.

        This updates the task with the highest iteration value encountered during the last events update, as well
        as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
        update time.
        """
        fields = {}

        if iter_max is not None:
            fields["last_iteration_max"] = iter_max

        if last_events:
            fields["last_values"] = list(
                flatten_nested_items(
                    last_events,
                    nesting=2,
                    include_leaves=["value", "metric", "variant"],
                ))

        if not fields:
            return False

        return TaskBLL.update_statistics(task_id,
                                         company_id,
                                         last_update=now,
                                         **fields)
Exemplo n.º 4
0
def edit(call: APICall, company_id, _):
    model_id = call.data["model"]

    with translate_errors_context():
        query = dict(id=model_id, company=company_id)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        fields = parse_model_fields(call, create_fields)
        fields = prepare_update_fields(call, company_id, fields)

        for key in fields:
            field = getattr(model, key, None)
            value = fields[key]
            if (field and isinstance(value, dict)
                    and isinstance(field, EmbeddedDocument)):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        iteration = call.data.get("iteration")
        task_id = model.task or fields.get("task")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=company_id,
                last_iteration_max=iteration,
            )

        if fields:
            updated = model.update(upsert=False, **fields)
            if updated:
                new_project = fields.get("project", model.project)
                if new_project != model.project:
                    _reset_cached_tags(company_id,
                                       projects=[new_project, model.project])
                else:
                    _update_cached_tags(company_id,
                                        project=model.project,
                                        fields=fields)
            conform_output_tags(call, fields)
            call.result.data_model = UpdateResponse(updated=updated,
                                                    fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
Exemplo n.º 5
0
def edit(call):
    assert isinstance(call, APICall)
    identity = call.identity
    model_id = call.data["model"]

    with translate_errors_context():
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        fields = parse_model_fields(call, create_fields)
        fields = prepare_update_fields(call, fields)

        for key in fields:
            field = getattr(model, key, None)
            value = fields[key]
            if (field and isinstance(value, dict)
                    and isinstance(field, EmbeddedDocument)):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        iteration = call.data.get("iteration")
        task_id = model.task or fields.get('task')
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=identity.company,
                last_iteration_max=iteration,
            )

        if fields:
            updated = model.update(upsert=False, **fields)
            call.result.data_model = UpdateResponse(updated=updated,
                                                    fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)
Exemplo n.º 6
0
    def _update_task(self,
                     company_id,
                     task_id,
                     now,
                     iter=None,
                     last_events=None,
                     last_metrics=None):
        """
        Update task information in DB with aggregated results after handling event(s) related to this task.

        This updates the task with the highest iteration value encountered during the last events update, as well
        as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
        update time.
        """
        fields = {}

        if iter is not None:
            fields["last_iteration"] = iter

        if last_events:

            def get_metric_event(ev):
                me = MetricEvent.from_dict(**ev)
                if "timestamp" in ev:
                    me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] /
                                                             1000)
                return me

            new_last_metrics = nested_dict(2, MetricEvent)
            new_last_metrics.update(last_metrics)

            for metric_hash, variants in last_events.items():
                for variant_hash, event in variants.items():
                    new_last_metrics[metric_hash][
                        variant_hash] = get_metric_event(event)

            fields["last_metrics"] = new_last_metrics.to_dict()

        if not fields:
            return False

        return TaskBLL.update_statistics(task_id,
                                         company_id,
                                         last_update=now,
                                         **fields)
Exemplo n.º 7
0
def update_for_task(call: APICall, company_id, _):
    task_id = call.data["task"]
    uri = call.data.get("uri")
    iteration = call.data.get("iteration")
    override_model_id = call.data.get("override_model_id")
    if not (uri or override_model_id) or (uri and override_model_id):
        raise errors.bad_request.MissingRequiredFields(
            "exactly one field is required",
            fields=("uri", "override_model_id"))

    with translate_errors_context():

        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            id=task_id,
            company=company_id,
            _only=["output", "execution", "name", "status", "project"],
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        allowed_states = [TaskStatus.created, TaskStatus.in_progress]
        if task.status not in allowed_states:
            raise errors.bad_request.InvalidTaskStatus(
                f"model can only be updated for tasks in the {allowed_states} states",
                **query,
            )

        if override_model_id:
            query = dict(company=company_id, id=override_model_id)
            model = Model.objects(**query).first()
            if not model:
                raise errors.bad_request.InvalidModelId(**query)
        else:
            if "name" not in call.data:
                # use task name if name not provided
                call.data["name"] = task.name

            if "comment" not in call.data:
                call.data[
                    "comment"] = f"Created by task `{task.name}` ({task.id})"

            if task.output and task.output.model:
                # model exists, update
                res = _update_model(call,
                                    company_id,
                                    model_id=task.output.model).to_struct()
                res.update({"id": task.output.model, "created": False})
                call.result.data = res
                return

            # new model, create
            fields = parse_model_fields(call, create_fields)

            # create and save model
            model = Model(
                id=database.utils.id(),
                created=datetime.utcnow(),
                user=call.identity.user,
                company=company_id,
                project=task.project,
                framework=task.execution.framework,
                parent=task.execution.model,
                design=task.execution.model_desc,
                labels=task.execution.model_labels,
                ready=(task.status == TaskStatus.published),
                **fields,
            )
            model.save()
            _update_cached_tags(company_id,
                                project=model.project,
                                fields=fields)

        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
            output__model=model.id,
        )

        call.result.data = {"id": model.id, "created": True}