コード例 #1
0
class Task(AttributedDocument):
    meta = {
        "db_alias": Database.backend,
        "strict": strict,
        "indexes": [
            "created",
            "started",
            "completed",
            {
                "name": "%s.task.main_text_index" % Database.backend,
                "fields": [
                    "$name",
                    "$id",
                    "$comment",
                    "$execution.model",
                    "$output.model",
                    "$script.repository",
                    "$script.entry_point",
                ],
                "default_language": "english",
                "weights": {
                    "name": 10,
                    "id": 10,
                    "comment": 10,
                    "execution.model": 2,
                    "output.model": 2,
                    "script.repository": 1,
                    "script.entry_point": 1,
                },
            },
        ],
    }

    id = StringField(primary_key=True)
    name = StrippedStringField(
        required=True, user_set_allowed=True, sparse=False, min_length=3
    )

    type = StringField(required=True, choices=get_options(TaskType))
    status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
    status_reason = StringField()
    status_message = StringField()
    status_changed = DateTimeField()
    comment = StringField(user_set_allowed=True)
    created = DateTimeField(required=True, user_set_allowed=True)
    started = DateTimeField()
    completed = DateTimeField()
    published = DateTimeField()
    parent = StringField()
    project = StringField(reference_field=Project, user_set_allowed=True)
    output = EmbeddedDocumentField(Output, default=Output)
    execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
    tags = ListField(StringField(required=True), user_set_allowed=True)
    system_tags = ListField(StringField(required=True), user_set_allowed=True)
    script = EmbeddedDocumentField(Script)
    last_worker = StringField()
    last_worker_report = DateTimeField()
    last_update = DateTimeField()
    last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
    last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
コード例 #2
0
ファイル: auth.py プロジェクト: shomratalon/trains-server
class User(DbModelMixin, AuthDocument):
    meta = {"db_alias": Database.auth, "strict": strict}

    id = StringField(primary_key=True)
    name = StringField()

    created = DateTimeField()
    """ User auth entry creation time """

    validated = DateTimeField()
    """ Last validation (login) time """

    role = StringField(required=True,
                       choices=get_options(Role),
                       default=Role.user)
    """ User role """

    company = StringField(required=True)
    """ Company this user belongs to """

    credentials = EmbeddedDocumentListField(Credentials, default=list)
    """ Credentials generated for this user """

    email = EmailField(unique=True, required=True)
    """ Email uniquely identifying the user """
コード例 #3
0
ファイル: tasks.py プロジェクト: yjshen1982/trains-server
class EditHyperParamsRequest(TaskRequest):
    hyperparams: Sequence[HyperParamItem] = ListField(
        [HyperParamItem], validators=Length(minimum_value=1))
    replace_hyperparams = StringField(
        validators=Enum(*get_options(ReplaceHyperparams)),
        default=ReplaceHyperparams.none,
    )
コード例 #4
0
class Artifact(EmbeddedDocument):
    key = StringField(required=True)
    type = StringField(required=True)
    mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
    uri = StringField()
    hash = StringField()
    content_size = LongField()
    timestamp = LongField()
    type_data = EmbeddedDocumentField(ArtifactTypeData)
    display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
コード例 #5
0
class CreateUserRequest(Base):
    name = StringField(required=True)
    company = StringField(required=True)
    role = StringField(
        validators=Enum(*(set(get_options(Role)))),
        default=Role.user,
    )
    email = StringField(required=True)
    family_name = StringField()
    given_name = StringField()
    avatar = StringField()
コード例 #6
0
    def export_to_zip(
        cls,
        filename: str,
        experiments: Sequence[str] = None,
        projects: Sequence[str] = None,
        artifacts_path: str = None,
        task_statuses: Sequence[str] = None,
        tag_exported_entities: bool = False,
    ) -> Sequence[str]:
        if task_statuses and not set(task_statuses).issubset(
                get_options(TaskStatus)):
            raise ValueError("Invalid task statuses")

        file = Path(filename)
        entities = cls._resolve_entities(experiments=experiments,
                                         projects=projects,
                                         task_statuses=task_statuses)

        map_file = file.with_suffix(".map")
        updated, old_files = cls._check_for_update(map_file, entities)
        if not updated:
            print(f"There are no updates from the last export")
            return old_files
        for old in old_files:
            old_path = Path(old)
            if old_path.is_file():
                old_path.unlink()

        zip_args = dict(mode="w", compression=ZIP_BZIP2)
        with ZipFile(file, **zip_args) as zfile:
            artifacts, hash_ = cls._export(zfile,
                                           entities,
                                           tag_entities=tag_exported_entities)
        file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}")
        file.replace(file_with_hash)
        created_files = [str(file_with_hash)]

        artifacts = cls._filter_artifacts(artifacts)
        if artifacts and artifacts_path and os.path.isdir(artifacts_path):
            artifacts_file = file_with_hash.with_suffix(".artifacts")
            with ZipFile(artifacts_file, **zip_args) as zfile:
                cls._export_artifacts(zfile, artifacts, artifacts_path)
            created_files.append(str(artifacts_file))

        cls._write_update_file(map_file, entities, created_files)

        return created_files
コード例 #7
0
from datetime import datetime
from typing import TypeVar, Callable, Tuple, Sequence

import attr
import six
from boltons.dictutils import OneToOne

from apierrors import errors
from database.errors import translate_errors_context
from database.model.project import Project
from database.model.task.task import Task, TaskStatus, TaskSystemTags
from database.utils import get_options
from timing_context import TimingContext
from utilities.attrs import typed_attrs

valid_statuses = get_options(TaskStatus)


@typed_attrs
class ChangeStatusRequest(object):
    task = attr.ib(type=Task)
    new_status = attr.ib(type=six.string_types,
                         validator=attr.validators.in_(valid_statuses))
    status_reason = attr.ib(type=six.string_types, default="")
    status_message = attr.ib(type=six.string_types, default="")
    force = attr.ib(type=bool, default=False)
    allow_same_state_transition = attr.ib(type=bool, default=True)
    current_status_override = attr.ib(default=None)

    def execute(self, **kwargs):
        current_status = self.current_status_override or self.task.status
コード例 #8
0
ファイル: tasks.py プロジェクト: lcasassa/trains-server
class CreateRequest(TaskData):
    name = StringField(required=True)
    type = StringField(required=True, validators=Enum(*get_options(TaskType)))
コード例 #9
0
ファイル: auth.py プロジェクト: shomratalon/trains-server
 def get_company_roles(cls) -> set:
     return set(get_options(cls)) - cls.get_system_roles()
コード例 #10
0
from mongoengine import Q

from apierrors import errors
from config import config
from database.errors import translate_errors_context
from database.model.auth import User, Entities, Credentials
from database.model.company import Company
from database.utils import get_options
from timing_context import TimingContext
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType

log = config.logger(__file__)

entity_keys = set(get_options(Entities))

verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)


def get_auth_func(auth_type):
    if auth_type == AuthType.bearer_token:
        return authorize_token
    elif auth_type == AuthType.basic:
        return authorize_credentials
    raise errors.unauthorized.BadAuthType()


def authorize_token(jwt_token, *_, **__):
    """ Validate token against service/endpoint and requests data (dicts).
        Returns a parsed token object (auth payload)
コード例 #11
0
ファイル: output.py プロジェクト: yjshen1982/trains-server
class Output(EmbeddedDocument):
    destination = StrippedStringField()
    model = StringField(reference_field='Model')
    error = StringField(user_set_allowed=True)
    result = StringField(choices=get_options(Result))
コード例 #12
0
ファイル: projects.py プロジェクト: yjshen1982/trains-server
def get_all_ex(call: APICall):
    include_stats = call.data.get("include_stats")
    stats_for_state = call.data.get("stats_for_state",
                                    EntityVisibility.active.value)
    allow_public = not call.data.get("non_public", False)

    if stats_for_state:
        try:
            specific_state = EntityVisibility(stats_for_state)
        except ValueError:
            raise errors.bad_request.FieldsValueError(
                stats_for_state=stats_for_state)
    else:
        specific_state = None

    conform_tag_fields(call, call.data)
    with translate_errors_context(), TimingContext("mongo",
                                                   "projects_get_all"):
        projects = Project.get_many_with_join(
            company=call.identity.company,
            query_dict=call.data,
            query_options=get_all_query_options,
            allow_public=allow_public,
        )
        conform_output_tags(call, projects)

        if not include_stats:
            call.result.data = {"projects": projects}
            return

        ids = [project["id"] for project in projects]
        status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
            call.identity.company, ids, specific_state=specific_state)

        default_counts = dict.fromkeys(get_options(TaskStatus), 0)

        def set_default_count(entry):
            return dict(default_counts, **entry)

        status_count = defaultdict(lambda: {})
        key = itemgetter(EntityVisibility.archived.value)
        for result in Task.aggregate(status_count_pipeline):
            for k, group in groupby(sorted(result["counts"], key=key), key):
                section = (EntityVisibility.archived
                           if k else EntityVisibility.active).value
                status_count[result["_id"]][section] = set_default_count({
                    count_entry["status"]: count_entry["count"]
                    for count_entry in group
                })

        runtime = {
            result["_id"]: {k: v
                            for k, v in result.items() if k != "_id"}
            for result in Task.aggregate(runtime_pipeline)
        }

    def safe_get(obj, path, default=None):
        try:
            return dpath.get(obj, path)
        except KeyError:
            return default

    def get_status_counts(project_id, section):
        path = "/".join((project_id, section))
        return {
            "total_runtime": safe_get(runtime, path, 0),
            "status_count": safe_get(status_count, path, default_counts),
        }

    report_for_states = [
        s for s in EntityVisibility
        if not specific_state or specific_state == s
    ]

    for project in projects:
        project["stats"] = {
            task_state.value: get_status_counts(project["id"],
                                                task_state.value)
            for task_state in report_for_states
        }

    call.result.data = {"projects": projects}
コード例 #13
0
    def export_to_zip(
        cls,
        filename: str,
        experiments: Sequence[str] = None,
        projects: Sequence[str] = None,
        artifacts_path: str = None,
        task_statuses: Sequence[str] = None,
        tag_exported_entities: bool = False,
        metadata: Mapping[str, Any] = None,
    ) -> Sequence[str]:
        if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
            raise ValueError("Invalid task statuses")

        file = Path(filename)
        entities = cls._resolve_entities(
            experiments=experiments, projects=projects, task_statuses=task_statuses
        )

        hash_ = hashlib.md5()
        if metadata:
            meta_str = json.dumps(metadata)
            hash_.update(meta_str.encode())
            metadata_hash = hash_.hexdigest()
        else:
            meta_str, metadata_hash = "", ""

        map_file = file.with_suffix(".map")
        updated, old_files = cls._check_for_update(
            map_file, entities=entities, metadata_hash=metadata_hash
        )
        if not updated:
            print(f"There are no updates from the last export")
            return old_files

        for old in old_files:
            old_path = Path(old)
            if old_path.is_file():
                old_path.unlink()

        with ZipFile(file, **cls.zip_args) as zfile:
            if metadata:
                zfile.writestr(cls.metadata_filename, meta_str)
            artifacts = cls._export(
                zfile,
                entities=entities,
                hash_=hash_,
                tag_entities=tag_exported_entities,
            )

        file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
        file.replace(file_with_hash)
        created_files = [str(file_with_hash)]

        artifacts = cls._filter_artifacts(artifacts)
        if artifacts and artifacts_path and os.path.isdir(artifacts_path):
            artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
            with ZipFile(artifacts_file, **cls.zip_args) as zfile:
                cls._export_artifacts(zfile, artifacts, artifacts_path)
            created_files.append(str(artifacts_file))

        cls._write_update_file(
            map_file,
            entities=entities,
            created_files=created_files,
            metadata_hash=metadata_hash,
        )

        return created_files
コード例 #14
0
class Task(AttributedDocument):
    meta = {
        'db_alias':
        Database.backend,
        'strict':
        strict,
        'indexes': [
            'created',
            'started',
            'completed',
            {
                'name':
                '%s.task.main_text_index' % Database.backend,
                'fields': [
                    '$name',
                    '$id',
                    '$comment',
                    '$execution.model',
                    '$output.model',
                    '$script.repository',
                    '$script.entry_point',
                ],
                'default_language':
                'english',
                'weights': {
                    'name': 10,
                    'id': 10,
                    'comment': 10,
                    'execution.model': 2,
                    'output.model': 2,
                    'script.repository': 1,
                    'script.entry_point': 1,
                },
            },
        ],
    }

    id = StringField(primary_key=True)
    name = StrippedStringField(required=True,
                               user_set_allowed=True,
                               sparse=False,
                               min_length=3)

    type = StringField(required=True, choices=get_options(TaskType))
    status = StringField(default=TaskStatus.created,
                         choices=get_options(TaskStatus))
    status_reason = StringField()
    status_message = StringField()
    status_changed = DateTimeField()
    comment = StringField(user_set_allowed=True)
    created = DateTimeField(required=True, user_set_allowed=True)
    started = DateTimeField()
    completed = DateTimeField()
    published = DateTimeField()
    parent = StringField()
    project = StringField(reference_field=Project, user_set_allowed=True)
    output = EmbeddedDocumentField(Output, default=Output)
    execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
    tags = ListField(StringField(required=True), user_set_allowed=True)
    script = EmbeddedDocumentField(Script)
    last_update = DateTimeField()
    last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
    last_metrics = SafeMapField(
        field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
コード例 #15
0
class Task(AttributedDocument):
    _numeric_locale = {"locale": "en_US", "numericOrdering": True}
    _field_collation_overrides = {
        "execution.parameters.": _numeric_locale,
        "last_metrics.": _numeric_locale,
        "hyperparams.": _numeric_locale,
        "configuration.": _numeric_locale,
    }

    meta = {
        "db_alias": Database.backend,
        "strict": strict,
        "indexes": [
            "created",
            "started",
            "completed",
            "parent",
            "project",
            ("company", "name"),
            ("company", "user"),
            ("company", "type", "system_tags", "status"),
            ("company", "project", "type", "system_tags", "status"),
            ("status", "last_update"),  # for maintenance tasks
            {
                "name": "%s.task.main_text_index" % Database.backend,
                "fields": [
                    "$name",
                    "$id",
                    "$comment",
                    "$execution.model",
                    "$output.model",
                    "$script.repository",
                    "$script.entry_point",
                ],
                "default_language": "english",
                "weights": {
                    "name": 10,
                    "id": 10,
                    "comment": 10,
                    "execution.model": 2,
                    "output.model": 2,
                    "script.repository": 1,
                    "script.entry_point": 1,
                },
            },
        ],
    }
    get_all_query_options = GetMixin.QueryParameterOptions(
        list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
        datetime_fields=("status_changed",),
        pattern_fields=("name", "comment"),
        fields=("parent",),
    )

    id = StringField(primary_key=True)
    name = StrippedStringField(
        required=True, user_set_allowed=True, sparse=False, min_length=3
    )

    type = StringField(required=True, choices=get_options(TaskType))
    status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
    status_reason = StringField()
    status_message = StringField()
    status_changed = DateTimeField()
    comment = StringField(user_set_allowed=True)
    created = DateTimeField(required=True, user_set_allowed=True)
    started = DateTimeField()
    completed = DateTimeField()
    published = DateTimeField()
    parent = StringField()
    project = StringField(reference_field=Project, user_set_allowed=True)
    output: Output = EmbeddedDocumentField(Output, default=Output)
    execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
    tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
    system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
    script: Script = EmbeddedDocumentField(Script, default=Script)
    last_worker = StringField()
    last_worker_report = DateTimeField()
    last_update = DateTimeField()
    last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
    last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
    metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
    company_origin = StringField(exclude_by_default=True)
    duration = IntField()  # task duration in seconds
    hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
    configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
    runtime = SafeDictField(default=dict)
コード例 #16
0
class TaskType(object):
    training = "training"
    testing = "testing"
    inference = "inference"
    data_processing = "data_processing"
    application = "application"
    monitor = "monitor"
    controller = "controller"
    optimizer = "optimizer"
    service = "service"
    qc = "qc"
    custom = "custom"


external_task_types = set(get_options(TaskType))


class Task(AttributedDocument):
    _numeric_locale = {"locale": "en_US", "numericOrdering": True}
    _field_collation_overrides = {
        "execution.parameters.": _numeric_locale,
        "last_metrics.": _numeric_locale,
        "hyperparams.": _numeric_locale,
        "configuration.": _numeric_locale,
    }

    meta = {
        "db_alias": Database.backend,
        "strict": strict,
        "indexes": [