コード例 #1
0
    def test_list_tensors_with_keyerror(self):
        """Test list_tensors method with parameter tag raises keyerror."""
        summary_base_dir = tempfile.mkdtemp()
        train_job_01 = 'train_01'
        name_01 = 'train_job_01'
        log_path_01 = os.path.join(summary_base_dir, 'dir1')
        self._make_path_and_file_list(log_path_01)
        modify_time_01 = 1575460551.9777446
        ms_loader = MSDataLoader(log_path_01)
        loader_01 = DataLoader(log_path_01)
        loader_01._loader = ms_loader

        loader = LoaderStruct(loader_id=train_job_01,
                              name=name_01,
                              path=log_path_01,
                              latest_update_time=modify_time_01,
                              data_loader=loader_01)
        loader_pool = {train_job_01: loader}
        d_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
        d_manager._status = DataManagerStatus.LOADING.value
        d_manager._loader_pool = loader_pool
        tag = 'image'
        with pytest.raises(ParamValueError):
            d_manager.list_tensors(train_job_01, tag)

        shutil.rmtree(summary_base_dir)
コード例 #2
0
    def test_list_tensors_success(self):
        """Test list_tensors method success."""
        summary_base_dir = tempfile.mkdtemp()
        train_job_01 = 'train_01'
        name_01 = 'train_job_01'
        log_path_01 = os.path.join(summary_base_dir, 'dir1')
        self._make_path_and_file_list(log_path_01)
        modify_time_01 = 1575460551.9777446
        loader_01 = DataLoader(log_path_01)

        ms_loader = MSDataLoader(log_path_01)
        event_data = EventsData()
        mock_obj = mock.MagicMock()
        mock_obj.samples.return_value = {'test result'}
        tag = 'image'
        event_data._reservoir_by_tag = {tag: mock_obj}
        ms_loader._events_data = event_data
        loader_01._loader = ms_loader

        loader = LoaderStruct(loader_id=train_job_01,
                              name=name_01,
                              path=log_path_01,
                              latest_update_time=modify_time_01,
                              data_loader=loader_01)
        loader_pool = {train_job_01: loader}
        d_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
        d_manager._status = DataManagerStatus.LOADING.value
        d_manager._loader_pool = loader_pool

        res = d_manager.list_tensors(train_job_01, tag)
        assert res == {'test result'}

        shutil.rmtree(summary_base_dir)