示例#1
0
class Script(EmbeddedDocument, ProperDictMixin):
    binary = StringField(default="python", strip=True)
    repository = StringField(default="", strip=True)
    tag = StringField(strip=True)
    branch = StringField(strip=True)
    version_num = StringField(strip=True)
    entry_point = StringField(default="", strip=True)
    working_dir = StringField(strip=True)
    requirements = SafeDictField()
    diff = StringField()
示例#2
0
class Execution(EmbeddedDocument, ProperDictMixin):
    meta = {"strict": strict}
    test_split = IntField(default=0)
    parameters = SafeDictField(default=dict)
    model_desc = SafeMapField(StringField(default=""))
    model_labels = ModelLabels()
    framework = StringField()
    artifacts: Dict[str, Artifact] = SafeMapField(
        field=EmbeddedDocumentField(Artifact))
    queue = StringField(reference_field="Queue")
    """ Queue ID where task was queued """
示例#3
0
class Task(AttributedDocument):
    _field_collation_overrides = {
        "execution.parameters.": AttributedDocument._numeric_locale,
        "last_metrics.": AttributedDocument._numeric_locale,
        "hyperparams.": AttributedDocument._numeric_locale,
    }

    meta = {
        "db_alias":
        Database.backend,
        "strict":
        strict,
        "indexes": [
            "created",
            "started",
            "completed",
            "active_duration",
            "parent",
            "project",
            "models.input.model",
            ("company", "name"),
            ("company", "user"),
            ("company", "status", "type"),
            ("company", "system_tags", "last_update"),
            ("company", "type", "system_tags", "status"),
            ("company", "project", "type", "system_tags", "status"),
            ("status", "last_update"),  # for maintenance tasks
            {
                "fields": ["company", "project"],
                "collation": AttributedDocument._numeric_locale,
            },
            {
                "name":
                "%s.task.main_text_index" % Database.backend,
                "fields": [
                    "$name",
                    "$id",
                    "$comment",
                    "$models.input.model",
                    "$models.output.model",
                    "$script.repository",
                    "$script.entry_point",
                ],
                "default_language":
                "english",
                "weights": {
                    "name": 10,
                    "id": 10,
                    "comment": 10,
                    "models.output.model": 2,
                    "models.input.model": 2,
                    "script.repository": 1,
                    "script.entry_point": 1,
                },
            },
        ],
    }
    get_all_query_options = GetMixin.QueryParameterOptions(
        list_fields=(
            "id",
            "user",
            "tags",
            "system_tags",
            "type",
            "status",
            "project",
            "parent",
            "hyperparams.*",
        ),
        range_fields=("started", "active_duration", "last_metrics.*",
                      "last_iteration"),
        datetime_fields=("status_changed", "last_update"),
        pattern_fields=("name", "comment"),
    )

    id = StringField(primary_key=True)
    name = StrippedStringField(required=True,
                               user_set_allowed=True,
                               sparse=False,
                               min_length=3)

    type = StringField(required=True, choices=get_options(TaskType))
    status = StringField(default=TaskStatus.created,
                         choices=get_options(TaskStatus))
    status_reason = StringField()
    status_message = StringField(user_set_allowed=True)
    status_changed = DateTimeField()
    comment = StringField(user_set_allowed=True)
    created = DateTimeField(required=True, user_set_allowed=True)
    started = DateTimeField()
    completed = DateTimeField()
    published = DateTimeField()
    active_duration = IntField(default=None)
    parent = StringField(reference_field="Task")
    project = StringField(reference_field=Project, user_set_allowed=True)
    output: Output = EmbeddedDocumentField(Output, default=Output)
    execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
    tags = SafeSortedListField(StringField(required=True),
                               user_set_allowed=True)
    system_tags = SafeSortedListField(StringField(required=True),
                                      user_set_allowed=True)
    script: Script = EmbeddedDocumentField(Script, default=Script)
    last_worker = StringField()
    last_worker_report = DateTimeField()
    last_update = DateTimeField()
    last_change = DateTimeField()
    last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
    last_metrics = SafeMapField(
        field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
    metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
    company_origin = StringField(exclude_by_default=True)
    duration = IntField()  # task duration in seconds
    hyperparams = SafeMapField(
        field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
    configuration = SafeMapField(
        field=EmbeddedDocumentField(ConfigurationItem))
    runtime = SafeDictField(default=dict)
    models: Models = EmbeddedDocumentField(Models, default=Models)
    container = SafeMapField(field=NullableStringField())
    enqueue_status = StringField(choices=get_options(TaskStatus),
                                 exclude_by_default=True)

    def get_index_company(self) -> str:
        """
        Returns the company ID used for locating indices containing task data.
        In case the task has a valid company, this is the company ID.
        Otherwise, if the task has a company_origin, this is a task that has been made public and the
         origin company should be used.
        Otherwise, an empty company is used.
        """
        return self.company or self.company_origin or ""
示例#4
0
class Model(DbModelMixin, Document):
    meta = {
        "db_alias":
        Database.backend,
        "strict":
        strict,
        "indexes": [
            "parent",
            "project",
            "task",
            "last_update",
            "metadata.key",
            "metadata.type",
            ("company", "framework"),
            ("company", "name"),
            ("company", "user"),
            {
                "name":
                "%s.model.main_text_index" % Database.backend,
                "fields":
                ["$name", "$id", "$comment", "$parent", "$task", "$project"],
                "default_language":
                "english",
                "weights": {
                    "name": 10,
                    "id": 10,
                    "comment": 10,
                    "parent": 5,
                    "task": 3,
                    "project": 3,
                },
            },
        ],
    }
    get_all_query_options = GetMixin.QueryParameterOptions(
        pattern_fields=("name", "comment"),
        fields=("ready", ),
        list_fields=(
            "tags",
            "system_tags",
            "framework",
            "uri",
            "id",
            "user",
            "project",
            "task",
            "parent",
        ),
        datetime_fields=("last_update", ),
    )

    id = StringField(primary_key=True)
    name = StrippedStringField(user_set_allowed=True, min_length=3)
    parent = StringField(reference_field="Model", required=False)
    user = StringField(required=True, reference_field=User)
    company = StringField(required=True, reference_field=Company)
    project = StringField(reference_field=Project, user_set_allowed=True)
    created = DateTimeField(required=True, user_set_allowed=True)
    task = StringField(reference_field=Task)
    comment = StringField(user_set_allowed=True)
    tags = SafeSortedListField(StringField(required=True),
                               user_set_allowed=True)
    system_tags = SafeSortedListField(StringField(required=True),
                                      user_set_allowed=True)
    uri = StrippedStringField(default="", user_set_allowed=True)
    framework = StringField()
    design = SafeDictField()
    labels = ModelLabels()
    ready = BooleanField(required=True)
    last_update = DateTimeField()
    ui_cache = SafeDictField(default=dict,
                             user_set_allowed=True,
                             exclude_by_default=True)
    company_origin = StringField(exclude_by_default=True)
    metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
        MetadataItem, default=list, user_set_allowed=True)