コード例 #1
0
ファイル: models.py プロジェクト: paklau99988/trains-server
def create(call, company, req_model):
    assert isinstance(call, APICall)
    assert isinstance(req_model, CreateModelRequest)
    identity = call.identity

    if req_model.public:
        company = ""

    with translate_errors_context():

        project = req_model.project
        if project:
            validate_id(Project, company=company, project=project)

        task = req_model.task
        req_data = req_model.to_struct()
        if task:
            validate_task(call, req_data)

        fields = filter_fields(Model, req_data)
        conform_tag_fields(call, fields)

        # create and save model
        model = Model(
            id=database.utils.id(),
            user=identity.user,
            company=company,
            created=datetime.utcnow(),
            **fields,
        )
        model.save()

        call.result.data_model = CreateModelResponse(id=model.id, created=True)
コード例 #2
0
ファイル: queues.py プロジェクト: paklau99988/trains-server
def get_all(call: APICall):
    conform_tag_fields(call, call.data)
    queues = queue_bll.get_all(company_id=call.identity.company,
                               query_dict=call.data)
    conform_output_tags(call, queues)

    call.result.data = {"queues": queues}
コード例 #3
0
ファイル: models.py プロジェクト: yjshen1982/trains-server
def prepare_update_fields(call, company_id, fields: dict):
    fields = fields.copy()
    if "uri" in fields:
        # clear UI cache if URI is provided (model updated)
        fields["ui_cache"] = fields.pop("ui_cache", {})
    if "task" in fields:
        validate_task(company_id, fields)

    if "labels" in fields:
        labels = fields["labels"]

        def find_other_types(iterable, type_):
            res = [x for x in iterable if not isinstance(x, type_)]
            try:
                return set(res)
            except TypeError:
                # Un-hashable, probably
                return res

        invalid_keys = find_other_types(labels.keys(), str)
        if invalid_keys:
            raise errors.bad_request.ValidationError(
                "labels keys must be strings", keys=invalid_keys)

        invalid_values = find_other_types(labels.values(), int)
        if invalid_values:
            raise errors.bad_request.ValidationError(
                "labels values must be integers", values=invalid_values)

    conform_tag_fields(call, fields, validate=True)
    return fields
コード例 #4
0
ファイル: projects.py プロジェクト: yjshen1982/trains-server
def update(call: APICall):
    """
    update

    :summary: Update project information.
              See `project.create` for parameters.
    :return: updated - `int` - number of projects updated
             fields - `[string]` - updated fields
    """
    project_id = call.data["project"]

    with translate_errors_context():
        project = Project.get_for_writing(company=call.identity.company,
                                          id=project_id)
        if not project:
            raise errors.bad_request.InvalidProjectId(id=project_id)

        fields = parse_from_call(call.data,
                                 create_fields,
                                 Project.get_fields(),
                                 discard_none_values=False)
        conform_tag_fields(call, fields, validate=True)
        fields["last_update"] = datetime.utcnow()
        with TimingContext("mongo", "projects_update"):
            updated = project.update(upsert=False, **fields)
        conform_output_tags(call, fields)
        call.result.data_model = UpdateResponse(updated=updated, fields=fields)
コード例 #5
0
ファイル: models.py プロジェクト: yjshen1982/trains-server
def create(call: APICall, company_id, req_model: CreateModelRequest):

    if req_model.public:
        company_id = ""

    with translate_errors_context():

        project = req_model.project
        if project:
            validate_id(Project, company=company_id, project=project)

        task = req_model.task
        req_data = req_model.to_struct()
        if task:
            validate_task(company_id, req_data)

        fields = filter_fields(Model, req_data)
        conform_tag_fields(call, fields, validate=True)

        # create and save model
        model = Model(
            id=database.utils.id(),
            user=call.identity.user,
            company=company_id,
            created=datetime.utcnow(),
            **fields,
        )
        model.save()
        _update_cached_tags(company_id, project=model.project, fields=fields)

        call.result.data_model = CreateModelResponse(id=model.id, created=True)
コード例 #6
0
ファイル: queues.py プロジェクト: paklau99988/trains-server
def update(call: APICall, company_id, req_model: UpdateRequest):
    data = call.data_model_for_partial_update
    conform_tag_fields(call, data)
    updated, fields = queue_bll.update(company_id=company_id,
                                       queue_id=req_model.queue,
                                       **data)
    conform_output_tags(call, fields)
    call.result.data_model = UpdateResponse(updated=updated, fields=fields)
コード例 #7
0
ファイル: models.py プロジェクト: yjshen1982/trains-server
def get_all_ex(call: APICall, company_id, _):
    conform_tag_fields(call, call.data)
    with translate_errors_context():
        with TimingContext("mongo", "models_get_all_ex"):
            models = Model.get_many_with_join(company=company_id,
                                              query_dict=call.data,
                                              allow_public=True)
        conform_output_tags(call, models)
        call.result.data = {"models": models}
コード例 #8
0
def prepare_update_fields(call, fields):
    fields = fields.copy()
    if "uri" in fields:
        # clear UI cache if URI is provided (model updated)
        fields["ui_cache"] = fields.pop("ui_cache", {})
    if "task" in fields:
        validate_task(call, fields)

    conform_tag_fields(call, fields)
    return fields
コード例 #9
0
ファイル: tasks.py プロジェクト: lcasassa/trains-server
def prepare_update_fields(call: APICall, task, call_data):
    valid_fields = deepcopy(task.__class__.user_set_allowed())
    update_fields = {
        k: v
        for k, v in create_fields.items() if k in valid_fields
    }
    update_fields["output__error"] = None
    t_fields = task_fields
    t_fields.add("output__error")
    fields = parse_from_call(call_data, update_fields, t_fields)
    conform_tag_fields(call, fields)
    return fields, valid_fields
コード例 #10
0
ファイル: models.py プロジェクト: paklau99988/trains-server
def get_all(call: APICall):
    conform_tag_fields(call, call.data)
    with translate_errors_context():
        with TimingContext("mongo", "models_get_all"):
            models = Model.get_many(
                company=call.identity.company,
                parameters=call.data,
                query_dict=call.data,
                allow_public=True,
                query_options=get_all_query_options,
            )
        conform_output_tags(call, models)
        call.result.data = {"models": models}
コード例 #11
0
ファイル: tasks.py プロジェクト: lcasassa/trains-server
def get_all_ex(call: APICall):
    conform_tag_fields(call, call.data)
    with translate_errors_context():
        with TimingContext("mongo", "task_get_all_ex"):
            tasks = Task.get_many_with_join(
                company=call.identity.company,
                query_dict=call.data,
                query_options=get_all_query_options,
                allow_public=
                True,  # required in case projection is requested for public dataset/versions
            )
        conform_output_tags(call, tasks)
        call.result.data = {"tasks": tasks}
コード例 #12
0
ファイル: projects.py プロジェクト: gaxler/trains-server
def get_all(call: APICall):
    conform_tag_fields(call, call.data)
    with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
        projects = Project.get_many(
            company=call.identity.company,
            query_dict=call.data,
            query_options=get_all_query_options,
            parameters=call.data,
            allow_public=True,
        )
        conform_output_tags(call, projects)

        call.result.data = {"projects": projects}
コード例 #13
0
ファイル: tasks.py プロジェクト: Goku12321/trains-server
def get_all_ex(call: APICall, company_id, _):
    conform_tag_fields(call, call.data)

    escape_execution_parameters(call)

    with translate_errors_context():
        with TimingContext("mongo", "task_get_all_ex"):
            tasks = Task.get_many_with_join(
                company=company_id,
                query_dict=call.data,
                allow_public=
                True,  # required in case projection is requested for public dataset/versions
            )
        unprepare_from_saved(call, tasks)
        call.result.data = {"tasks": tasks}
コード例 #14
0
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
    conform_tag_fields(call, fields, validate=True)
    params_prepare_for_save(fields, previous_task=previous_task)

    # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
    for field in task_script_fields:
        try:
            path = f"script/{field}"
            value = dpath.get(fields, path)
            if isinstance(value, str):
                value = value.strip()
            dpath.set(fields, path, value)
        except KeyError:
            pass

    return fields
コード例 #15
0
ファイル: projects.py プロジェクト: yjshen1982/trains-server
def create(call):
    assert isinstance(call, APICall)
    identity = call.identity

    with translate_errors_context():
        fields = parse_from_call(call.data, create_fields,
                                 Project.get_fields())
        conform_tag_fields(call, fields, validate=True)
        now = datetime.utcnow()
        project = Project(id=database.utils.id(),
                          user=identity.user,
                          company=identity.company,
                          created=now,
                          last_update=now,
                          **fields)
        with TimingContext("mongo", "projects_save"):
            project.save()
        call.result.data = {"id": project.id}
コード例 #16
0
def prepare_for_save(call: APICall, fields: dict):
    conform_tag_fields(call, fields)

    # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
    for field in task_script_fields:
        try:
            path = f"script/{field}"
            value = dpath.get(fields, path)
            if isinstance(value, str):
                value = value.strip()
            dpath.set(fields, path, value)
        except KeyError:
            pass

    parameters = safe_get(fields, "execution/parameters")
    if parameters is not None:
        # Escape keys to make them mongo-safe
        parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
        dpath.set(fields, "execution/parameters", parameters)

    return fields
コード例 #17
0
ファイル: tasks.py プロジェクト: lcasassa/trains-server
def prepare_create_fields(call: APICall,
                          valid_fields=None,
                          output=None,
                          previous_task: Task = None):
    valid_fields = valid_fields if valid_fields is not None else create_fields
    t_fields = task_fields
    t_fields.add("output_dest")

    fields = parse_from_call(call.data, valid_fields, t_fields)

    # Move output_dest to output.destination
    output_dest = fields.get("output_dest")
    if output_dest is not None:
        fields.pop("output_dest")
        if output:
            output.destination = output_dest
        else:
            output = Output(destination=output_dest)
        fields["output"] = output

    conform_tag_fields(call, fields)

    # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
    for field in task_script_fields:
        try:
            path = "script/%s" % field
            value = dpath.get(fields, path)
            if isinstance(value, six.string_types):
                value = value.strip()
            dpath.set(fields, path, value)
        except KeyError:
            pass

    parameters = safe_get(fields, "execution/parameters")
    if parameters is not None:
        parameters = {k.strip(): v for k, v in parameters.items()}
        dpath.set(fields, "execution/parameters", parameters)

    return fields
コード例 #18
0
ファイル: projects.py プロジェクト: yjshen1982/trains-server
def get_all_ex(call: APICall):
    include_stats = call.data.get("include_stats")
    stats_for_state = call.data.get("stats_for_state",
                                    EntityVisibility.active.value)
    allow_public = not call.data.get("non_public", False)

    if stats_for_state:
        try:
            specific_state = EntityVisibility(stats_for_state)
        except ValueError:
            raise errors.bad_request.FieldsValueError(
                stats_for_state=stats_for_state)
    else:
        specific_state = None

    conform_tag_fields(call, call.data)
    with translate_errors_context(), TimingContext("mongo",
                                                   "projects_get_all"):
        projects = Project.get_many_with_join(
            company=call.identity.company,
            query_dict=call.data,
            query_options=get_all_query_options,
            allow_public=allow_public,
        )
        conform_output_tags(call, projects)

        if not include_stats:
            call.result.data = {"projects": projects}
            return

        ids = [project["id"] for project in projects]
        status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
            call.identity.company, ids, specific_state=specific_state)

        default_counts = dict.fromkeys(get_options(TaskStatus), 0)

        def set_default_count(entry):
            return dict(default_counts, **entry)

        status_count = defaultdict(lambda: {})
        key = itemgetter(EntityVisibility.archived.value)
        for result in Task.aggregate(status_count_pipeline):
            for k, group in groupby(sorted(result["counts"], key=key), key):
                section = (EntityVisibility.archived
                           if k else EntityVisibility.active).value
                status_count[result["_id"]][section] = set_default_count({
                    count_entry["status"]: count_entry["count"]
                    for count_entry in group
                })

        runtime = {
            result["_id"]: {k: v
                            for k, v in result.items() if k != "_id"}
            for result in Task.aggregate(runtime_pipeline)
        }

    def safe_get(obj, path, default=None):
        try:
            return dpath.get(obj, path)
        except KeyError:
            return default

    def get_status_counts(project_id, section):
        path = "/".join((project_id, section))
        return {
            "total_runtime": safe_get(runtime, path, 0),
            "status_count": safe_get(status_count, path, default_counts),
        }

    report_for_states = [
        s for s in EntityVisibility
        if not specific_state or specific_state == s
    ]

    for project in projects:
        project["stats"] = {
            task_state.value: get_status_counts(project["id"],
                                                task_state.value)
            for task_state in report_for_states
        }

    call.result.data = {"projects": projects}
コード例 #19
0
ファイル: models.py プロジェクト: yjshen1982/trains-server
def parse_model_fields(call, valid_fields):
    fields = parse_from_call(call.data, valid_fields, Model.get_fields())
    conform_tag_fields(call, fields, validate=True)
    return fields