Exemplo n.º 1
0
class DataLoader:
    def __init__(self, config: DataLoaderConfig):
        self._config = config
        self._logger = get_logger(
            name=get_filename_without_extension(__file__),
            output_path=config.output_path,
            quiet=False)
        cprint(f'Started.', self._logger)
        self._dataset = Dataset()
        self._num_runs = 0
        self._probabilities: List = []
        self.seed()
        self._hdf5_file_index = -1

    def seed(self, seed: int = None):
        np.random.seed(
            self._config.random_seed) if seed is None else np.random.seed(seed)

    def update_data_directories_with_raw_data(self):
        if self._config.data_directories is None:
            self._config.data_directories = []
        for d in sorted(
                os.listdir(os.path.join(self._config.output_path,
                                        'raw_data'))):
            self._config.data_directories.append(
                os.path.join(self._config.output_path, 'raw_data', d))
        self._config.data_directories = list(set(
            self._config.data_directories))

    def load_dataset(self):
        if len(self._config.hdf5_files) != 0:
            if self._config.loop_over_hdf5_files:
                self._dataset = Dataset()
                self._hdf5_file_index += 1
                self._hdf5_file_index %= len(self._config.hdf5_files)
                while len(self._dataset) == 0:
                    try:
                        self._dataset.extend(
                            load_dataset_from_hdf5(
                                self._config.hdf5_files[self._hdf5_file_index],
                                input_size=self._config.input_size))
                    except OSError:
                        cprint(
                            f'Failed to load {self._config.hdf5_files[self._hdf5_file_index]}',
                            self._logger,
                            msg_type=MessageType.warning)
                        del self._config.hdf5_files[self._hdf5_file_index]
                        self._hdf5_file_index %= len(self._config.hdf5_files)
                cprint(
                    f'Loaded {len(self._dataset)} datapoints from {self._config.hdf5_files[self._hdf5_file_index]}',
                    self._logger,
                    msg_type=MessageType.warning if len(
                        self._dataset.observations) == 0 else MessageType.info)
            else:
                for hdf5_file in self._config.hdf5_files:
                    self._dataset.extend(
                        load_dataset_from_hdf5(
                            hdf5_file, input_size=self._config.input_size))
                cprint(
                    f'Loaded {len(self._dataset)} datapoints from {self._config.hdf5_files}',
                    self._logger,
                    msg_type=MessageType.warning if len(
                        self._dataset.observations) == 0 else MessageType.info)
        else:
            self.load_dataset_from_directories(self._config.data_directories)

        if self._config.subsample != 1:
            self._dataset.subsample(self._config.subsample)

        if self._config.balance_over_actions:
            self._probabilities = balance_weights_over_actions(self._dataset)

    def load_dataset_from_directories(self,
                                      directories: List[str] = None
                                      ) -> Dataset:
        directory_generator = tqdm(directories, ascii=True, desc=__name__) \
            if len(directories) > 10 else directories
        for directory in directory_generator:
            run = load_run(directory,
                           arrange_according_to_timestamp=False,
                           input_size=self._config.input_size,
                           scope=self._config.input_scope)
            if len(run) != 0:
                self._dataset.extend(experiences=run)
        cprint(
            f'Loaded {len(self._dataset)} data points from {len(directories)} directories',
            self._logger,
            msg_type=MessageType.warning
            if len(self._dataset) == 0 else MessageType.info)
        return self._dataset

    def empty_dataset(self) -> None:
        self._dataset = Dataset()

    def set_dataset(self, ds: Dataset = None) -> None:
        if ds is not None:
            self._dataset = ds
        else:
            self._dataset = Dataset()
            self.update_data_directories_with_raw_data()
            self.load_dataset()

    def empty_dataset(self) -> None:
        self._dataset = Dataset()

    def get_dataset(self) -> Dataset:
        return self._dataset

    def get_data_batch(self) -> Generator[Dataset, None, None]:
        if len(self._dataset) == 0 or self._config.loop_over_hdf5_files:
            self.load_dataset()
        index = 0
        while index < len(self._dataset):
            batch = Dataset()
            end_index = min(index + self._config.batch_size, len(self._dataset)) \
                if self._config.batch_size != -1 else len(self._dataset)
            batch.observations = self._dataset.observations[index:end_index]
            batch.actions = self._dataset.actions[index:end_index]
            batch.done = self._dataset.done[index:end_index]
            batch.rewards = self._dataset.rewards[index:end_index]
            index = index + self._config.batch_size if self._config.batch_size != -1 else len(
                self._dataset)
            yield batch

    def sample_shuffled_batch(self, max_number_of_batches: int = 1000) \
            -> Generator[Dataset, None, None]:
        """
        randomly shuffle data samples in runs in dataset and provide them as ready run objects
        :param batch_size: number of samples or datapoints in one batch
        :param max_number_of_batches: define an upperbound in number of batches to end epoch
        :param dataset: list of runs with inputs, outputs and batches
        :return: yield a batch up until all samples are done
        """
        if len(self._dataset) == 0 or self._config.loop_over_hdf5_files:
            self.load_dataset()
        # Get data indices:
        batch_count = 0
        while batch_count < min(
                len(self._dataset),
                max_number_of_batches * self._config.batch_size):
            sample_indices = np.random.choice(
                list(range(len(self._dataset))),
                size=self._config.batch_size,
                replace=len(self._dataset) < self._config.batch_size,
                p=self._probabilities
                if len(self._probabilities) != 0 else None)
            batch = select(self._dataset, sample_indices)
            batch_count += len(batch)
            yield batch
        return

    def split_data(self, indices: np.ndarray,
                   *args) -> Generator[tuple, None, None]:
        """
        Split the indices in batches of configs batch_size and select the data in args.
        :param indices: possible indices to be selected. If all indices can be selected, provide empty array.
        :param args: lists or tensors from which the corresponding data according to the indices is selected.
        :return: provides a tuple in the same order as the args with the selected data.
        """
        if len(indices) == 0:
            indices = np.arange(len(self._dataset))
        np.random.shuffle(indices)
        splits = np.array_split(
            indices, max(1, int(len(self._dataset) / self._config.batch_size)))
        for selected_indices in splits:
            return_tuple = (select(data, selected_indices) for data in args)
            yield return_tuple

    def remove(self):
        [h.close() for h in self._logger.handlers]
    def _clean(self, filename_tag: str, runs: List[str]) -> None:
        total_data_points = 0
        filename_index = 0
        hdf5_data = Dataset()
        for run in tqdm(runs):
            if self._config.require_success:
                if not os.path.isfile(os.path.join(run, 'Success')):
                    continue
            # load data in dataset in input size
            run_dataset = self._data_loader.load_dataset_from_directories(
                [run])
            if len(run_dataset) <= self._config.remove_first_n_timestamps:
                continue
            # remove first N frames
            for _ in range(self._config.remove_first_n_timestamps):
                run_dataset.pop()
            # subsample
            run_dataset.subsample(self._config.data_loader_config.subsample)
            # enforce max run length
            if self._config.max_run_length != -1:
                run_dataset.clip(self._config.max_run_length)
                assert len(run_dataset) <= self._config.max_run_length
            # augment with background noise and change target to binary map

            binary_maps = parse_binary_maps(run_dataset.observations, invert=self._config.invert_binary_maps) \
                if self._config.augment_background_noise != 0 or self._config.augment_background_textured != 0 else None
            if self._config.binary_maps_as_target:
                run_dataset = set_binary_maps_as_target(
                    run_dataset,
                    invert=self._config.invert_binary_maps,
                    binary_images=binary_maps,
                    smoothen_labels=self._config.smoothen_labels)

            if self._config.augment_background_noise != 0:
                run_dataset = augment_background_noise(
                    run_dataset,
                    p=self._config.augment_background_noise,
                    binary_images=binary_maps)
            if self._config.augment_background_textured != 0:
                run_dataset = augment_background_textured(
                    run_dataset,
                    texture_directory=self._config.texture_directory,
                    p=self._config.augment_background_textured,
                    p_empty=self._config.augment_empty_images,
                    binary_images=binary_maps)
            # store dhf5 file once max dataset size is reached
            hdf5_data.extend(run_dataset)
            self._data_loader.empty_dataset()
            if hdf5_data.get_memory_size() > self._config.max_hdf5_size:
                if self._config.shuffle:
                    hdf5_data.shuffle()
                create_hdf5_file_from_dataset(filename=os.path.join(
                    self._config.output_path,
                    f'{filename_tag}_{filename_index}.hdf5'),
                                              dataset=hdf5_data)
                filename_index += 1
                total_data_points += len(hdf5_data)
                hdf5_data = Dataset()
        if len(hdf5_data) != 0:
            if self._config.shuffle:
                hdf5_data.shuffle()
            create_hdf5_file_from_dataset(filename=os.path.join(
                self._config.output_path,
                f'{filename_tag}_{filename_index}.hdf5'),
                                          dataset=hdf5_data)
            total_data_points += len(hdf5_data)
        print(f'Total data points: {total_data_points}')