def test_data_storage_in_raw_data_with_data_size_limit(self):
     config_dict = {
         'output_path': self.output_dir,
         'max_size': 25,
         'separate_raw_data_runs': True
     }
     config = DataSaverConfig().create(config_dict=config_dict)
     self.data_saver = DataSaver(config=config)
     first_info = generate_dummy_dataset(self.data_saver, num_runs=2)
     self.assertEqual(sum(first_info['episode_lengths']),
                      self.data_saver._frame_counter)
     self.data_saver.update_saving_directory()
     second_info = generate_dummy_dataset(self.data_saver, num_runs=2)
     self.assertTrue(
         (sum(first_info['episode_lengths']) +
          sum(second_info['episode_lengths'])) > config_dict['max_size'])
     self.assertTrue(
         self.data_saver._frame_counter <= config_dict['max_size'])
     raw_data_dir = os.path.dirname(self.data_saver.get_saving_directory())
     count_actual_frames = sum([
         len(
             os.listdir(
                 os.path.join(raw_data_dir, episode_dir, 'observation')))
         for episode_dir in os.listdir(raw_data_dir)
     ])
     self.assertEqual(count_actual_frames, self.data_saver._frame_counter)
    def test_create_hdf5_files_subsampled_in_time(self):
        num_runs = 10
        split = 1.0
        subsample = 3
        config_dict = {
            'output_path': self.output_dir,
            'training_validation_split': split,
            'store_hdf5': True,
            'subsample_hdf5': subsample,
            'separate_raw_data_runs': True
        }
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=num_runs)
        self.data_saver.create_train_validation_hdf5_files()

        config = DataLoaderConfig().create(
            config_dict={
                'output_path': self.output_dir,
                'hdf5_files': [os.path.join(self.output_dir, 'train.hdf5')]
            })
        training_data_loader = DataLoader(config=config)
        training_data_loader.load_dataset()
        training_data = training_data_loader.get_dataset()

        self.assertEqual(
            len(training_data),
            sum([
                np.ceil((el - 1) / subsample) + 1
                for el in info['episode_lengths']
            ]))
 def test_store_in_ram(self):
     config_dict = {
         'output_path': self.output_dir,
         'store_on_ram_only': True,
         'max_size': 10
     }
     number_of_runs = 10
     config = DataSaverConfig().create(config_dict=config_dict)
     self.data_saver = DataSaver(config=config)
     info = generate_dummy_dataset(self.data_saver, num_runs=number_of_runs)
     data = self.data_saver.get_dataset()
     self.assertEqual(len(data), config_dict['max_size'])
     for lst in [data.observations, data.actions, data.rewards, data.done]:
         self.assertEqual(len(lst), config_dict['max_size'])
         self.assertTrue(isinstance(lst[0], torch.Tensor))
 def test_empty_saving_directory(self):
     config_dict = {
         'output_path': self.output_dir,
         'separate_raw_data_runs': True
     }
     number_of_runs = 5
     config = DataSaverConfig().create(config_dict=config_dict)
     self.data_saver = DataSaver(config=config)
     info = generate_dummy_dataset(self.data_saver, num_runs=number_of_runs)
     self.assertEqual(
         len(os.listdir(os.path.join(self.output_dir, 'raw_data'))),
         number_of_runs)
     self.data_saver.empty_raw_data_in_output_directory()
     self.assertEqual(
         len(os.listdir(os.path.join(self.output_dir, 'raw_data'))), 0)
 def setUp(self) -> None:
     self.output_dir = f'{get_data_dir(os.environ["HOME"])}/test_dir/{get_filename_without_extension(__file__)}'
     if not os.path.isdir(self.output_dir):
         os.makedirs(self.output_dir)
     config_dict = {'output_path': self.output_dir, 'store_hdf5': True}
     config = DataSaverConfig().create(config_dict=config_dict)
     self.data_saver = DataSaver(config=config)
def generate_random_dataset_in_raw_data(
        output_dir: str,
        num_runs: int = 20,
        input_size: tuple = (100, 100, 3),
        output_size: tuple = (1, ),
        continuous: bool = True,
        fixed_input_value: Union[float, np.ndarray] = None,
        fixed_output_value: Union[float, np.ndarray] = None,
        store_hdf5: bool = False) -> dict:
    """Generate data, stored in raw_data directory of output_dir"""
    data_saver = DataSaver(config=DataSaverConfig().create(
        config_dict={
            'output_path': output_dir,
            'store_hdf5': store_hdf5,
            'separate_raw_data_runs': True
        }))
    info = generate_dummy_dataset(data_saver,
                                  num_runs=num_runs,
                                  input_size=input_size,
                                  output_size=output_size,
                                  continuous=continuous,
                                  fixed_input_value=fixed_input_value,
                                  fixed_output_value=fixed_output_value,
                                  store_hdf5=store_hdf5)
    return info
示例#7
0
def generate_dummy_dataset(data_saver: DataSaver,
                           num_runs: int = 10,
                           input_size: tuple = (100, 100, 3),
                           output_size: tuple = (1, ),
                           continuous: bool = True,
                           fixed_input_value: float = None,
                           fixed_output_value: float = None,
                           store_hdf5: bool = False) -> dict:
    episode_lengths = []
    episode_dirs = []
    for run in range(num_runs):
        episode_length = 0
        if run > 0:
            data_saver.update_saving_directory()
        for count, experience in enumerate(
                experience_generator(input_size=input_size,
                                     output_size=output_size,
                                     continuous=continuous,
                                     fixed_input_value=fixed_input_value,
                                     fixed_output_value=fixed_output_value)):
            if experience.done != TerminationType.Unknown:
                episode_length += 1
            data_saver.save(experience=experience)
        episode_lengths.append(episode_length)
        episode_dirs.append(data_saver.get_saving_directory())
    if store_hdf5:
        data_saver.create_train_validation_hdf5_files()
    return {
        'episode_lengths': episode_lengths,
        'episode_directories': episode_dirs
    }
 def test_data_storage_in_raw_data(self):
     config_dict = {
         'output_path': self.output_dir,
         'separate_raw_data_runs': True
     }
     config = DataSaverConfig().create(config_dict=config_dict)
     self.data_saver = DataSaver(config=config)
     info = generate_dummy_dataset(self.data_saver, num_runs=2)
     for total, episode_dir in zip(info['episode_lengths'],
                                   info['episode_directories']):
         self.assertEqual(
             len(
                 os.listdir(
                     os.path.join(self.output_dir, 'raw_data', episode_dir,
                                  'observation'))), total)
         with open(
                 os.path.join(self.output_dir, 'raw_data', episode_dir,
                              'action.data')) as f:
             expert_controls = f.readlines()
             self.assertEqual(len(expert_controls), total)
    def test_big_data_hdf5_loop(self):
        # create 3 datasets as hdf5 files
        hdf5_files = []
        infos = []
        for index in range(3):
            output_path = os.path.join(self.output_dir, f'ds{index}')
            os.makedirs(output_path, exist_ok=True)
            config_dict = {
                'output_path': output_path,
                'store_hdf5': True,
                'training_validation_split': 1.0
            }
            config = DataSaverConfig().create(config_dict=config_dict)
            self.data_saver = DataSaver(config=config)
            infos.append(
                generate_dummy_dataset(self.data_saver,
                                       num_runs=2,
                                       input_size=(3, 10, 10),
                                       fixed_input_value=(0.3 * index) *
                                       np.ones((3, 10, 10)),
                                       store_hdf5=True))
            self.assertTrue(
                os.path.isfile(os.path.join(output_path, 'train.hdf5')))
            hdf5_files.append(os.path.join(output_path, 'train.hdf5'))
            hdf5_files.append(os.path.join(output_path, 'wrong.hdf5'))

        # create data loader with big data tag and three hdf5 training sets
        conf = {
            'output_path': self.output_dir,
            'hdf5_files': hdf5_files,
            'batch_size': 15,
            'loop_over_hdf5_files': True
        }
        loader = DataLoader(DataLoaderConfig().create(config_dict=conf))

        # sample data batches and see that index increases every two batches sampled
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.3,
                                   2)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.6,
                                   2)
        for batch in loader.get_data_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0, 2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.3,
                                   2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0.6,
                                   2)
        for batch in loader.sample_shuffled_batch():
            self.assertAlmostEqual(batch.observations[0][0, 0, 0].item(), 0, 2)
示例#10
0
    def __init__(self, config: ExperimentConfig):
        np.random.seed(123)
        self._epoch = 0
        self._config = config
        self._logger = get_logger(name=get_filename_without_extension(__file__),
                                  output_path=config.output_path,
                                  quiet=False)
        self._data_saver = DataSaver(config=config.data_saver_config) \
            if self._config.data_saver_config is not None else None
        self._environment = EnvironmentFactory().create(config.environment_config) \
            if self._config.environment_config is not None else None
        self._net = eval(config.architecture_config.architecture).Net(config=config.architecture_config) \
            if self._config.architecture_config is not None else None
        self._trainer = TrainerFactory().create(config=self._config.trainer_config, network=self._net) \
            if self._config.trainer_config is not None else None
        self._evaluator = Evaluator(config=self._config.evaluator_config, network=self._net) \
            if self._config.evaluator_config is not None else None
        self._tester = None  # create at the end to avoid too much data is loaded in RAM
        self._writer = None
        if self._config.tensorboard:  # Local import so code can run without tensorboard
            from src.core.tensorboard_wrapper import TensorboardWrapper
            self._writer = TensorboardWrapper(log_dir=config.output_path)
        self._episode_runner = EpisodeRunner(config=self._config.episode_runner_config,
                                             data_saver=self._data_saver,
                                             environment=self._environment,
                                             net=self._net,
                                             writer=self._writer) \
            if self._config.episode_runner_config is not None else None

        if self._config.load_checkpoint_found \
                and len(glob(f'{self._config.output_path}/torch_checkpoints/*.ckpt')) > 0:
            self.load_checkpoint(self.get_checkpoint_file(self._config.output_path))
        elif self._config.load_checkpoint_file is not None:
            self.load_checkpoint(self._config.load_checkpoint_file)
        elif self._config.load_checkpoint_dir is not None:
            if not self._config.load_checkpoint_dir.startswith('/'):
                self._config.load_checkpoint_dir = f'{get_data_dir(self._config.output_path)}/' \
                                                   f'{self._config.load_checkpoint_dir}'
            self.load_checkpoint(self.get_checkpoint_file(self._config.load_checkpoint_dir))

        cprint(f'Initiated.', self._logger)
 def setUp(self) -> None:
     self.output_dir = f'{os.environ["PWD"]}/test_dir/{get_filename_without_extension(__file__)}'
     if not os.path.isdir(self.output_dir):
         os.makedirs(self.output_dir)
     config_dict = {'output_path': self.output_dir, 'store_hdf5': True}
     config = DataSaverConfig().create(config_dict=config_dict)
     self.data_saver = DataSaver(config=config)
     self.info = generate_dummy_dataset(self.data_saver,
                                        num_runs=20,
                                        input_size=(100, 100, 3),
                                        output_size=(3, ),
                                        continuous=False)
    def test_create_train_validation_hdf5_files(self):
        num_runs = 10
        split = 0.7
        config_dict = {
            'output_path': self.output_dir,
            'training_validation_split': split,
            'store_hdf5': True,
            'separate_raw_data_runs': True
        }
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=num_runs)
        self.data_saver.create_train_validation_hdf5_files()

        config = DataLoaderConfig().create(
            config_dict={
                'output_path': self.output_dir,
                'hdf5_files': [os.path.join(self.output_dir, 'train.hdf5')]
            })
        training_data_loader = DataLoader(config=config)
        training_data_loader.load_dataset()
        training_data = training_data_loader.get_dataset()

        config = DataLoaderConfig().create(
            config_dict={
                'output_path': self.output_dir,
                'hdf5_files':
                [os.path.join(self.output_dir, 'validation.hdf5')]
            })
        validation_data_loader = DataLoader(config=config)
        validation_data_loader.load_dataset()
        validation_data = validation_data_loader.get_dataset()

        self.assertEqual(len(training_data),
                         sum(info['episode_lengths'][:int(split * num_runs)]))
        self.assertEqual(len(validation_data),
                         sum(info['episode_lengths'][int(split * num_runs):]))
class TestDataSaver(unittest.TestCase):
    def setUp(self) -> None:
        self.output_dir = f'{get_data_dir(os.environ["PWD"])}/test_dir/{get_filename_without_extension(__file__)}'
        if not os.path.isdir(self.output_dir):
            os.makedirs(self.output_dir)
        self.data_saver = None

    def test_experience_generator(self):
        for count, experience in enumerate(experience_generator()):
            if count == 0:
                self.assertEqual(experience.done, TerminationType.Unknown)
        self.assertTrue(experience.done in [
            TerminationType.Done, TerminationType.Success,
            TerminationType.Failure
        ])

    def test_data_storage_in_raw_data(self):
        config_dict = {
            'output_path': self.output_dir,
            'separate_raw_data_runs': True
        }
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=2)
        for total, episode_dir in zip(info['episode_lengths'],
                                      info['episode_directories']):
            self.assertEqual(
                len(
                    os.listdir(
                        os.path.join(self.output_dir, 'raw_data', episode_dir,
                                     'observation'))), total)
            with open(
                    os.path.join(self.output_dir, 'raw_data', episode_dir,
                                 'action.data')) as f:
                expert_controls = f.readlines()
                self.assertEqual(len(expert_controls), total)

    def test_data_storage_in_raw_data_with_data_size_limit(self):
        config_dict = {
            'output_path': self.output_dir,
            'max_size': 25,
            'separate_raw_data_runs': True
        }
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        first_info = generate_dummy_dataset(self.data_saver, num_runs=2)
        self.assertEqual(sum(first_info['episode_lengths']),
                         self.data_saver._frame_counter)
        self.data_saver.update_saving_directory()
        second_info = generate_dummy_dataset(self.data_saver, num_runs=2)
        self.assertTrue(
            (sum(first_info['episode_lengths']) +
             sum(second_info['episode_lengths'])) > config_dict['max_size'])
        self.assertTrue(
            self.data_saver._frame_counter <= config_dict['max_size'])
        raw_data_dir = os.path.dirname(self.data_saver.get_saving_directory())
        count_actual_frames = sum([
            len(
                os.listdir(
                    os.path.join(raw_data_dir, episode_dir, 'observation')))
            for episode_dir in os.listdir(raw_data_dir)
        ])
        self.assertEqual(count_actual_frames, self.data_saver._frame_counter)

    def test_create_train_validation_hdf5_files(self):
        num_runs = 10
        split = 0.7
        config_dict = {
            'output_path': self.output_dir,
            'training_validation_split': split,
            'store_hdf5': True,
            'separate_raw_data_runs': True
        }
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=num_runs)
        self.data_saver.create_train_validation_hdf5_files()

        config = DataLoaderConfig().create(
            config_dict={
                'output_path': self.output_dir,
                'hdf5_files': [os.path.join(self.output_dir, 'train.hdf5')]
            })
        training_data_loader = DataLoader(config=config)
        training_data_loader.load_dataset()
        training_data = training_data_loader.get_dataset()

        config = DataLoaderConfig().create(
            config_dict={
                'output_path': self.output_dir,
                'hdf5_files':
                [os.path.join(self.output_dir, 'validation.hdf5')]
            })
        validation_data_loader = DataLoader(config=config)
        validation_data_loader.load_dataset()
        validation_data = validation_data_loader.get_dataset()

        self.assertEqual(len(training_data),
                         sum(info['episode_lengths'][:int(split * num_runs)]))
        self.assertEqual(len(validation_data),
                         sum(info['episode_lengths'][int(split * num_runs):]))

    def test_create_hdf5_files_subsampled_in_time(self):
        num_runs = 10
        split = 1.0
        subsample = 3
        config_dict = {
            'output_path': self.output_dir,
            'training_validation_split': split,
            'store_hdf5': True,
            'subsample_hdf5': subsample,
            'separate_raw_data_runs': True
        }
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=num_runs)
        self.data_saver.create_train_validation_hdf5_files()

        config = DataLoaderConfig().create(
            config_dict={
                'output_path': self.output_dir,
                'hdf5_files': [os.path.join(self.output_dir, 'train.hdf5')]
            })
        training_data_loader = DataLoader(config=config)
        training_data_loader.load_dataset()
        training_data = training_data_loader.get_dataset()

        self.assertEqual(
            len(training_data),
            sum([
                np.ceil((el - 1) / subsample) + 1
                for el in info['episode_lengths']
            ]))

    def test_empty_saving_directory(self):
        config_dict = {
            'output_path': self.output_dir,
            'separate_raw_data_runs': True
        }
        number_of_runs = 5
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=number_of_runs)
        self.assertEqual(
            len(os.listdir(os.path.join(self.output_dir, 'raw_data'))),
            number_of_runs)
        self.data_saver.empty_raw_data_in_output_directory()
        self.assertEqual(
            len(os.listdir(os.path.join(self.output_dir, 'raw_data'))), 0)

    def test_store_in_ram(self):
        config_dict = {
            'output_path': self.output_dir,
            'store_on_ram_only': True,
            'max_size': 10
        }
        number_of_runs = 10
        config = DataSaverConfig().create(config_dict=config_dict)
        self.data_saver = DataSaver(config=config)
        info = generate_dummy_dataset(self.data_saver, num_runs=number_of_runs)
        data = self.data_saver.get_dataset()
        self.assertEqual(len(data), config_dict['max_size'])
        for lst in [data.observations, data.actions, data.rewards, data.done]:
            self.assertEqual(len(lst), config_dict['max_size'])
            self.assertTrue(isinstance(lst[0], torch.Tensor))

    def tearDown(self) -> None:
        shutil.rmtree(self.output_dir, ignore_errors=True)
示例#14
0
class Experiment:

    def __init__(self, config: ExperimentConfig):
        np.random.seed(123)
        self._epoch = 0
        self._config = config
        self._logger = get_logger(name=get_filename_without_extension(__file__),
                                  output_path=config.output_path,
                                  quiet=False)
        self._data_saver = DataSaver(config=config.data_saver_config) \
            if self._config.data_saver_config is not None else None
        self._environment = EnvironmentFactory().create(config.environment_config) \
            if self._config.environment_config is not None else None
        self._net = eval(config.architecture_config.architecture).Net(config=config.architecture_config) \
            if self._config.architecture_config is not None else None
        self._trainer = TrainerFactory().create(config=self._config.trainer_config, network=self._net) \
            if self._config.trainer_config is not None else None
        self._evaluator = Evaluator(config=self._config.evaluator_config, network=self._net) \
            if self._config.evaluator_config is not None else None
        self._tester = None  # create at the end to avoid too much data is loaded in RAM
        self._writer = None
        if self._config.tensorboard:  # Local import so code can run without tensorboard
            from src.core.tensorboard_wrapper import TensorboardWrapper
            self._writer = TensorboardWrapper(log_dir=config.output_path)
        self._episode_runner = EpisodeRunner(config=self._config.episode_runner_config,
                                             data_saver=self._data_saver,
                                             environment=self._environment,
                                             net=self._net,
                                             writer=self._writer) \
            if self._config.episode_runner_config is not None else None

        if self._config.load_checkpoint_found \
                and len(glob(f'{self._config.output_path}/torch_checkpoints/*.ckpt')) > 0:
            self.load_checkpoint(self.get_checkpoint_file(self._config.output_path))
        elif self._config.load_checkpoint_file is not None:
            self.load_checkpoint(self._config.load_checkpoint_file)
        elif self._config.load_checkpoint_dir is not None:
            if not self._config.load_checkpoint_dir.startswith('/'):
                self._config.load_checkpoint_dir = f'{get_data_dir(self._config.output_path)}/' \
                                                   f'{self._config.load_checkpoint_dir}'
            self.load_checkpoint(self.get_checkpoint_file(self._config.load_checkpoint_dir))

        cprint(f'Initiated.', self._logger)

    def run(self):
        for self._epoch in range(self._config.number_of_epochs):
            best_ckpt = False
            msg = f'{get_date_time_tag()} epoch: {self._epoch + 1} / {self._config.number_of_epochs}'
            if self._environment is not None:
                if self._data_saver is not None and self._config.data_saver_config.clear_buffer_before_episode:
                    self._data_saver.clear_buffer()
                output_msg, best_ckpt = self._episode_runner.run(
                    store_frames=(self._config.tb_render_every_n_epochs != -1 and
                                  self._epoch % self._config.tb_render_every_n_epochs == 0 and
                                  self._writer is not None))
                if self._data_saver is not None and self._config.data_saver_config.store_hdf5:
                    self._data_saver.create_train_validation_hdf5_files()
                msg += output_msg
            if self._trainer is not None:
                if self._data_saver is not None:  # update fresh data to train
                    self._trainer.data_loader.set_dataset(
                        self._data_saver.get_dataset() if self._config.data_saver_config.store_on_ram_only else None
                    )
                msg += self._trainer.train(epoch=self._epoch, writer=self._writer)
            if self._evaluator is not None:  # if validation error is minimal then save best checkpoint
                output_msg, best_ckpt = self._evaluator.evaluate(epoch=self._epoch, writer=self._writer)
                msg += output_msg
            if self._config.run_test_episodes:
                output_msg, best_ckpt = self._episode_runner.run(
                    store_frames=(self._config.tb_render_every_n_epochs != -1 and
                                  self._epoch % self._config.tb_render_every_n_epochs == 0 and
                                  self._writer is not None),
                    test=True,
                    # ! adversarial tag should be in architecture name
                    adversarial='adversarial' in self._config.architecture_config.architecture
                )
                msg += output_msg
            if self._config.save_checkpoint_every_n != -1 and \
                    (self._epoch % self._config.save_checkpoint_every_n == 0 or
                     self._epoch == self._config.number_of_epochs - 1):
                self.save_checkpoint(tag=f'{self._epoch:05d}')
            if best_ckpt and self._config.save_checkpoint_every_n != -1:
                self.save_checkpoint(tag='best')
            cprint(msg, self._logger)
        if self._trainer is not None:
            self._trainer.data_loader.empty_dataset()
        if self._evaluator is not None and self._config.evaluator_config.evaluate_extensive:
            self._evaluator.evaluate_extensive()
        if self._evaluator is not None:
            self._evaluator.data_loader.empty_dataset()
        self._tester = Evaluator(config=self._config.tester_config, network=self._net) \
            if self._config.tester_config is not None else None
        if self._tester is not None:
            output_msg, _ = self._tester.evaluate(epoch=self._epoch, writer=self._writer, tag='test')
            cprint(f'Testing: {output_msg}', self._logger)
            if self._config.tester_config.evaluate_extensive:
                self._tester.evaluate_extensive()

        cprint(f'Finished.', self._logger)

    def save_checkpoint(self, tag: str = ''):
        filename = f'checkpoint_{tag}' if tag != '' else 'checkpoint'
        filename += '.ckpt'
        checkpoint = {
            'epoch': self._epoch,
        }
        for element, key in zip([self._net, self._trainer, self._environment],
                                ['net_ckpt', 'trainer_ckpt', 'environment_ckpt']):
            if element is not None:
                checkpoint[key] = element.get_checkpoint()
        os.makedirs(f'{self._config.output_path}/torch_checkpoints', exist_ok=True)
        torch.save(checkpoint, f'{self._config.output_path}/torch_checkpoints/{filename}')
        torch.save(checkpoint, f'{self._config.output_path}/torch_checkpoints/checkpoint_latest.ckpt')
        cprint(f'stored {filename}', self._logger)

    def get_checkpoint_file(self, checkpoint_dir: str) -> str:
        """
        Search in torch_checkpoints directory for
        'best' and otherwise 'latest' and otherwise checkpoint with highest tag.
        Return absolute path.
        """
        if not checkpoint_dir.endswith('torch_checkpoints') and not checkpoint_dir.endswith('.ckpt'):
            checkpoint_dir += '/torch_checkpoints'

        if len(glob(f'{checkpoint_dir}/*.ckpt')) == 0 and len(glob(f'{checkpoint_dir}/torch_checkpoints/*.ckpt')) == 0:
            cprint(f'Could not find suitable checkpoint in {checkpoint_dir}', self._logger, MessageType.error)
            time.sleep(0.1)
            raise FileNotFoundError
        # Get checkpoint in following order
        if os.path.isfile(os.path.join(checkpoint_dir, 'checkpoint_best.ckpt')):
            checkpoint_file = os.path.join(checkpoint_dir, 'checkpoint_best.ckpt')
        elif os.path.isfile(os.path.join(checkpoint_dir, 'checkpoint_latest.ckpt')):
            checkpoint_file = os.path.join(checkpoint_dir, 'checkpoint_latest.ckpt')
        else:
            checkpoints = {int(f.split('.')[0].split('_')[-1]): os.path.join(checkpoint_dir, f)
                           for f in os.listdir(checkpoint_dir)}
            checkpoint_file = checkpoints[max(checkpoints.keys())]
        return checkpoint_file

    def load_checkpoint(self, checkpoint_file: str):
        # Load params for each experiment element
        checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
        self._epoch = checkpoint['epoch'] if 'epoch' in checkpoint.keys() else 0
        for element, key in zip([self._net, self._trainer, self._environment],
                                ['net_ckpt', 'trainer_ckpt', 'environment_ckpt']):
            if element is not None and key in checkpoint.keys():
                element.load_checkpoint(checkpoint[key])
        cprint(f'loaded network from {checkpoint_file}', self._logger)

    def shutdown(self):
        if self._writer is not None:
            self._writer.close()
        if self._environment is not None:
            result = self._environment.remove()
            cprint(f'Terminated successfully? {bool(result)}', self._logger,
                   msg_type=MessageType.info if result else MessageType.warning)
        if self._data_saver is not None:
            self._data_saver.remove()
        if self._trainer is not None:
            self._trainer.remove()
        if self._evaluator is not None:
            self._evaluator.remove()
        if self._net is not None:
            self._net.remove()
        if self._episode_runner is not None:
            self._episode_runner.remove()
        [h.close() for h in self._logger.handlers]