예제 #1
0
    def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name):
        """
        Parsing unavailable plugin name to single train task.

        Test Params:
        request route: GET('/v1/mindinsight/datavisual/single-job').
        request params: plugin_name.

        Expect:
        response status code: 400.
        response json: {
            'error_code': '50540002',
            'error_message': "Invalid parameter value. 'plugin_name' only can
                             be one of ['graph', 'image', 'scalar']"
        }
        """
        plugin_name_list = PluginNameEnum.list_members()

        params = dict(plugin_name=plugin_name, train_id="not_exist")
        url = get_url(TRAIN_ROUTES['single_job'], params)

        response = client.get(url)

        assert response.status_code == 400

        response = response.get_json()
        assert response['error_code'] == '5054500B'
        assert response['error_msg'] == f"Plugin is not available. " \
                                        f"Detail: 'plugin_name' only can be one of {plugin_name_list}"
예제 #2
0
    def create_summary(self, log_dir, steps_list, tag_name_list):
        """Create summary in log_dir."""
        metadata_dict = dict()
        timestamp = time.time() + self._time_count
        file_path = os.path.join(log_dir, f'test.summary.{int(timestamp)}')

        metadata_dict.update({"plugins": dict()})
        metadata_dict.update({"metadata": dict()})
        metadata_dict.update({"actual_values": dict()})
        for plugin_name in PluginNameEnum.list_members():
            metadata_dict["plugins"].update({plugin_name: list()})
            log_generator = log_generators.get(plugin_name)
            if plugin_name == PluginNameEnum.GRAPH.value:
                with open(self._graph_base_path, 'r') as load_f:
                    graph_dict = json.load(load_f)
                values = log_generator.generate_log(file_path, graph_dict)
                metadata_dict["actual_values"].update({plugin_name: values})
                metadata_dict["plugins"][plugin_name].append("UUID str")
            else:
                for tag_name in tag_name_list:
                    metadata, values = log_generator.generate_log(
                        file_path, steps_list, tag_name)
                    full_tag_name = f'{tag_name}/{plugin_name}'
                    metadata_dict["metadata"].update({full_tag_name: metadata})
                    metadata_dict["plugins"][plugin_name].append(full_tag_name)

                    if plugin_name == PluginNameEnum.IMAGE.value:
                        metadata_dict["actual_values"].update(
                            {full_tag_name: values})

        os.utime(file_path, (timestamp, timestamp))
        self._time_count += 1
        return metadata_dict
예제 #3
0
    def get_plugins(self, train_id, manual_update=True):
        """
        Queries the plug-in data for the specified training job

        Args:
            train_id (str): Specify a training job to query.
            manual_update (bool): Specifies whether to refresh automatically.

        Returns:
            dict, refer to restful api.
        """
        Validation.check_param_empty(train_id=train_id)
        if contains_null_byte(train_id=train_id):
            raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id))

        if manual_update:
            self._data_manager.cache_train_job(train_id)

        train_job = self._data_manager.get_train_job(train_id)

        try:
            data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY)
            plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY)
        except exceptions.TrainJobDetailNotInCacheError:
            plugins = []

        if not plugins:
            default_result = dict()
            for plugin_name in PluginNameEnum.list_members():
                default_result.update({plugin_name: list()})
            return dict(plugins=default_result)

        return dict(
            plugins=plugins
        )
 def test_get_single_train_task_with_not_exists_train_id(self):
     """Test getting single train task with not exists train_id."""
     train_task_manager = TrainTaskManager(self._mock_data_manager)
     for plugin_name in PluginNameEnum.list_members():
         test_train_id = "not_exist_id"
         with pytest.raises(TrainJobNotExistError) as exc_info:
             _ = train_task_manager.get_single_train_task(plugin_name, test_train_id)
         assert exc_info.value.message == "Train job is not exist. " \
                                          "Detail: Can not find the train job in data manager."
         assert exc_info.value.error_code == '50545005'
예제 #5
0
    def test_plugins_with_train_id_not_in_cache(self, client):
        """Test getting plugins with train id that not in loader pool."""
        train_id = "./summary0"
        params = dict(train_id=train_id)
        url = get_url(BASE_URL, params)
        response = client.get(url)
        plugins = response.get_json().get('plugins')

        for plugin_name in PluginNameEnum.list_members():
            # Empty list.
            assert not plugins.get(plugin_name)
예제 #6
0
 def test_get_single_train_task_with_not_exists_train_id(self, load_data):
     """Test getting single train task with not exists train_id."""
     train_task_manager = TrainTaskManager(self._mock_data_manager)
     for plugin_name in PluginNameEnum.list_members():
         test_train_id = "not_exist_id"
         with pytest.raises(ParamValueError) as exc_info:
             _ = train_task_manager.get_single_train_task(
                 plugin_name, test_train_id)
         assert exc_info.type == ParamValueError
         assert exc_info.value.message == "Invalid parameter value. Can not find " \
                                          "the train job in data manager."
         assert exc_info.value.error_code == '50540002'
    def test_get_single_train_task_with_params(self):
        """Test getting single train task with params."""
        train_task_manager = TrainTaskManager(self._mock_data_manager)
        for plugin_name in PluginNameEnum.list_members():
            for test_train_id in self._train_id_list:
                result = train_task_manager.get_single_train_task(plugin_name, test_train_id)
                tags = result.get("train_jobs")[0].get("tags")

                # if it is a UUID
                if tags:
                    assert test_train_id in self._plugins_id_map.get(plugin_name)
                else:
                    assert test_train_id not in self._plugins_id_map.get(plugin_name)
    def check_plugin_name(cls, plugin_name):
        """
        Check plugin name.

        Args:
            plugin_name (str): The plugin name.

        Raises:
            PluginNotAvailableError: When plugin name is not valid.
        """
        plugin_name_list = PluginNameEnum.list_members()
        if plugin_name not in plugin_name_list:
            raise PluginNotAvailableError(f"'plugin_name' only can be one of {plugin_name_list}")
예제 #9
0
    def check_plugin_name(cls, plugin_name):
        """
        Check plugin name.

        Args:
            plugin_name (str): The plugin name.

        Raises:
            ParamValueError: When plugin name is not valid.
        """
        plugin_name_list = PluginNameEnum.list_members()
        if plugin_name not in plugin_name_list:
            raise ParamValueError("'plugin_name' only can be one of {}"
                                  "".format(plugin_name_list))
 def test_query_single_train_task(self, client):
     """"Test query single train task."""
     for train_id in gbl.summaries_metadata:
         expected = gbl.summaries_metadata.get(train_id).get("plugins")
         for plugin_name in PluginNameEnum.list_members():
             params = dict(train_id=train_id, plugin_name=plugin_name)
             url = get_url(BASE_URL, params)
             response = client.get(url)
             result = response.get_json()
             tags = result["train_jobs"][0]["tags"]
             if plugin_name == PluginNameEnum.GRAPH.value:
                 assert len(tags) == len(expected[plugin_name])
             else:
                 assert sorted(tags) == sorted(expected[plugin_name])
예제 #11
0
    def test_plugins(self, client):
        """Test getting plugins."""
        train_id = gbl.get_train_ids()[0]
        expected_plugins = gbl.summaries_metadata.get(train_id).get("plugins")

        params = dict(train_id=train_id)
        url = get_url(BASE_URL, params)
        response = client.get(url)
        plugins = response.get_json().get('plugins')
        for plugin_name in PluginNameEnum.list_members():
            if plugin_name == PluginNameEnum.GRAPH.value:
                assert len(plugins.get(plugin_name)) == len(
                    expected_plugins.get(plugin_name))
            else:
                assert sorted(plugins.get(plugin_name)) == sorted(
                    expected_plugins.get(plugin_name))
예제 #12
0
    def get_train_job(self, train_id):
        """
        Get train job by train ID.

        This method overrides parent method.

        Args:
            train_id (str): Train ID for train job.
        Returns:
            dict, single train job, if can not find any data, will return None.
        """
        self._check_train_job_exist(train_id, self._loader_pool)

        loader = self._get_loader(train_id)
        if loader is None:
            logger.warning(
                "No valid summary log in train job %s, "
                "or it is not in the cache.", train_id)
            return None

        train_job = loader.to_dict()
        train_job.pop('data_loader')

        plugin_data = {}
        for plugin_name in PluginNameEnum.list_members():
            job = self.get_train_job_by_plugin(train_id,
                                               plugin_name=plugin_name)
            if job is None:
                plugin_data[plugin_name] = []
            else:
                plugin_data[plugin_name] = job['tags']

        train_job.update({DATAVISUAL_PLUGIN_KEY: plugin_data})

        # Will fill basic_info value in future.
        train_job_obj = CachedTrainJob(basic_info=None)
        train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job)

        train_job_obj.cache_status = loader.cache_status

        return train_job_obj
예제 #13
0
    def get_plugins(self, train_id, manual_update=True):
        """
        Queries the plug-in data for the specified training job

        Args:
            train_id (str): Specify a training job to query.
            manual_update (bool): Specifies whether to refresh automatically.

        Returns:
            dict, refer to restful api.
        """
        Validation.check_param_empty(train_id=train_id)
        train_job = self._data_manager.get_single_train_job(
            train_id, manual_update=manual_update)
        if not train_job:
            default_result = dict()
            for plugin_name in PluginNameEnum.list_members():
                default_result.update({plugin_name: list()})
            return dict(plugins=default_result)

        return dict(plugins=train_job['tag_mapping'])
예제 #14
0
    def _get_train_job_item(self, train_id):
        """
        Get train job item.

        Args:
            train_id (str): Specify train id.

        Returns:
            dict, a dict of train job item.
        """
        try:
            train_job = self._data_manager.get_train_job(train_id)
        except exceptions.TrainJobNotExistError:
            logger.warning('Train job %s not existed', train_id)
            return None

        basic_info = train_job.get_basic_info()
        train_job_item = dict(
            train_id=basic_info.train_id,
            relative_path=basic_info.train_id,
            create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
            update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
            profiler_dir=basic_info.profiler_dir,
            cache_status=train_job.cache_status.value,
            profiler_type=basic_info.profiler_type,
            summary_files=basic_info.summary_files,
            graph_files=basic_info.graph_files,
            lineage_files=basic_info.lineage_files,
            dump_dir=basic_info.dump_dir)

        if train_job.cache_status != CacheStatus.NOT_IN_CACHE:
            plugins = self.get_plugins(train_id, manual_update=False)
        else:
            plugins = dict(plugins={
                plugin: []
                for plugin in PluginNameEnum.list_members()
            })

        train_job_item.update(plugins)
        return train_job_item
예제 #15
0
    def get_single_train_job(self, train_id, manual_update=False):
        """
        Get train job by train ID.

        Args:
            train_id (str): Train ID for train job.
            manual_update (bool): If manual update, True.

        Returns:
            dict, single train job, if can not find any data, will return None.
        """
        self._check_status_valid()
        self._check_train_job_exist(train_id, self._loader_pool)

        loader = self._get_loader(train_id, manual_update)
        if loader is None:
            logger.warning(
                "No valid summary log in train job %s, "
                "or it is not in the cache.", train_id)
            return None

        train_job = loader.to_dict()
        train_job.pop('data_loader')

        plugin_data = {}
        for plugin_name in PluginNameEnum.list_members():
            job = self.get_train_job_by_plugin(train_id,
                                               plugin_name=plugin_name)
            if job is None:
                plugin_data[plugin_name] = []
            else:
                plugin_data[plugin_name] = job['tags']

        train_job.update({'tag_mapping': plugin_data})

        return train_job