예제 #1
0
 def post_init(self):  # add default options
     if self.data_directories is None:
         self.data_directories = []
     assert self.subsample >= 1
     if self.input_size is None:
         self.input_size = []
     if self.hdf5_files is None:
         self.hdf5_files = []
     if self.data_directories is not None and len(
             self.data_directories) != 0:
         self.data_directories = [
             os.path.join(get_data_dir(os.environ['HOME']), d)
             if not d.startswith('/') else d for d in self.data_directories
         ]
         # check if raw_data is a subdirectory of d and fill in raw_data runs if this is the case.
         data_directories = []
         for d in self.data_directories:
             if os.path.isdir(os.path.join(d, 'raw_data')):
                 data_directories.extend(glob(f'{d}/raw_data/*'))
             else:
                 data_directories.append(d)
         self.data_directories = data_directories
     if self.hdf5_files is not None and len(self.hdf5_files) != 0:
         self.hdf5_files = [
             os.path.join(get_data_dir(os.environ['HOME']), hdf5_f)
             if not hdf5_f.startswith('/') else hdf5_f
             for hdf5_f in self.hdf5_files
         ]
 def test_get_data_dir(self):
     # with datadir environment variable
     if "DATADIR" not in os.environ.keys():
         os.environ["DATADIR"] = '/my/wonderful/data/dir'
     self.assertTrue("DATADIR" in os.environ.keys())
     result = get_data_dir(os.environ['HOME'])
     self.assertTrue(result, os.environ["DATADIR"])
     del os.environ['DATADIR']
     self.assertFalse("DATADIR" in os.environ.keys())
     result = get_data_dir(os.environ['HOME'])
     self.assertTrue(result, os.environ["HOME"])
 def test_load_imagenet_pretrained_checkpoint(self):
     base_config['architecture'] = 'auto_encoder_deeply_supervised'
     base_config['batch_normalisation'] = True
     network = eval(base_config['architecture']).Net(
         config=ArchitectureConfig().create(config_dict=base_config), )
     checkpoint = torch.load(
         os.path.join(get_data_dir(os.environ['PWD']), 'pretrained_models',
                      'auto_encoder_deeply_supervised', 'torch_checkpoints',
                      'checkpoint_latest.ckpt'))
     network.load_checkpoint(checkpoint['net_ckpt'])
 def adjust_hdf5_files(hdf5_files: List[str]) -> List[str]:
     hdf5_files = [
         os.path.join(get_data_dir(os.environ['HOME']),
                      original_hdf5_file) if
         not original_hdf5_file.startswith('/') else original_hdf5_file
         for original_hdf5_file in hdf5_files
     ]
     new_hdf5_files = [
         os.path.join(
             self.local_home,
             f"{len(original_to_new_location_tuples) + index}.hdf5")
         for index in range(len(hdf5_files))
     ]
     original_to_new_location_tuples.extend(
         zip(hdf5_files, new_hdf5_files))
     return new_hdf5_files
예제 #5
0
def augment_background_textured(dataset: Dataset, texture_directory: str,
                                p: float = 0.0, p_empty: float = 0.0,
                                binary_images: List[np.ndarray] = None) -> Dataset:
    """
    parse background and fore ground. Give fore ground random color and background a crop of a textured image.
    :param binary_images:
    :param p_empty: in case of augmentation with texture, potentially augment without any foreground and empty action map (hacky!)
    :param p: probability for each image to be augmented with background
    :param dataset: augmented dataset in which observations are adjusted with new back- and foregrounds
    :param texture_directory: directory with sub-directories for each texture
    :return: augmented dataset
    """
    # 1. load texture image paths and keep in RAM
    if not texture_directory.startswith('/'):
        texture_directory = os.path.join(get_data_dir(os.environ['HOME']), texture_directory)
    texture_paths = [os.path.join(texture_directory, sub_directory, image)
                     for sub_directory in os.listdir(texture_directory)
                     if os.path.isdir(os.path.join(texture_directory, sub_directory))
                     for image in os.listdir(os.path.join(texture_directory, sub_directory))
                     if image.endswith('.jpg')]
    assert len(texture_paths) != 0
    # 2. parse binary images to extract fore- and background
    if binary_images is None:
        binary_images = parse_binary_maps(copy.deepcopy(dataset.observations), invert=True)
    # 3. combine foreground and background in augmented dataset
    augmented_dataset = copy.deepcopy(dataset)
    augmented_dataset.observations = []
    num_channels = dataset.observations[0].shape[0]
    for index, image in tqdm(enumerate(binary_images[:-1])):
        if np.random.binomial(1, p):
            bg_image = np.random.choice(texture_paths)
            bg = load_and_preprocess_file(bg_image,
                                          size=(num_channels, image.shape[0], image.shape[1])).permute(1, 2, 0).numpy()
            if not np.random.binomial(1, p_empty):
                # blurred_mask = motion_blur_mask(binary_images[index], binary_images[index + 1], num_channels)
                blurred_mask = gaussian_blur_mask(binary_images[index], num_channels)
                fg = create_random_gradient_image(size=bg.shape, mean=bg.mean())
                new_img = torch.as_tensor(((1 - blurred_mask) * bg + blurred_mask * fg)).permute(2, 0, 1)
                augmented_dataset.observations.append(new_img)
            else:  # for ratio p_empty put only background
                augmented_dataset.observations.append(torch.as_tensor(bg).permute(2, 0, 1))
                augmented_dataset.actions[index] = torch.zeros_like(augmented_dataset.actions[index])
        else:
            augmented_dataset.observations.append(dataset.observations[index])
    return augmented_dataset
 def post_init(self):  # add default options
     if not self.output_path.startswith('/'):
         self.output_path = os.path.join(get_data_dir(self.codebase_dir),
                                         self.output_path)
     if self.use_singularity:  # set singularity directory to gluster or opal
         if not self.singularity_dir.startswith('/'):
             if os.path.isdir(
                     '/gluster/visics/kkelchte/singularity_images'):
                 self.singularity_dir = os.path.join(
                     '/gluster/visics/kkelchte/singularity_images',
                     self.singularity_dir)
             else:
                 self.singularity_dir = os.path.join(
                     '/esat/opal/kkelchte/singularity_images',
                     self.singularity_dir)
         if self.singularity_file == '':
             self.singularity_file = sorted(
                 glob.glob(f'{self.singularity_dir}/*'))[-1]
         elif not self.singularity_file.startswith('/'):
             self.singularity_file = os.path.join(self.singularity_dir,
                                                  self.singularity_file)
         assert os.path.isfile(self.singularity_file)
 def test_load_dronet_checkpoint(self):
     base_config['architecture'] = 'dronet'
     network = eval(base_config['architecture']).Net(
         config=ArchitectureConfig().create(config_dict=base_config), )
     checkpoint = torch.load(
         os.path.join(get_data_dir(os.environ['PWD']), 'pretrained_models',
                      'dronet', 'torch_checkpoints',
                      'checkpoint_latest.ckpt'))
     network.load_checkpoint(checkpoint['net_ckpt'])
     self.assertLess(network.conv2d_1.weight.sum().item() - conv2d, 0.001)
     self.assertLess(
         network.batch_normalization_1.weight.sum().item() -
         batch_normalization, 0.001)
     self.assertLess(network.conv2d_2.weight.sum().item() - conv2d_1, 0.001)
     self.assertLess(
         network.batch_normalization_2.weight.sum().item() -
         batch_normalization_1, 0.001)
     self.assertLess(network.conv2d_4.weight.sum().item() - conv2d_3, 0.001)
     self.assertLess(network.conv2d_3.weight.sum().item() - conv2d_2, 0.001)
     self.assertLess(
         network.batch_normalization_3.weight.sum().item() -
         batch_normalization_2, 0.001)
     self.assertLess(network.conv2d_5.weight.sum().item() - conv2d_4, 0.001)
     self.assertLess(
         network.batch_normalization_4.weight.sum().item() -
         batch_normalization_3, 0.001)
     self.assertLess(network.conv2d_7.weight.sum().item() - conv2d_6, 0.001)
     self.assertLess(network.conv2d_6.weight.sum().item() - conv2d_5, 0.001)
     self.assertLess(
         network.batch_normalization_5.weight.sum().item() -
         batch_normalization_4, 0.001)
     self.assertLess(network.conv2d_8.weight.sum().item() - conv2d_7, 0.001)
     self.assertLess(
         network.batch_normalization_6.weight.sum().item() -
         batch_normalization_5, 0.001)
     self.assertLess(network.conv2d_10.weight.sum().item() - conv2d_9,
                     0.001)
     self.assertLess(network.conv2d_9.weight.sum().item() - conv2d_8, 0.001)
예제 #8
0
            cprint(f'Terminated successfully? {bool(result)}', self._logger,
                   msg_type=MessageType.info if result else MessageType.warning)
        if self._data_saver is not None:
            self._data_saver.remove()
        if self._trainer is not None:
            self._trainer.remove()
        if self._evaluator is not None:
            self._evaluator.remove()
        if self._net is not None:
            self._net.remove()
        if self._episode_runner is not None:
            self._episode_runner.remove()
        [h.close() for h in self._logger.handlers]


if __name__ == "__main__":
    arguments = Parser().parse_args()
    config_file = arguments.config
    if arguments.rm:
        with open(config_file, 'r') as f:
            configuration = yaml.load(f, Loader=yaml.FullLoader)
        if not configuration['output_path'].startswith('/'):
            configuration['output_path'] = os.path.join(get_data_dir(os.environ['HOME']), configuration['output_path'])
        shutil.rmtree(configuration['output_path'], ignore_errors=True)

    experiment_config = ExperimentConfig().create(config_file=config_file,
                                                  seed=arguments.seed)
    experiment = Experiment(experiment_config)
    experiment.run()
    experiment.shutdown()