コード例 #1
0
ファイル: models.py プロジェクト: ramonemiliani93/clearml
    def _get_output_models(task):
        # type: ("clearml.Task") -> ModelsList # noqa: F821

        res = task.send(
            models.GetAllRequest(
                task=[task.id], order_by=["created"], only_fields=["id"]
            )
        )
        ids = [m.id for m in res.response.models or []] + list(task.output_models_id.values())
        # remove duplicates and preserve order
        ids = list(OrderedDict.fromkeys(ids))

        id_to_name = (
            {x.model: x.name for x in task.data.models.output}
            if Session.check_min_api_version("2.13")
            else {}
        )

        def resolve_name(index, model_id):
            return id_to_name.get(model_id, "Output Model #{}".format(index))

        from clearml.model import Model

        output_models = OrderedDict(
            (resolve_name(i, m_id), Model(model_id=m_id)) for i, m_id in enumerate(ids)
        )

        return ModelsList(output_models)
コード例 #2
0
    def _get_input_models(self, task):
        # type: ("clearml.Task") -> ModelsList # noqa: F821

        if Session.check_min_api_version("2.13"):
            parsed_ids = list(task.input_models_id.values())
        else:
            # since we'll fall back to the new task.models.input if no parsed IDs are found, only
            #  extend this with the input model in case we're using 2.13 and have any parsed IDs or if we're using
            #  a lower API version.
            parsed_ids = [i[-1] for i in self._input_models_re.findall(task.comment)]
            # get the last one on the Task
            parsed_ids.extend(list(task.input_models_id.values()))

        from clearml.model import Model

        def get_model(id_):
            m = Model(model_id=id_)
            # noinspection PyBroadException
            try:
                # make sure the model is is valid
                # noinspection PyProtectedMember
                m._get_model_data()
                return m
            except Exception:
                pass

        # noinspection PyProtectedMember
        if Session.check_min_api_version("2.13") and task._get_task_property(
                "models.input", raise_on_error=False, log_on_error=False):
            input_models = OrderedDict(
                (x.name, get_model(x.model)) for x in task.data.models.input
            )
        else:
            # remove duplicates and preserve order
            input_models = OrderedDict(
                ("Input Model #{}".format(i), a_model)
                for i, a_model in enumerate(
                    filter(None, map(get_model, OrderedDict.fromkeys(parsed_ids)))
                )
            )

        return ModelsList(input_models)