def test_big_data_hdf5_loop(self):
        # create 3 datasets as hdf5 files
        hdf5_files = []
        infos = []
        for index in range(3):
            output_path = os.path.join(self.output_dir, f'ds{index}')
            os.makedirs(output_path, exist_ok=True)
            config_dict = {
                'output_path': output_path,
                'store_hdf5': True,
                'training_validation_split': 1.0
            }
            config = DataSaverConfig().create(config_dict=config_dict)
            self.data_saver = DataSaver(config=config)
            infos.append(
                generate_dummy_dataset(self.data_saver,
                                       num_runs=2,
                                       input_size=(3, 10, 10),
                                       fixed_input_value=(0.3 * index) *
                                       np.ones((3, 10, 10)),
                                       store_hdf5=True))
            self.assertTrue(
                os.path.isfile(os.path.join(output_path, 'train.hdf5')))
            hdf5_files.append(os.path.join(output_path, 'train.hdf5'))
            hdf5_files.append(os.path.join(output_path, 'wrong.hdf5'))

        # create data loader with big data tag and three hdf5 training sets
        conf = {
            'output_path': self.output_dir,
            'hdf5_files': hdf5_files,
            'batch_size': 15,
            'loop_over_hdf5_files': True
        }
        loader = DataLoader(DataLoaderConfig().create(config_dict=conf))

        # sample data batches and see that index increases every two batches sampled
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.3,
                                   2)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.6,
                                   2)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0, 2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.3,
                                   2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.6,
                                   2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0, 2)
    def test_data_batch(self):
        self.info = generate_dummy_dataset(self.data_saver,
                                           num_runs=20,
                                           input_size=(100, 100, 3),
                                           output_size=(3, ),
                                           continuous=False)
        config_dict = {
            'data_directories': self.info['episode_directories'],
            'output_path': self.output_dir,
            'random_seed': 1,
            'batch_size': 3
        }
        data_loader = DataLoader(config=DataLoaderConfig().create(
            config_dict=config_dict))
        data_loader.load_dataset()

        for batch in data_loader.get_data_batch():
            self.assertEqual(len(batch), config_dict['batch_size'])
            break
Esempio n. 3
0
class Evaluator:
    def __init__(self,
                 config: EvaluatorConfig,
                 network: BaseNet,
                 quiet: bool = False):
        self._config = config
        self._net = network
        self.data_loader = DataLoader(config=self._config.data_loader_config)

        if not quiet:
            self._logger = get_logger(
                name=get_filename_without_extension(__file__),
                output_path=config.output_path,
                quiet=False) if type(self) == Evaluator else None
            cprint(f'Started.', self._logger)

        self._device = torch.device(
            "cuda" if self._config.device in ['gpu', 'cuda']
            and torch.cuda.is_available() else "cpu")
        self._criterion = eval(
            f'{self._config.criterion}(reduction=\'none\', {self._config.criterion_args_str})'
        )
        self._criterion.to(self._device)
        self._lowest_validation_loss = None
        self.data_loader.load_dataset()

        self._minimum_error = float(10**6)
        self._original_model_device = self._net.get_device(
        ) if self._net is not None else None

    def put_model_on_device(self, device: str = None):
        self._original_model_device = self._net.get_device()
        self._net.set_device(
            torch.device(self._config.device) if device is None else torch.
            device(device))

    def put_model_back_to_original_device(self):
        self._net.set_device(self._original_model_device)

    def evaluate(self,
                 epoch: int = -1,
                 writer=None,
                 tag: str = 'validation') -> Tuple[str, bool]:
        self.put_model_on_device()
        total_error = []
        #        for batch in tqdm(self.data_loader.get_data_batch(), ascii=True, desc='evaluate'):
        for batch in self.data_loader.get_data_batch():
            with torch.no_grad():
                predictions = self._net.forward(batch.observations,
                                                train=False)
                targets = data_to_tensor(batch.actions).type(
                    self._net.dtype).to(self._device)
                error = self._criterion(predictions, targets).mean()
                total_error.append(error)
        error_distribution = Distribution(total_error)
        self.put_model_back_to_original_device()
        if writer is not None:
            writer.write_distribution(error_distribution, tag)
            if self._config.store_output_on_tensorboard and (epoch % 30 == 0
                                                             or tag == 'test'):
                writer.write_output_image(predictions, f'{tag}/predictions')
                writer.write_output_image(targets, f'{tag}/targets')
                writer.write_output_image(torch.stack(batch.observations),
                                          f'{tag}/inputs')

        msg = f' {tag} {self._config.criterion} {error_distribution.mean: 0.3e} [{error_distribution.std:0.2e}]'

        best_checkpoint = False
        if self._lowest_validation_loss is None or error_distribution.mean < self._lowest_validation_loss:
            self._lowest_validation_loss = error_distribution.mean
            best_checkpoint = True
        return msg, best_checkpoint

    def evaluate_extensive(self) -> None:
        """
        Extra offline evaluation methods for an extensive evaluation at the end of training
        :return: None
        """
        self.put_model_on_device('cpu')
        self.data_loader.get_dataset().subsample(10)
        dataset = self.data_loader.get_dataset()
        predictions = self._net.forward(dataset.observations,
                                        train=False).detach().cpu()
        #error = predictions - torch.stack(dataset.actions)
        self.put_model_back_to_original_device()

        # save_output_plots(output_dir=self._config.output_path,
        #                   data={'expert': np.stack(dataset.actions),
        #                         'network': predictions.numpy(),
        #                         'difference': error.numpy()})
        # create_output_video(output_dir=self._config.output_path,
        #                     observations=dataset.observations,
        #                     actions={'expert': np.stack(dataset.actions),
        #                              'network': predictions.numpy()})
        create_output_video_segmentation_network(
            output_dir=self._config.output_path,
            observations=torch.stack(dataset.observations).numpy(),
            predictions=predictions.numpy())

    def remove(self):
        self.data_loader.remove()
        [h.close() for h in self._logger.handlers]