Exemplo n.º 1
0
class EditHyperParamsRequest(TaskUpdateRequest):
    hyperparams: Sequence[HyperParamItem] = ListField(
        [HyperParamItem], validators=Length(minimum_value=1))
    replace_hyperparams = StringField(
        validators=Enum(*get_options(ReplaceHyperparams)),
        default=ReplaceHyperparams.none,
    )
Exemplo n.º 2
0
class Artifact(models.Base):
    key = StringField(required=True)
    type = StringField(required=True)
    mode = StringField(validators=Enum(*get_options(ArtifactModes)),
                       default=DEFAULT_ARTIFACT_MODE)
    uri = StringField()
    hash = StringField()
    content_size = IntField()
    timestamp = IntField()
    type_data = EmbeddedField(ArtifactTypeData)
    display_data = ListField([list])
Exemplo n.º 3
0
class CreateUserRequest(Base):
    name = StringField(required=True)
    company = StringField(required=True)
    role = StringField(
        validators=Enum(*(set(get_options(Role)))),
        default=Role.user,
    )
    email = StringField(required=True)
    family_name = StringField()
    given_name = StringField()
    avatar = StringField()
Exemplo n.º 4
0
class Artifact(EmbeddedDocument):
    key = StringField(required=True)
    type = StringField(required=True)
    mode = StringField(choices=get_options(ArtifactModes),
                       default=DEFAULT_ARTIFACT_MODE)
    uri = StringField()
    hash = StringField()
    content_size = LongField()
    timestamp = LongField()
    type_data = EmbeddedDocumentField(ArtifactTypeData)
    display_data = SafeSortedListField(ListField(UnionField(
        (int, float, str))))
Exemplo n.º 5
0
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
    task = get_task_for_update(company_id=company_id,
                               task_id=request.task,
                               force=True)

    delete_names = {
        type_: [m.name for m in request.models if m.type == type_]
        for type_ in get_options(TaskModelTypes)
    }
    commands = {
        f"pull__models__{field}__name__in": names
        for field, names in delete_names.items() if names
    }

    updated = task.update(
        last_change=datetime.utcnow(),
        **commands,
    )
    return {"updated": updated}
Exemplo n.º 6
0
class User(DbModelMixin, AuthDocument):
    meta = {"db_alias": Database.auth, "strict": strict}

    id = StringField(primary_key=True)
    name = StringField()

    created = DateTimeField()
    """ User auth entry creation time """

    validated = DateTimeField()
    """ Last validation (login) time """

    role = StringField(required=True, choices=get_options(Role), default=Role.user)
    """ User role """

    company = StringField(required=True)
    """ Company this user belongs to """

    credentials = EmbeddedDocumentListField(Credentials, default=list)
    """ Credentials generated for this user """

    email = EmailField(unique=True, sparse=True)
    """ Email uniquely identifying the user """
Exemplo n.º 7
0
 def get_company_roles(cls) -> set:
     return set(get_options(cls)) - cls.get_system_roles()
Exemplo n.º 8
0
class CreateRequest(TaskData):
    name = StringField(required=True)
    type = StringField(required=True, validators=Enum(*get_options(TaskType)))
Exemplo n.º 9
0
class ArtifactId(models.Base):
    key = StringField(required=True)
    mode = StringField(validators=Enum(*get_options(ArtifactModes)),
                       default=DEFAULT_ARTIFACT_MODE)
Exemplo n.º 10
0
    def get_project_stats(
        cls,
        company: str,
        project_ids: Sequence[str],
        specific_state: Optional[EntityVisibility] = None,
        include_children: bool = True,
        filter_: Mapping[str, Any] = None,
    ) -> Tuple[Dict[str, dict], Dict[str, dict]]:
        if not project_ids:
            return {}, {}

        child_projects = (_get_sub_projects(project_ids, _only=("id", "name"))
                          if include_children else {})
        project_ids_with_children = set(project_ids) | {
            c.id
            for c in itertools.chain.from_iterable(child_projects.values())
        }
        status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
            company,
            project_ids=list(project_ids_with_children),
            specific_state=specific_state,
            filter_=filter_,
        )

        default_counts = dict.fromkeys(get_options(TaskStatus), 0)

        def set_default_count(entry):
            return dict(default_counts, **entry)

        status_count = defaultdict(lambda: {})
        key = itemgetter(EntityVisibility.archived.value)
        for result in Task.aggregate(status_count_pipeline):
            for k, group in groupby(sorted(result["counts"], key=key), key):
                section = (EntityVisibility.archived
                           if k else EntityVisibility.active).value
                status_count[result["_id"]][section] = set_default_count({
                    count_entry["status"]: count_entry["count"]
                    for count_entry in group
                })

        def sum_status_count(a: Mapping[str, Mapping],
                             b: Mapping[str, Mapping]) -> Dict[str, dict]:
            return {
                section: {
                    status: nested_get(a, (section, status), default=0) +
                    nested_get(b, (section, status), default=0)
                    for status in set(a.get(section, {}))
                    | set(b.get(section, {}))
                }
                for section in set(a) | set(b)
            }

        status_count = cls.aggregate_project_data(
            func=sum_status_count,
            project_ids=project_ids,
            child_projects=child_projects,
            data=status_count,
        )

        runtime = {
            result["_id"]: {k: v
                            for k, v in result.items() if k != "_id"}
            for result in Task.aggregate(runtime_pipeline)
        }

        def sum_runtime(a: Mapping[str, Mapping],
                        b: Mapping[str, Mapping]) -> Dict[str, dict]:
            return {
                section: a.get(section, 0) + b.get(section, 0)
                if not section.endswith("max_task_started") else max(
                    a.get(section) or datetime.min,
                    b.get(section) or datetime.min)
                for section in set(a) | set(b)
            }

        runtime = cls.aggregate_project_data(
            func=sum_runtime,
            project_ids=project_ids,
            child_projects=child_projects,
            data=runtime,
        )

        def get_status_counts(project_id, section):
            project_runtime = runtime.get(project_id, {})
            project_section_statuses = nested_get(status_count,
                                                  (project_id, section),
                                                  default=default_counts)

            def get_time_or_none(value):
                return value if value != datetime.min else None

            return {
                "status_count":
                project_section_statuses,
                "total_tasks":
                sum(project_section_statuses.values()),
                "total_runtime":
                project_runtime.get(section, 0),
                "completed_tasks_24h":
                project_runtime.get(f"{section}_recently_completed", 0),
                "last_task_run":
                get_time_or_none(
                    project_runtime.get(f"{section}_max_task_started",
                                        datetime.min)),
            }

        report_for_states = [
            s for s in cls.visibility_states
            if not specific_state or specific_state == s
        ]

        stats = {
            project: {
                task_state.value: get_status_counts(project, task_state.value)
                for task_state in report_for_states
            }
            for project in project_ids
        }

        children = {
            project: sorted(
                [{
                    "id": c.id,
                    "name": c.name
                } for c in child_projects.get(project, [])],
                key=itemgetter("name"),
            )
            for project in project_ids
        }
        return stats, children
Exemplo n.º 11
0
class ModelItemKey(models.Base):
    name = StringField(required=True)
    type = StringField(required=True,
                       validators=Enum(*get_options(TaskModelTypes)))
Exemplo n.º 12
0
class AddUpdateModelRequest(TaskRequest):
    name = StringField(required=True)
    model = StringField(required=True)
    type = StringField(required=True,
                       validators=Enum(*get_options(TaskModelTypes)))
    iteration = IntField()
Exemplo n.º 13
0
def check_mongo_empty() -> bool:
    return not all(
        get_db(alias).collection_names()
        for alias in utils.get_options(Database))
Exemplo n.º 14
0
from datetime import datetime
from typing import Sequence, Union

import attr
import six

from apiserver.apierrors import errors
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs

valid_statuses = get_options(TaskStatus)
deleted_prefix = "__DELETED__"


@typed_attrs
class ChangeStatusRequest(object):
    task = attr.ib(type=Task)
    new_status = attr.ib(type=six.string_types,
                         validator=attr.validators.in_(valid_statuses))
    status_reason = attr.ib(type=six.string_types, default="")
    status_message = attr.ib(type=six.string_types, default="")
    force = attr.ib(type=bool, default=False)
    allow_same_state_transition = attr.ib(type=bool, default=True)
    current_status_override = attr.ib(default=None)

    def execute(self, **kwargs):
        current_status = self.current_status_override or self.task.status
Exemplo n.º 15
0
class Output(EmbeddedDocument):
    destination = StrippedStringField()
    error = StringField(user_set_allowed=True)
    result = StringField(choices=get_options(Result))
Exemplo n.º 16
0
class TaskType(object):
    training = "training"
    testing = "testing"
    inference = "inference"
    data_processing = "data_processing"
    application = "application"
    monitor = "monitor"
    controller = "controller"
    optimizer = "optimizer"
    service = "service"
    qc = "qc"
    custom = "custom"


external_task_types = set(get_options(TaskType))


class Task(AttributedDocument):
    _field_collation_overrides = {
        "execution.parameters.": AttributedDocument._numeric_locale,
        "last_metrics.": AttributedDocument._numeric_locale,
        "hyperparams.": AttributedDocument._numeric_locale,
    }

    meta = {
        "db_alias":
        Database.backend,
        "strict":
        strict,
        "indexes": [
Exemplo n.º 17
0
class Task(AttributedDocument):
    _field_collation_overrides = {
        "execution.parameters.": AttributedDocument._numeric_locale,
        "last_metrics.": AttributedDocument._numeric_locale,
        "hyperparams.": AttributedDocument._numeric_locale,
    }

    meta = {
        "db_alias":
        Database.backend,
        "strict":
        strict,
        "indexes": [
            "created",
            "started",
            "completed",
            "active_duration",
            "parent",
            "project",
            "models.input.model",
            ("company", "name"),
            ("company", "user"),
            ("company", "status", "type"),
            ("company", "system_tags", "last_update"),
            ("company", "type", "system_tags", "status"),
            ("company", "project", "type", "system_tags", "status"),
            ("status", "last_update"),  # for maintenance tasks
            {
                "fields": ["company", "project"],
                "collation": AttributedDocument._numeric_locale,
            },
            {
                "name":
                "%s.task.main_text_index" % Database.backend,
                "fields": [
                    "$name",
                    "$id",
                    "$comment",
                    "$models.input.model",
                    "$models.output.model",
                    "$script.repository",
                    "$script.entry_point",
                ],
                "default_language":
                "english",
                "weights": {
                    "name": 10,
                    "id": 10,
                    "comment": 10,
                    "models.output.model": 2,
                    "models.input.model": 2,
                    "script.repository": 1,
                    "script.entry_point": 1,
                },
            },
        ],
    }
    get_all_query_options = GetMixin.QueryParameterOptions(
        list_fields=(
            "id",
            "user",
            "tags",
            "system_tags",
            "type",
            "status",
            "project",
            "parent",
            "hyperparams.*",
        ),
        range_fields=("started", "active_duration", "last_metrics.*",
                      "last_iteration"),
        datetime_fields=("status_changed", "last_update"),
        pattern_fields=("name", "comment"),
    )

    id = StringField(primary_key=True)
    name = StrippedStringField(required=True,
                               user_set_allowed=True,
                               sparse=False,
                               min_length=3)

    type = StringField(required=True, choices=get_options(TaskType))
    status = StringField(default=TaskStatus.created,
                         choices=get_options(TaskStatus))
    status_reason = StringField()
    status_message = StringField(user_set_allowed=True)
    status_changed = DateTimeField()
    comment = StringField(user_set_allowed=True)
    created = DateTimeField(required=True, user_set_allowed=True)
    started = DateTimeField()
    completed = DateTimeField()
    published = DateTimeField()
    active_duration = IntField(default=None)
    parent = StringField(reference_field="Task")
    project = StringField(reference_field=Project, user_set_allowed=True)
    output: Output = EmbeddedDocumentField(Output, default=Output)
    execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
    tags = SafeSortedListField(StringField(required=True),
                               user_set_allowed=True)
    system_tags = SafeSortedListField(StringField(required=True),
                                      user_set_allowed=True)
    script: Script = EmbeddedDocumentField(Script, default=Script)
    last_worker = StringField()
    last_worker_report = DateTimeField()
    last_update = DateTimeField()
    last_change = DateTimeField()
    last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
    last_metrics = SafeMapField(
        field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
    metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
    company_origin = StringField(exclude_by_default=True)
    duration = IntField()  # task duration in seconds
    hyperparams = SafeMapField(
        field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
    configuration = SafeMapField(
        field=EmbeddedDocumentField(ConfigurationItem))
    runtime = SafeDictField(default=dict)
    models: Models = EmbeddedDocumentField(Models, default=Models)
    container = SafeMapField(field=NullableStringField())
    enqueue_status = StringField(choices=get_options(TaskStatus),
                                 exclude_by_default=True)

    def get_index_company(self) -> str:
        """
        Returns the company ID used for locating indices containing task data.
        In case the task has a valid company, this is the company ID.
        Otherwise, if the task has a company_origin, this is a task that has been made public and the
         origin company should be used.
        Otherwise, an empty company is used.
        """
        return self.company or self.company_origin or ""
Exemplo n.º 18
0
    def export_to_zip(
        cls,
        filename: str,
        experiments: Sequence[str] = None,
        projects: Sequence[str] = None,
        artifacts_path: str = None,
        task_statuses: Sequence[str] = None,
        tag_exported_entities: bool = False,
        metadata: Mapping[str, Any] = None,
    ) -> Sequence[str]:
        cls._init_entity_types()

        if task_statuses and not set(task_statuses).issubset(
                get_options(TaskStatus)):
            raise ValueError("Invalid task statuses")

        file = Path(filename)
        entities = cls._resolve_entities(experiments=experiments,
                                         projects=projects,
                                         task_statuses=task_statuses)

        hash_ = hashlib.md5()
        if metadata:
            meta_str = json.dumps(metadata)
            hash_.update(meta_str.encode())
            metadata_hash = hash_.hexdigest()
        else:
            meta_str, metadata_hash = "", ""

        map_file = file.with_suffix(".map")
        updated, old_files = cls._check_for_update(map_file,
                                                   entities=entities,
                                                   metadata_hash=metadata_hash)
        if not updated:
            print(f"There are no updates from the last export")
            return old_files

        for old in old_files:
            old_path = Path(old)
            if old_path.is_file():
                old_path.unlink()

        with ZipFile(file, **cls.zip_args) as zfile:
            if metadata:
                zfile.writestr(cls.metadata_filename, meta_str)
            artifacts = cls._export(
                zfile,
                entities=entities,
                hash_=hash_,
                tag_entities=tag_exported_entities,
            )

        file_with_hash = file.with_name(
            f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
        file.replace(file_with_hash)
        created_files = [str(file_with_hash)]

        artifacts = cls._filter_artifacts(artifacts)
        if artifacts and artifacts_path and os.path.isdir(artifacts_path):
            artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
            with ZipFile(artifacts_file, **cls.zip_args) as zfile:
                cls._export_artifacts(zfile, artifacts, artifacts_path)
            created_files.append(str(artifacts_file))

        cls._write_update_file(
            map_file,
            entities=entities,
            created_files=created_files,
            metadata_hash=metadata_hash,
        )

        return created_files
Exemplo n.º 19
0
from mongoengine import Q

from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.auth import User, Entities, Credentials
from apiserver.database.model.company import Company
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType

log = config.logger(__file__)

entity_keys = set(get_options(Entities))

verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)


def get_auth_func(auth_type):
    if auth_type == AuthType.bearer_token:
        return authorize_token
    elif auth_type == AuthType.basic:
        return authorize_credentials
    raise errors.unauthorized.BadAuthType()


def authorize_token(jwt_token, *_, **__):
    """Validate token against service/endpoint and requests data (dicts).
    Returns a parsed token object (auth payload)
Exemplo n.º 20
0
    def get_project_stats(
        cls,
        company: str,
        project_ids: Sequence[str],
        specific_state: Optional[EntityVisibility] = None,
    ) -> Tuple[Dict[str, dict], Dict[str, dict]]:
        if not project_ids:
            return {}, {}

        child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
        project_ids_with_children = set(project_ids) | {
            c.id for c in itertools.chain.from_iterable(child_projects.values())
        }
        status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
            company,
            project_ids=list(project_ids_with_children),
            specific_state=specific_state,
        )

        default_counts = dict.fromkeys(get_options(TaskStatus), 0)

        def set_default_count(entry):
            return dict(default_counts, **entry)

        status_count = defaultdict(lambda: {})
        key = itemgetter(EntityVisibility.archived.value)
        for result in Task.aggregate(status_count_pipeline):
            for k, group in groupby(sorted(result["counts"], key=key), key):
                section = (
                    EntityVisibility.archived if k else EntityVisibility.active
                ).value
                status_count[result["_id"]][section] = set_default_count(
                    {
                        count_entry["status"]: count_entry["count"]
                        for count_entry in group
                    }
                )

        def sum_status_count(
            a: Mapping[str, Mapping], b: Mapping[str, Mapping]
        ) -> Dict[str, dict]:
            return {
                section: {
                    status: nested_get(a, (section, status), 0)
                    + nested_get(b, (section, status), 0)
                    for status in set(a.get(section, {})) | set(b.get(section, {}))
                }
                for section in set(a) | set(b)
            }

        status_count = cls.aggregate_project_data(
            func=sum_status_count,
            project_ids=project_ids,
            child_projects=child_projects,
            data=status_count,
        )

        runtime = {
            result["_id"]: {k: v for k, v in result.items() if k != "_id"}
            for result in Task.aggregate(runtime_pipeline)
        }

        def sum_runtime(
            a: Mapping[str, Mapping], b: Mapping[str, Mapping]
        ) -> Dict[str, dict]:
            return {
                section: a.get(section, 0) + b.get(section, 0)
                for section in set(a) | set(b)
            }

        runtime = cls.aggregate_project_data(
            func=sum_runtime,
            project_ids=project_ids,
            child_projects=child_projects,
            data=runtime,
        )

        def get_status_counts(project_id, section):
            return {
                "total_runtime": nested_get(runtime, (project_id, section), 0),
                "status_count": nested_get(
                    status_count, (project_id, section), default_counts
                ),
            }

        report_for_states = [
            s for s in EntityVisibility if not specific_state or specific_state == s
        ]

        stats = {
            project: {
                task_state.value: get_status_counts(project, task_state.value)
                for task_state in report_for_states
            }
            for project in project_ids
        }

        children = {
            project: sorted(
                [{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
                key=itemgetter("name"),
            )
            for project in project_ids
        }
        return stats, children