def test_dataset_size(self):
     dataset = Dataset()
     dataset.append(
         Experience(observation=torch.as_tensor([0] * 10),
                    action=torch.as_tensor([1] * 3),
                    reward=torch.as_tensor(0),
                    done=torch.as_tensor(2)))
     first_size = dataset.get_memory_size()
     dataset.append(
         Experience(observation=torch.as_tensor([0] * 10),
                    action=torch.as_tensor([1] * 3),
                    reward=torch.as_tensor(0),
                    done=torch.as_tensor(2)))
     self.assertEqual(2 * first_size, dataset.get_memory_size())
     dataset = Dataset()
     dataset.append(
         Experience(observation=torch.as_tensor([0] * 10,
                                                dtype=torch.float32),
                    action=torch.as_tensor([1] * 3, dtype=torch.float32),
                    reward=torch.as_tensor(0, dtype=torch.float32),
                    done=torch.as_tensor(2, dtype=torch.float32)))
     second_size = dataset.get_memory_size()
     self.assertEqual(first_size, 2 * second_size)
    def _clean(self, filename_tag: str, runs: List[str]) -> None:
        total_data_points = 0
        filename_index = 0
        hdf5_data = Dataset()
        for run in tqdm(runs):
            if self._config.require_success:
                if not os.path.isfile(os.path.join(run, 'Success')):
                    continue
            # load data in dataset in input size
            run_dataset = self._data_loader.load_dataset_from_directories(
                [run])
            if len(run_dataset) <= self._config.remove_first_n_timestamps:
                continue
            # remove first N frames
            for _ in range(self._config.remove_first_n_timestamps):
                run_dataset.pop()
            # subsample
            run_dataset.subsample(self._config.data_loader_config.subsample)
            # enforce max run length
            if self._config.max_run_length != -1:
                run_dataset.clip(self._config.max_run_length)
                assert len(run_dataset) <= self._config.max_run_length
            # augment with background noise and change target to binary map

            binary_maps = parse_binary_maps(run_dataset.observations, invert=self._config.invert_binary_maps) \
                if self._config.augment_background_noise != 0 or self._config.augment_background_textured != 0 else None
            if self._config.binary_maps_as_target:
                run_dataset = set_binary_maps_as_target(
                    run_dataset,
                    invert=self._config.invert_binary_maps,
                    binary_images=binary_maps,
                    smoothen_labels=self._config.smoothen_labels)

            if self._config.augment_background_noise != 0:
                run_dataset = augment_background_noise(
                    run_dataset,
                    p=self._config.augment_background_noise,
                    binary_images=binary_maps)
            if self._config.augment_background_textured != 0:
                run_dataset = augment_background_textured(
                    run_dataset,
                    texture_directory=self._config.texture_directory,
                    p=self._config.augment_background_textured,
                    p_empty=self._config.augment_empty_images,
                    binary_images=binary_maps)
            # store dhf5 file once max dataset size is reached
            hdf5_data.extend(run_dataset)
            self._data_loader.empty_dataset()
            if hdf5_data.get_memory_size() > self._config.max_hdf5_size:
                if self._config.shuffle:
                    hdf5_data.shuffle()
                create_hdf5_file_from_dataset(filename=os.path.join(
                    self._config.output_path,
                    f'{filename_tag}_{filename_index}.hdf5'),
                                              dataset=hdf5_data)
                filename_index += 1
                total_data_points += len(hdf5_data)
                hdf5_data = Dataset()
        if len(hdf5_data) != 0:
            if self._config.shuffle:
                hdf5_data.shuffle()
            create_hdf5_file_from_dataset(filename=os.path.join(
                self._config.output_path,
                f'{filename_tag}_{filename_index}.hdf5'),
                                          dataset=hdf5_data)
            total_data_points += len(hdf5_data)
        print(f'Total data points: {total_data_points}')