class EditHyperParamsRequest(TaskUpdateRequest): hyperparams: Sequence[HyperParamItem] = ListField( [HyperParamItem], validators=Length(minimum_value=1)) replace_hyperparams = StringField( validators=Enum(*get_options(ReplaceHyperparams)), default=ReplaceHyperparams.none, )
class Artifact(models.Base): key = StringField(required=True) type = StringField(required=True) mode = StringField(validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE) uri = StringField() hash = StringField() content_size = IntField() timestamp = IntField() type_data = EmbeddedField(ArtifactTypeData) display_data = ListField([list])
class CreateUserRequest(Base): name = StringField(required=True) company = StringField(required=True) role = StringField( validators=Enum(*(set(get_options(Role)))), default=Role.user, ) email = StringField(required=True) family_name = StringField() given_name = StringField() avatar = StringField()
class Artifact(EmbeddedDocument): key = StringField(required=True) type = StringField(required=True) mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE) uri = StringField() hash = StringField() content_size = LongField() timestamp = LongField() type_data = EmbeddedDocumentField(ArtifactTypeData) display_data = SafeSortedListField(ListField(UnionField( (int, float, str))))
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest): task = get_task_for_update(company_id=company_id, task_id=request.task, force=True) delete_names = { type_: [m.name for m in request.models if m.type == type_] for type_ in get_options(TaskModelTypes) } commands = { f"pull__models__{field}__name__in": names for field, names in delete_names.items() if names } updated = task.update( last_change=datetime.utcnow(), **commands, ) return {"updated": updated}
class User(DbModelMixin, AuthDocument): meta = {"db_alias": Database.auth, "strict": strict} id = StringField(primary_key=True) name = StringField() created = DateTimeField() """ User auth entry creation time """ validated = DateTimeField() """ Last validation (login) time """ role = StringField(required=True, choices=get_options(Role), default=Role.user) """ User role """ company = StringField(required=True) """ Company this user belongs to """ credentials = EmbeddedDocumentListField(Credentials, default=list) """ Credentials generated for this user """ email = EmailField(unique=True, sparse=True) """ Email uniquely identifying the user """
def get_company_roles(cls) -> set: return set(get_options(cls)) - cls.get_system_roles()
class CreateRequest(TaskData): name = StringField(required=True) type = StringField(required=True, validators=Enum(*get_options(TaskType)))
class ArtifactId(models.Base): key = StringField(required=True) mode = StringField(validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE)
def get_project_stats( cls, company: str, project_ids: Sequence[str], specific_state: Optional[EntityVisibility] = None, include_children: bool = True, filter_: Mapping[str, Any] = None, ) -> Tuple[Dict[str, dict], Dict[str, dict]]: if not project_ids: return {}, {} child_projects = (_get_sub_projects(project_ids, _only=("id", "name")) if include_children else {}) project_ids_with_children = set(project_ids) | { c.id for c in itertools.chain.from_iterable(child_projects.values()) } status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines( company, project_ids=list(project_ids_with_children), specific_state=specific_state, filter_=filter_, ) default_counts = dict.fromkeys(get_options(TaskStatus), 0) def set_default_count(entry): return dict(default_counts, **entry) status_count = defaultdict(lambda: {}) key = itemgetter(EntityVisibility.archived.value) for result in Task.aggregate(status_count_pipeline): for k, group in groupby(sorted(result["counts"], key=key), key): section = (EntityVisibility.archived if k else EntityVisibility.active).value status_count[result["_id"]][section] = set_default_count({ count_entry["status"]: count_entry["count"] for count_entry in group }) def sum_status_count(a: Mapping[str, Mapping], b: Mapping[str, Mapping]) -> Dict[str, dict]: return { section: { status: nested_get(a, (section, status), default=0) + nested_get(b, (section, status), default=0) for status in set(a.get(section, {})) | set(b.get(section, {})) } for section in set(a) | set(b) } status_count = cls.aggregate_project_data( func=sum_status_count, project_ids=project_ids, child_projects=child_projects, data=status_count, ) runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} for result in Task.aggregate(runtime_pipeline) } def sum_runtime(a: Mapping[str, Mapping], b: Mapping[str, Mapping]) -> Dict[str, dict]: return { section: a.get(section, 0) + b.get(section, 0) if not section.endswith("max_task_started") else max( a.get(section) or datetime.min, b.get(section) or datetime.min) for section in set(a) | set(b) } runtime = cls.aggregate_project_data( func=sum_runtime, project_ids=project_ids, child_projects=child_projects, data=runtime, ) def get_status_counts(project_id, section): project_runtime = runtime.get(project_id, {}) project_section_statuses = nested_get(status_count, (project_id, section), default=default_counts) def get_time_or_none(value): return value if value != datetime.min else None return { "status_count": project_section_statuses, "total_tasks": sum(project_section_statuses.values()), "total_runtime": project_runtime.get(section, 0), "completed_tasks_24h": project_runtime.get(f"{section}_recently_completed", 0), "last_task_run": get_time_or_none( project_runtime.get(f"{section}_max_task_started", datetime.min)), } report_for_states = [ s for s in cls.visibility_states if not specific_state or specific_state == s ] stats = { project: { task_state.value: get_status_counts(project, task_state.value) for task_state in report_for_states } for project in project_ids } children = { project: sorted( [{ "id": c.id, "name": c.name } for c in child_projects.get(project, [])], key=itemgetter("name"), ) for project in project_ids } return stats, children
class ModelItemKey(models.Base): name = StringField(required=True) type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
class AddUpdateModelRequest(TaskRequest): name = StringField(required=True) model = StringField(required=True) type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes))) iteration = IntField()
def check_mongo_empty() -> bool: return not all( get_db(alias).collection_names() for alias in utils.get_options(Database))
from datetime import datetime from typing import Sequence, Union import attr import six from apiserver.apierrors import errors from apiserver.database.errors import translate_errors_context from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags from apiserver.database.utils import get_options from apiserver.timing_context import TimingContext from apiserver.utilities.attrs import typed_attrs valid_statuses = get_options(TaskStatus) deleted_prefix = "__DELETED__" @typed_attrs class ChangeStatusRequest(object): task = attr.ib(type=Task) new_status = attr.ib(type=six.string_types, validator=attr.validators.in_(valid_statuses)) status_reason = attr.ib(type=six.string_types, default="") status_message = attr.ib(type=six.string_types, default="") force = attr.ib(type=bool, default=False) allow_same_state_transition = attr.ib(type=bool, default=True) current_status_override = attr.ib(default=None) def execute(self, **kwargs): current_status = self.current_status_override or self.task.status
class Output(EmbeddedDocument): destination = StrippedStringField() error = StringField(user_set_allowed=True) result = StringField(choices=get_options(Result))
class TaskType(object): training = "training" testing = "testing" inference = "inference" data_processing = "data_processing" application = "application" monitor = "monitor" controller = "controller" optimizer = "optimizer" service = "service" qc = "qc" custom = "custom" external_task_types = set(get_options(TaskType)) class Task(AttributedDocument): _field_collation_overrides = { "execution.parameters.": AttributedDocument._numeric_locale, "last_metrics.": AttributedDocument._numeric_locale, "hyperparams.": AttributedDocument._numeric_locale, } meta = { "db_alias": Database.backend, "strict": strict, "indexes": [
class Task(AttributedDocument): _field_collation_overrides = { "execution.parameters.": AttributedDocument._numeric_locale, "last_metrics.": AttributedDocument._numeric_locale, "hyperparams.": AttributedDocument._numeric_locale, } meta = { "db_alias": Database.backend, "strict": strict, "indexes": [ "created", "started", "completed", "active_duration", "parent", "project", "models.input.model", ("company", "name"), ("company", "user"), ("company", "status", "type"), ("company", "system_tags", "last_update"), ("company", "type", "system_tags", "status"), ("company", "project", "type", "system_tags", "status"), ("status", "last_update"), # for maintenance tasks { "fields": ["company", "project"], "collation": AttributedDocument._numeric_locale, }, { "name": "%s.task.main_text_index" % Database.backend, "fields": [ "$name", "$id", "$comment", "$models.input.model", "$models.output.model", "$script.repository", "$script.entry_point", ], "default_language": "english", "weights": { "name": 10, "id": 10, "comment": 10, "models.output.model": 2, "models.input.model": 2, "script.repository": 1, "script.entry_point": 1, }, }, ], } get_all_query_options = GetMixin.QueryParameterOptions( list_fields=( "id", "user", "tags", "system_tags", "type", "status", "project", "parent", "hyperparams.*", ), range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"), datetime_fields=("status_changed", "last_update"), pattern_fields=("name", "comment"), ) id = StringField(primary_key=True) name = StrippedStringField(required=True, user_set_allowed=True, sparse=False, min_length=3) type = StringField(required=True, choices=get_options(TaskType)) status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus)) status_reason = StringField() status_message = StringField(user_set_allowed=True) status_changed = DateTimeField() comment = StringField(user_set_allowed=True) created = DateTimeField(required=True, user_set_allowed=True) started = DateTimeField() completed = DateTimeField() published = DateTimeField() active_duration = IntField(default=None) parent = StringField(reference_field="Task") project = StringField(reference_field=Project, user_set_allowed=True) output: Output = EmbeddedDocumentField(Output, default=Output) execution: Execution = EmbeddedDocumentField(Execution, default=Execution) tags = SafeSortedListField(StringField(required=True), user_set_allowed=True) system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True) script: Script = EmbeddedDocumentField(Script, default=Script) last_worker = StringField() last_worker_report = DateTimeField() last_update = DateTimeField() last_change = DateTimeField() last_iteration = IntField(default=DEFAULT_LAST_ITERATION) last_metrics = SafeMapField( field=SafeMapField(EmbeddedDocumentField(MetricEvent))) metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) company_origin = StringField(exclude_by_default=True) duration = IntField() # task duration in seconds hyperparams = SafeMapField( field=SafeMapField(EmbeddedDocumentField(ParamsItem))) configuration = SafeMapField( field=EmbeddedDocumentField(ConfigurationItem)) runtime = SafeDictField(default=dict) models: Models = EmbeddedDocumentField(Models, default=Models) container = SafeMapField(field=NullableStringField()) enqueue_status = StringField(choices=get_options(TaskStatus), exclude_by_default=True) def get_index_company(self) -> str: """ Returns the company ID used for locating indices containing task data. In case the task has a valid company, this is the company ID. Otherwise, if the task has a company_origin, this is a task that has been made public and the origin company should be used. Otherwise, an empty company is used. """ return self.company or self.company_origin or ""
def export_to_zip( cls, filename: str, experiments: Sequence[str] = None, projects: Sequence[str] = None, artifacts_path: str = None, task_statuses: Sequence[str] = None, tag_exported_entities: bool = False, metadata: Mapping[str, Any] = None, ) -> Sequence[str]: cls._init_entity_types() if task_statuses and not set(task_statuses).issubset( get_options(TaskStatus)): raise ValueError("Invalid task statuses") file = Path(filename) entities = cls._resolve_entities(experiments=experiments, projects=projects, task_statuses=task_statuses) hash_ = hashlib.md5() if metadata: meta_str = json.dumps(metadata) hash_.update(meta_str.encode()) metadata_hash = hash_.hexdigest() else: meta_str, metadata_hash = "", "" map_file = file.with_suffix(".map") updated, old_files = cls._check_for_update(map_file, entities=entities, metadata_hash=metadata_hash) if not updated: print(f"There are no updates from the last export") return old_files for old in old_files: old_path = Path(old) if old_path.is_file(): old_path.unlink() with ZipFile(file, **cls.zip_args) as zfile: if metadata: zfile.writestr(cls.metadata_filename, meta_str) artifacts = cls._export( zfile, entities=entities, hash_=hash_, tag_entities=tag_exported_entities, ) file_with_hash = file.with_name( f"{file.stem}_{hash_.hexdigest()}{file.suffix}") file.replace(file_with_hash) created_files = [str(file_with_hash)] artifacts = cls._filter_artifacts(artifacts) if artifacts and artifacts_path and os.path.isdir(artifacts_path): artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext) with ZipFile(artifacts_file, **cls.zip_args) as zfile: cls._export_artifacts(zfile, artifacts, artifacts_path) created_files.append(str(artifacts_file)) cls._write_update_file( map_file, entities=entities, created_files=created_files, metadata_hash=metadata_hash, ) return created_files
from mongoengine import Q from apiserver.apierrors import errors from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.auth import User, Entities, Credentials from apiserver.database.model.company import Company from apiserver.database.utils import get_options from apiserver.timing_context import TimingContext from .fixed_user import FixedUser from .identity import Identity from .payload import Payload, Token, Basic, AuthType log = config.logger(__file__) entity_keys = set(get_options(Entities)) verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True) def get_auth_func(auth_type): if auth_type == AuthType.bearer_token: return authorize_token elif auth_type == AuthType.basic: return authorize_credentials raise errors.unauthorized.BadAuthType() def authorize_token(jwt_token, *_, **__): """Validate token against service/endpoint and requests data (dicts). Returns a parsed token object (auth payload)
def get_project_stats( cls, company: str, project_ids: Sequence[str], specific_state: Optional[EntityVisibility] = None, ) -> Tuple[Dict[str, dict], Dict[str, dict]]: if not project_ids: return {}, {} child_projects = _get_sub_projects(project_ids, _only=("id", "name")) project_ids_with_children = set(project_ids) | { c.id for c in itertools.chain.from_iterable(child_projects.values()) } status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines( company, project_ids=list(project_ids_with_children), specific_state=specific_state, ) default_counts = dict.fromkeys(get_options(TaskStatus), 0) def set_default_count(entry): return dict(default_counts, **entry) status_count = defaultdict(lambda: {}) key = itemgetter(EntityVisibility.archived.value) for result in Task.aggregate(status_count_pipeline): for k, group in groupby(sorted(result["counts"], key=key), key): section = ( EntityVisibility.archived if k else EntityVisibility.active ).value status_count[result["_id"]][section] = set_default_count( { count_entry["status"]: count_entry["count"] for count_entry in group } ) def sum_status_count( a: Mapping[str, Mapping], b: Mapping[str, Mapping] ) -> Dict[str, dict]: return { section: { status: nested_get(a, (section, status), 0) + nested_get(b, (section, status), 0) for status in set(a.get(section, {})) | set(b.get(section, {})) } for section in set(a) | set(b) } status_count = cls.aggregate_project_data( func=sum_status_count, project_ids=project_ids, child_projects=child_projects, data=status_count, ) runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} for result in Task.aggregate(runtime_pipeline) } def sum_runtime( a: Mapping[str, Mapping], b: Mapping[str, Mapping] ) -> Dict[str, dict]: return { section: a.get(section, 0) + b.get(section, 0) for section in set(a) | set(b) } runtime = cls.aggregate_project_data( func=sum_runtime, project_ids=project_ids, child_projects=child_projects, data=runtime, ) def get_status_counts(project_id, section): return { "total_runtime": nested_get(runtime, (project_id, section), 0), "status_count": nested_get( status_count, (project_id, section), default_counts ), } report_for_states = [ s for s in EntityVisibility if not specific_state or specific_state == s ] stats = { project: { task_state.value: get_status_counts(project, task_state.value) for task_state in report_for_states } for project in project_ids } children = { project: sorted( [{"id": c.id, "name": c.name} for c in child_projects.get(project, [])], key=itemgetter("name"), ) for project in project_ids } return stats, children