def select_channels(self, dataset: Dataset):
     if len(dataset.data_array.shape) == 3:
         dataset.data_array = dataset.data_array[np.newaxis, ...]
     selected_images_channels = dataset.data_array[...,
                                                   self.channels_to_select]
     if len(selected_images_channels.shape) == 3:
         selected_images_channels = selected_images_channels[...,
                                                             np.newaxis]
     dataset.data_array = selected_images_channels
     return dataset
 def normalize_by_image(self, dataset: Dataset):
     images = dataset.data_array
     images -= np.nanmin(images, axis=(1, 2))[:, np.newaxis, np.newaxis, :]
     images = images / np.nanmax(images, axis=(1, 2))[:, np.newaxis,
                                                      np.newaxis, :]
     dataset.data_array = images
     return dataset
 def normalize_by_sample(self, dataset: Dataset):
     images = dataset.data_array
     images -= np.nanmin(images, axis=(1, 2, 3))[..., np.newaxis,
                                                 np.newaxis, np.newaxis]
     images = images / np.nanmax(
         images, axis=(1, 2, 3))[..., np.newaxis, np.newaxis, np.newaxis]
     dataset.data_array = images
     return dataset
 def nan_to_num(self, dataset: Dataset):
     samples = dataset.data_array
     nans_sample_idx = self._get_nans_samples_idx(samples)
     if self.verbose:
         print('%i samples with NaNs. NaNs replaced with number %s' %
               (len(nans_sample_idx), str(self.number_to_replace_nans)))
     samples[np.isnan(samples)] = self.number_to_replace_nans
     dataset.data_array = samples
     return dataset
 def crop_at_center(self, dataset: Dataset):
     if self.crop_size is None:
         return dataset
     samples = dataset.data_array
     assert (samples.shape[1] % 2 == self.crop_size % 2)
     center = int((samples.shape[1]) / 2)
     crop_side = int(self.crop_size / 2)
     crop_begin = center - crop_side
     if samples.shape[1] % 2 == 0:
         crop_end = center + crop_side
     elif samples.shape[1] % 2 == 1:
         crop_end = center + crop_side + 1
     # print(center)
     # print(crop_begin, crop_end)
     cropped_samples = samples[:, crop_begin:crop_end,
                               crop_begin:crop_end, :]
     dataset.data_array = cropped_samples
     return dataset
 def check_single_image(self, dataset: Dataset):
     if len(dataset.data_array.shape) == 3:
         dataset.data_array = dataset.data_array[np.newaxis, ...]
     return dataset