def test_big_data_hdf5_loop(self):
        # create 3 datasets as hdf5 files
        hdf5_files = []
        infos = []
        for index in range(3):
            output_path = os.path.join(self.output_dir, f'ds{index}')
            os.makedirs(output_path, exist_ok=True)
            config_dict = {
                'output_path': output_path,
                'store_hdf5': True,
                'training_validation_split': 1.0
            }
            config = DataSaverConfig().create(config_dict=config_dict)
            self.data_saver = DataSaver(config=config)
            infos.append(
                generate_dummy_dataset(self.data_saver,
                                       num_runs=2,
                                       input_size=(3, 10, 10),
                                       fixed_input_value=(0.3 * index) *
                                       np.ones((3, 10, 10)),
                                       store_hdf5=True))
            self.assertTrue(
                os.path.isfile(os.path.join(output_path, 'train.hdf5')))
            hdf5_files.append(os.path.join(output_path, 'train.hdf5'))
            hdf5_files.append(os.path.join(output_path, 'wrong.hdf5'))

        # create data loader with big data tag and three hdf5 training sets
        conf = {
            'output_path': self.output_dir,
            'hdf5_files': hdf5_files,
            'batch_size': 15,
            'loop_over_hdf5_files': True
        }
        loader = DataLoader(DataLoaderConfig().create(config_dict=conf))

        # sample data batches and see that index increases every two batches sampled
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.3,
                                   2)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.6,
                                   2)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0, 2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.3,
                                   2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.6,
                                   2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0, 2)
    def test_sample_batch(self):
        self.info = generate_dummy_dataset(self.data_saver,
                                           num_runs=20,
                                           input_size=(100, 100, 3),
                                           output_size=(3, ),
                                           continuous=False)
        max_num_batches = 2
        config_dict = {
            'data_directories': self.info['episode_directories'],
            'output_path': self.output_dir,
            'random_seed': 1,
            'batch_size': 3
        }
        data_loader = DataLoader(config=DataLoaderConfig().create(
            config_dict=config_dict))
        data_loader.load_dataset()
        first_batch = []
        index = 0
        for index, batch in enumerate(
                data_loader.sample_shuffled_batch(
                    max_number_of_batches=max_num_batches)):
            if index == 0:
                first_batch = deepcopy(batch)
            self.assertEqual(len(batch), config_dict['batch_size'])
        self.assertEqual(index, max_num_batches - 1)

        # test sampling seed for reproduction
        config_dict['random_seed'] = 2
        data_loader = DataLoader(config=DataLoaderConfig().create(
            config_dict=config_dict))
        data_loader.load_dataset()
        second_batch = []
        for index, batch in enumerate(
                data_loader.sample_shuffled_batch(
                    max_number_of_batches=max_num_batches)):
            second_batch = deepcopy(batch)
            break
        self.assertNotEqual(np.sum(np.asarray(first_batch.observations[0])),
                            np.sum(np.asarray(second_batch.observations[0])))
        config_dict['random_seed'] = 1
        data_loader = DataLoader(config=DataLoaderConfig().create(
            config_dict=config_dict))
        data_loader.load_dataset()
        third_batch = []
        for index, batch in enumerate(
                data_loader.sample_shuffled_batch(
                    max_number_of_batches=max_num_batches)):
            third_batch = deepcopy(batch)
            break
        self.assertEqual(np.sum(np.asarray(first_batch.observations[0])),
                         np.sum(np.asarray(third_batch.observations[0])))
Пример #3
0
class DomainAdaptationTrainer(Trainer):

    def __init__(self, config: TrainerConfig, network: BaseNet, quiet: bool = False):
        super().__init__(config, network, quiet=True)

        self._config.epsilon = 0.2 if self._config.epsilon == "default" else self._config.epsilon

        self.target_data_loader = DataLoader(config=self._config.target_data_loader_config)
        self.target_data_loader.load_dataset()
        self._domain_adaptation_criterion = eval(f'{self._config.domain_adaptation_criterion}()') \
            if not self._config.domain_adaptation_criterion == 'default' else MMDLossZhao()
        self._domain_adaptation_criterion.to(self._device)

        if not quiet:
            self._optimizer = eval(f'torch.optim.{self._config.optimizer}')(params=self._net.parameters(),
                                                                            lr=self._config.learning_rate,
                                                                            weight_decay=self._config.weight_decay)

            lambda_function = lambda f: 1 - f / self._config.scheduler_config.number_of_epochs
            self._scheduler = torch.optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda=lambda_function) \
                if self._config.scheduler_config is not None else None

            self._logger = get_logger(name=get_filename_without_extension(__file__),
                                      output_path=config.output_path,
                                      quiet=False)
            cprint(f'Started.', self._logger)

    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        total_error = []
        task_error = []
        domain_error = []
        for source_batch, target_batch in zip(self.data_loader.sample_shuffled_batch(),
                                              self.target_data_loader.sample_shuffled_batch()):
            self._optimizer.zero_grad()
            targets = data_to_tensor(source_batch.actions).type(self._net.dtype).to(self._device)
            # task loss
            predictions = self._net.forward(source_batch.observations, train=True)
            task_loss = (1 - self._config.epsilon) * self._criterion(predictions, targets).mean()

            # add domain adaptation loss
            domain_loss = self._config.epsilon * self._domain_adaptation_criterion(
                self._net.get_features(source_batch.observations, train=True),
                self._net.get_features(target_batch.observations, train=True))

            loss = task_loss + domain_loss
            loss.backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()
            self._net.global_step += 1
            task_error.append(task_loss.cpu().detach())
            domain_error.append(domain_loss.cpu().detach())
            total_error.append(loss.cpu().detach())
        self.put_model_back_to_original_device()

        if self._scheduler is not None:
            self._scheduler.step()

        task_error_distribution = Distribution(task_error)
        domain_error_distribution = Distribution(domain_error)
        total_error_distribution = Distribution(total_error)
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(task_error_distribution, 'training/task_error')
            writer.write_distribution(domain_error_distribution, 'training/domain_error')
            writer.write_distribution(total_error_distribution, 'training/total_error')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                writer.write_output_image(predictions, 'source/predictions')
                writer.write_output_image(targets, 'source/targets')
                writer.write_output_image(torch.stack(source_batch.observations), 'source/inputs')
                writer.write_output_image(self._net.forward(target_batch.observations, train=True),
                                          'target/predictions')
                writer.write_output_image(torch.stack(target_batch.observations), 'target/inputs')

        return f' training task: {self._config.criterion} {task_error_distribution.mean: 0.3e} ' \
               f'[{task_error_distribution.std:0.2e}]' \
               f' domain: {self._config.domain_adaptation_criterion} {domain_error_distribution.mean: 0.3e} ' \
               f'[{domain_error_distribution.std:0.2e}]'