def __init__(self, train_id, data_manager, tag=None): Validation.check_param_empty(train_id=train_id) super(GraphProcessor, self).__init__(data_manager) train_job = self._data_manager.get_train_job_by_plugin( train_id, PluginNameEnum.GRAPH.value) if train_job is None: raise exceptions.TrainJobNotExistError() if not train_job['tags'] or (tag is not None and tag not in train_job['tags']): raise exceptions.GraphNotExistError() if tag is None: tag = train_job['tags'][0] tensors = self._data_manager.list_tensors(train_id, tag=tag) self._graph = tensors[0].value
def get_single_train_task(self, plugin_name, train_id): """ get single train task. Args: plugin_name (str): Plugin name, refer `PluginNameEnum`. train_id (str): Specify a training job to query. Returns: {'train_jobs': list[TrainJob]}, refer to restful api. """ Validation.check_param_empty(plugin_name=plugin_name, train_id=train_id) Validation.check_plugin_name(plugin_name=plugin_name) train_job = self._data_manager.get_train_job_by_plugin(train_id=train_id, plugin_name=plugin_name) if train_job is None: raise exceptions.TrainJobNotExistError() return dict(train_jobs=[train_job])