def train(self):
        '''
        meta-training loop
        '''
        self.pretrain()
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(
                range(self.num_iterations),
                save_itrs=True,
        ):
            self._start_epoch(it_)
            self.training_mode(True)
            if it_ == 0:
                print('collecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data(self.num_initial_steps, 1, np.inf)
            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.randint(len(self.train_tasks))
                self.task_idx = idx
                self.env.reset_task(idx)
                self.enc_replay_buffer.task_buffers[idx].clear()

                # collect some trajectories with z ~ prior
                if self.num_steps_prior > 0:
                    self.collect_data(self.num_steps_prior, 1, np.inf)
                # collect some trajectories with z ~ posterior
                if self.num_steps_posterior > 0:
                    self.collect_data(self.num_steps_posterior, 1,
                                      self.update_post_train)
                # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior
                if self.num_extra_rl_steps_posterior > 0:
                    self.collect_data(self.num_extra_rl_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      add_to_enc_buffer=False)

            # Sample train tasks and compute gradient updates on parameters.
            for train_step in range(self.num_train_steps_per_itr):
                indices = np.random.choice(self.train_tasks, self.meta_batch)
                self._do_training(indices)
                self._n_train_steps_total += 1
            gt.stamp('train')

            self.training_mode(False)

            # eval
            self._try_to_eval(it_)
            gt.stamp('eval')

            self._end_epoch()
Exemplo n.º 2
0
 def run(self):
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output('progress.csv',
                                      relative_to_snapshot_dir=True)
         logger.add_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
     timer.return_global_times = True
     for _ in range(self.num_iters):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
     if self.progress_csv_file_name != 'progress.csv':
         logger.remove_tabular_output(
             self.progress_csv_file_name,
             relative_to_snapshot_dir=True,
         )
         logger.add_tabular_output(
             'progress.csv',
             relative_to_snapshot_dir=True,
         )
Exemplo n.º 3
0
    def _end_epoch(self, epoch):
        print('in _end_epoch, epoch is: {}'.format(epoch))
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        # trainer_obj = self.trainer
        # ckpt_path='ckpt.pkl'
        # logger.save_ckpt(epoch, trainer_obj, ckpt_path)
        # gt.stamp('saving')
        if epoch % 1 == 0:
            self.save_snapshot_2(epoch)
        expl_paths = self.expl_data_collector.get_epoch_paths()
        d = eval_util.get_generic_path_information(expl_paths)
        # print(d.keys())
        metric_val = d['Rewards Mean']

        cur_best_metric_val = self.get_cur_best_metric_val()
        if epoch != 0:
            self.save_snapshot_2_best_only(
                metric_val=metric_val,
                cur_best_metric_val=cur_best_metric_val,
                min_or_max='max')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 4
0
    def _try_to_eval(self, epoch):
        if self._can_evaluate():
            # save if it's time to save
            if epoch % self.freq_saving == 0:
                logger.save_extra_data(self.get_extra_data_to_save(epoch))
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)

            self.evaluate(epoch)

            logger.record_tabular(
                "Number of train calls total",
                self._n_train_steps_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 5
0
    def _try_to_eval(self, epoch=0):
        if epoch % self.save_extra_data_interval == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch), epoch)
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
    def test_epoch(
            self,
            epoch,
            save_reconstruction=True,
            save_vae=True,
            from_rl=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        kles = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            next_obs = self.get_batch(train=False)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(next_obs)
            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            loss = -1 * log_prob + beta * kle

            encoder_mean = latent_distribution_params[0]
            z_data = ptu.get_numpy(encoder_mean.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())

            if batch_idx == 0 and save_reconstruction:
                n = min(next_obs.size(0), 8)
                comparison = torch.cat([
                    next_obs[:n].narrow(start=0, length=self.imlength, dim=1)
                        .contiguous().view(
                        -1, self.input_channels, self.imsize, self.imsize
                    ).transpose(2, 3),
                    reconstructions.view(
                        self.batch_size,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    )[:n].transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'r%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)

        zs = np.array(zs)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['beta'] = beta
        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)
Exemplo n.º 7
0
def train_vae(variant, return_data=False):
    from rlkit.misc.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
Exemplo n.º 8
0
    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            if self.environment_farming:
                # Create new new eval_sampler each evaluation time in order to avoid relesed environment problem
                env_for_eval_sampler = self.farmer.force_acq_env()
                print(env_for_eval_sampler)
                self.eval_sampler = InPlacePathSampler(
                    env=env_for_eval_sampler,
                    policy=self.eval_policy,
                    max_samples=self.num_steps_per_eval + self.max_path_length,
                    max_path_length=self.max_path_length,
                )

            self.evaluate(epoch)

            # Adding env back to free_env list
            self.farmer.add_free_env(env_for_eval_sampler)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 9
0
 def train(self, start_epoch=0):
     self.pretrain()
     if start_epoch == 0:
         params = self.get_epoch_snapshot(-1)
         logger.save_itr_params(-1, params)
     self.training_mode(False)
     self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
     gt.reset()
     gt.set_def_unique(False)
     self.train_online(start_epoch=start_epoch)
Exemplo n.º 10
0
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)
Exemplo n.º 11
0
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        self.eval_data_collector.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 12
0
    def _try_to_eval(self, epoch):
        if epoch % self.freq_saving == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            if epoch % self.freq_saving == 0:
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            # if self._old_table_keys is not None:
            #     print('$$$$$$$$$$$$$$$')
            #     print(table_keys)
            #     print('\n'*4)
            #     print(self._old_table_keys)
            #     print('$$$$$$$$$$$$$$$')
            #     print(set(table_keys) - set(self._old_table_keys))
            #     print(set(self._old_table_keys) - set(table_keys))
            #     assert table_keys == self._old_table_keys, (
            #         "Table keys cannot change from iteration to iteration."
            #     )
            # self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 13
0
 def train(self):
     timer.return_global_times = True
     for _ in range(self.num_epochs):
         self._begin_epoch()
         timer.start_timer('saving')
         logger.save_itr_params(self.epoch, self._get_snapshot())
         timer.stop_timer('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         self._end_epoch()
     logger.save_itr_params(self.epoch, self._get_snapshot())
Exemplo n.º 14
0
    def train(self):
        '''
        meta-training loop
        '''
        self.pretrain()
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(
                range(self.num_iterations),
                save_itrs=True,
        ):
            self._start_epoch(it_)
            self.training_mode(True)

            # Sample train tasks and compute gradient updates on parameters.
            batch_idxes = np.random.randint(0,
                                            len(self.train_goals),
                                            size=self.meta_batch_size)
            train_batch_obj_id = self.replay_buffers.sample_training_data(
                batch_idxes, self.use_same_context)
            for _ in range(self.num_train_steps_per_itr):
                train_raw_batch = ray.get(train_batch_obj_id)
                gt.stamp('sample_training_data', unique=False)

                batch_idxes = np.random.randint(0,
                                                len(self.train_goals),
                                                size=self.meta_batch_size)
                # In this way, we can start the data sampling job for the
                # next training while doing training for the current loop.
                train_batch_obj_id = self.replay_buffers.sample_training_data(
                    batch_idxes, self.use_same_context)
                gt.stamp('set_up_sampling', unique=False)

                train_data = self.construct_training_batch(train_raw_batch)
                gt.stamp('construct_training_batch', unique=False)

                self._do_training(train_data)
                self._n_train_steps_total += 1
            gt.stamp('train')

            self.training_mode(False)

            # eval
            self._try_to_eval(it_)
            gt.stamp('eval')

            self._end_epoch()
            if it_ == self.num_iterations:
                logger.save_itr_params(it_, self.agent.get_snapshot())
Exemplo n.º 15
0
    def _end_epoch(self, epoch, num_epochs_per_eval=0):
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch, num_epochs_per_eval)

        self.expl_data_collector.end_epoch(epoch)
        # self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 16
0
    def _end_epoch(self, epoch, solved=False):
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch, solved=solved)

        self.eval_data_collector.end_epoch(epoch)
        if not solved:
            self.expl_data_collector.end_epoch(epoch)
            self.replay_buffer.end_epoch(epoch)
            self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
    def _end_epoch(self, epoch):
        #print ("core/rl_algorithm, _end_epoch(): ", "epoch: ", epoch)
        snapshot = self._get_snapshot()
        #print ("core/rl_algorithm, _end_epoch(): ", "snapshot: ", snapshot)
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 18
0
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        # only save params for the first gpu
        if ptu.dist_rank == 0:
            logger.save_itr_params(epoch, snapshot)
        gt.stamp("saving")
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 19
0
 def train(self):
     self.fix_data_set()
     logger.log("Done creating dataset.")
     num_batches_total = 0
     for epoch in range(self.num_epochs):
         for _ in range(self.num_batches_per_epoch):
             self.qf.train(True)
             self._do_training()
             num_batches_total += 1
         logger.push_prefix('Iteration #%d | ' % epoch)
         self.qf.train(False)
         self.evaluate(epoch)
         params = self.get_epoch_snapshot(epoch)
         logger.save_itr_params(epoch, params)
         logger.log("Done evaluating")
         logger.pop_prefix()
Exemplo n.º 20
0
    def _end_epoch(self, epoch):
        snapshot = self._get_snapshot()
        logger.save_itr_params(epoch, snapshot)
        gt.stamp('saving')
        self._log_stats(epoch)

        if self.collect_actions and epoch % self.collect_actions_every == 0:
            self._log_actions(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 21
0
 def train(self, start_epoch=0):
     self.pretrain()
     if start_epoch == 0:
         params = self.get_epoch_snapshot(-1)
         logger.save_itr_params(-1, params)
     self.training_mode(False)
     self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
     gt.reset()
     gt.set_def_unique(False)
     if self.collection_mode == 'online':
         self.train_online(start_epoch=start_epoch)
     elif self.collection_mode == 'batch':
         self.train_batch(start_epoch=start_epoch)
     else:
         raise TypeError("Invalid collection_mode: {}".format(
             self.collection_mode))
Exemplo n.º 22
0
    def _try_to_eval(self,
                     epoch,
                     eval_all=False,
                     eval_train_offline=True,
                     animated=False):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch, eval_all, eval_train_offline, animated)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs.get('train', [0])[-1]
            sample_time = times_itrs.get('sample', [0])[-1]
            eval_time = times_itrs.get('eval', [0])[-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 23
0
    def _end_epoch(self, epoch):
        if not self.trainer.discrete:
            snapshot = self._get_snapshot()
            logger.save_itr_params(epoch, snapshot)
            # if snapshot['evaluation/Average Returns'] >= self.best_rewrad:
            #     self.best_rewrad = snapshot['evaluation/Average Returns']

            gt.stamp('saving')
        self._log_stats(epoch)

        self.expl_data_collector.end_epoch(epoch)
        self.eval_data_collector.end_epoch(epoch)
        self.replay_buffer.end_epoch(epoch)
        self.trainer.end_epoch(epoch)

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
Exemplo n.º 24
0
def create_policy(variant):
    bottom_snapshot = joblib.load(variant['bottom_path'])
    column_snapshot = joblib.load(variant['column_path'])
    policy = variant['combiner_class'](
        policy1=bottom_snapshot['naf_policy'],
        policy2=column_snapshot['naf_policy'],
    )
    env = bottom_snapshot['env']
    logger.save_itr_params(0, dict(
        policy=policy,
        env=env,
    ))
    path = rollout(
        env,
        policy,
        max_path_length=variant['max_path_length'],
        animated=variant['render'],
    )
    env.log_diagnostics([path])
    logger.dump_tabular()
Exemplo n.º 25
0
    def train(self):
        for epoch in range(self.num_epochs):
            logger.push_prefix('Iteration #%d | ' % epoch)

            start_time = time.time()
            for _ in range(self.num_steps_per_epoch):
                batch = self.get_batch()
                train_dict = self.get_train_dict(batch)

                self.policy_optimizer.zero_grad()
                policy_loss = train_dict['Policy Loss']
                policy_loss.backward()
                self.policy_optimizer.step()
            logger.log("Train time: {}".format(time.time() - start_time))

            start_time = time.time()
            self.evaluate(epoch)
            logger.log("Eval time: {}".format(time.time() - start_time))

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            logger.pop_prefix()
Exemplo n.º 26
0
def train(dataset_generator,
          n_start_samples,
          projection=project_samples_square_np,
          n_samples_to_add_per_epoch=1000,
          n_epochs=100,
          z_dim=1,
          hidden_size=32,
          save_period=10,
          append_all_data=True,
          full_variant=None,
          dynamics_noise=0,
          decoder_output_var='learned',
          num_bins=5,
          skew_config=None,
          use_perfect_samples=False,
          use_perfect_density=False,
          vae_reset_period=0,
          vae_kwargs=None,
          use_dataset_generator_first_epoch=True,
          **kwargs):
    """
    Sanitize Inputs
    """
    assert skew_config is not None
    if not (use_perfect_density and use_perfect_samples):
        assert vae_kwargs is not None
    if vae_kwargs is None:
        vae_kwargs = {}

    report = HTMLReport(
        logger.get_snapshot_dir() + '/report.html',
        images_per_row=10,
    )
    dynamics = Dynamics(projection, dynamics_noise)
    if full_variant:
        report.add_header("Variant")
        report.add_text(
            json.dumps(
                ppp.dict_to_safe_json(full_variant, sort=True),
                indent=2,
            ))

    vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
        decoder_output_var,
        hidden_size,
        z_dim,
        vae_kwargs,
    )
    vae.to(ptu.device)

    epochs = []
    losses = []
    kls = []
    log_probs = []
    hist_heatmap_imgs = []
    vae_heatmap_imgs = []
    sample_imgs = []
    entropies = []
    tvs_to_uniform = []
    entropy_gains_from_reweighting = []
    p_theta = Histogram(num_bins)
    p_new = Histogram(num_bins)

    orig_train_data = dataset_generator(n_start_samples)
    train_data = orig_train_data
    start = time.time()
    for epoch in progressbar(range(n_epochs)):
        p_theta = Histogram(num_bins)
        if epoch == 0 and use_dataset_generator_first_epoch:
            vae_samples = dataset_generator(n_samples_to_add_per_epoch)
        else:
            if use_perfect_samples and epoch != 0:
                # Ideally the VAE = p_new, but in practice, it won't be...
                vae_samples = p_new.sample(n_samples_to_add_per_epoch)
            else:
                vae_samples = vae.sample(n_samples_to_add_per_epoch)
        projected_samples = dynamics(vae_samples)
        if append_all_data:
            train_data = np.vstack((train_data, projected_samples))
        else:
            train_data = np.vstack((orig_train_data, projected_samples))

        p_theta.fit(train_data)
        if use_perfect_density:
            prob = p_theta.compute_density(train_data)
        else:
            prob = vae.compute_density(train_data)
        all_weights = prob_to_weight(prob, skew_config)
        p_new.fit(train_data, weights=all_weights)
        if epoch == 0 or (epoch + 1) % save_period == 0:
            epochs.append(epoch)
            report.add_text("Epoch {}".format(epoch))
            hist_heatmap_img = visualize_histogram(p_theta, skew_config,
                                                   report)
            vae_heatmap_img = visualize_vae(
                vae,
                skew_config,
                report,
                resolution=num_bins,
            )
            sample_img = visualize_vae_samples(
                epoch,
                train_data,
                vae,
                report,
                dynamics,
            )

            visualize_samples(
                p_theta.sample(n_samples_to_add_per_epoch),
                report,
                title="P Theta/RB Samples",
            )
            visualize_samples(
                p_new.sample(n_samples_to_add_per_epoch),
                report,
                title="P Adjusted Samples",
            )
            hist_heatmap_imgs.append(hist_heatmap_img)
            vae_heatmap_imgs.append(vae_heatmap_img)
            sample_imgs.append(sample_img)
            report.save()

            Image.fromarray(
                hist_heatmap_img).save(logger.get_snapshot_dir() +
                                       '/hist_heatmap{}.png'.format(epoch))
            Image.fromarray(
                vae_heatmap_img).save(logger.get_snapshot_dir() +
                                      '/hist_heatmap{}.png'.format(epoch))
            Image.fromarray(sample_img).save(logger.get_snapshot_dir() +
                                             '/samples{}.png'.format(epoch))
        """
        train VAE to look like p_new
        """
        if sum(all_weights) == 0:
            all_weights[:] = 1
        if vae_reset_period > 0 and epoch % vae_reset_period == 0:
            vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
                decoder_output_var,
                hidden_size,
                z_dim,
                vae_kwargs,
            )
            vae.to(ptu.device)
        vae.fit(train_data, weights=all_weights)
        epoch_stats = vae.get_epoch_stats()

        losses.append(np.mean(epoch_stats['losses']))
        kls.append(np.mean(epoch_stats['kls']))
        log_probs.append(np.mean(epoch_stats['log_probs']))
        entropies.append(p_theta.entropy())
        tvs_to_uniform.append(p_theta.tv_to_uniform())
        entropy_gain = p_new.entropy() - p_theta.entropy()
        entropy_gains_from_reweighting.append(entropy_gain)

        for k in sorted(epoch_stats.keys()):
            logger.record_tabular(k, epoch_stats[k])

        logger.record_tabular("Epoch", epoch)
        logger.record_tabular('Entropy ', p_theta.entropy())
        logger.record_tabular('KL from uniform', p_theta.kl_from_uniform())
        logger.record_tabular('TV to uniform', p_theta.tv_to_uniform())
        logger.record_tabular('Entropy gain from reweight', entropy_gain)
        logger.record_tabular('Total Time (s)', time.time() - start)
        logger.dump_tabular()
        logger.save_itr_params(
            epoch, {
                'vae': vae,
                'train_data': train_data,
                'vae_samples': vae_samples,
                'dynamics': dynamics,
            })

    report.add_header("Training Curves")
    plot_curves(
        [
            ("Training Loss", losses),
            ("KL", kls),
            ("Log Probs", log_probs),
            ("Entropy Gain from Reweighting", entropy_gains_from_reweighting),
        ],
        report,
    )
    plot_curves(
        [
            ("Entropy", entropies),
            ("TV to Uniform", tvs_to_uniform),
        ],
        report,
    )
    report.add_text("Max entropy: {}".format(p_theta.max_entropy()))
    report.save()

    for filename, imgs in [
        ("hist_heatmaps", hist_heatmap_imgs),
        ("vae_heatmaps", vae_heatmap_imgs),
        ("samples", sample_imgs),
    ]:
        video = np.stack(imgs)
        vwrite(
            logger.get_snapshot_dir() + '/{}.mp4'.format(filename),
            video,
        )
        local_gif_file_path = '{}.gif'.format(filename)
        gif_file_path = '{}/{}'.format(logger.get_snapshot_dir(),
                                       local_gif_file_path)
        gif(gif_file_path, video)
        report.add_image(local_gif_file_path, txt=filename, is_url=True)
    report.save()
Exemplo n.º 27
0
    def train(self):
        '''
        meta-training loop
        '''
        self.pretrain()
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()
        self.train_obs = self._start_new_rollout()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(
                range(self.num_iterations),
                save_itrs=True,
        ):
            self._start_epoch(it_)
            self.training_mode(True)
            if it_ == 0:
                print('collecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:
                    print('train task', idx)
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data_sampling_from_prior(
                        num_samples=self.max_path_length * 10,
                        resample_z_every_n=self.max_path_length,
                        eval_task=False)
                """
                for idx in self.eval_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    # TODO: make number of initial trajectories a parameter
                    self.collect_data_sampling_from_prior(num_samples=self.max_path_length * 20,
                                                          resample_z_every_n=self.max_path_length,
                                                          eval_task=True)
                """

            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.randint(len(self.train_tasks))
                self.task_idx = idx
                self.env.reset_task(idx)

                # TODO: there may be more permutations of sampling/adding to encoding buffer we may wish to try
                if self.train_embedding_source == 'initial_pool':
                    # embeddings are computed using only the initial pool of data
                    # sample data from posterior to train RL algorithm
                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=False)
                elif self.train_embedding_source == 'posterior_only':
                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        eval_task=False,
                        add_to_enc_buffer=True)
                elif self.train_embedding_source == 'online_exploration_trajectories':
                    # embeddings are computed using only data collected using the prior
                    # sample data from posterior to train RL algorithm
                    self.enc_replay_buffer.task_buffers[idx].clear()
                    # resamples using current policy, conditioned on prior

                    self.collect_data_sampling_from_prior(
                        num_samples=self.num_steps_per_task,
                        resample_z_every_n=self.max_path_length,
                        add_to_enc_buffer=True)

                    self.env.reset_task(idx)

                    self.collect_data_from_task_posterior(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=False,
                        viz=True)

                elif self.train_embedding_source == 'online_on_policy_trajectories':
                    # sample from prior, then sample more from the posterior
                    # embeddings computed from both prior and posterior data
                    self.enc_replay_buffer.task_buffers[idx].clear()
                    self.collect_data_online(
                        idx=idx,
                        num_samples=self.num_steps_per_task,
                        add_to_enc_buffer=True)
                else:
                    raise Exception(
                        "Invalid option for computing train embedding {}".
                        format(self.train_embedding_source))

            # Sample train tasks and compute gradient updates on parameters.
            for train_step in range(self.num_train_steps_per_itr):
                indices = np.random.choice(self.train_tasks, self.meta_batch)
                self._do_training(indices, train_step)
                self._n_train_steps_total += 1
            gt.stamp('train')

            #self.training_mode(False)

            # eval
            self._try_to_eval(it_)
            gt.stamp('eval')

            self._end_epoch()
Exemplo n.º 28
0
def train_pixelcnn(
    vqvae=None,
    vqvae_path=None,
    num_epochs=100,
    batch_size=32,
    n_layers=15,
    dataset_path=None,
    save=True,
    save_period=10,
    cached_dataset_path=False,
    trainer_kwargs=None,
    model_kwargs=None,
    data_filter_fn=lambda x: x,
    debug=False,
    data_size=float('inf'),
    num_train_batches_per_epoch=None,
    num_test_batches_per_epoch=None,
    train_img_loader=None,
    test_img_loader=None,
):
    trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs
    model_kwargs = {} if model_kwargs is None else model_kwargs

    # Load VQVAE + Define Args
    if vqvae is None:
        vqvae = load_local_or_remote_file(vqvae_path)
        vqvae.to(ptu.device)
        vqvae.eval()

    root_len = vqvae.root_len
    num_embeddings = vqvae.num_embeddings
    embedding_dim = vqvae.embedding_dim
    cond_size = vqvae.num_embeddings
    imsize = vqvae.imsize
    discrete_size = root_len * root_len
    representation_size = embedding_dim * discrete_size
    input_channels = vqvae.input_channels
    imlength = imsize * imsize * input_channels

    log_dir = logger.get_snapshot_dir()

    # Define data loading info
    new_path = osp.join(log_dir, 'pixelcnn_data.npy')

    def prep_sample_data(cached_path):
        data = load_local_or_remote_file(cached_path).item()
        train_data = data['train']
        test_data = data['test']
        return train_data, test_data

    def encode_dataset(path, object_list):
        data = load_local_or_remote_file(path)
        data = data.item()
        data = data_filter_fn(data)

        all_data = []
        n = min(data["observations"].shape[0], data_size)

        for i in tqdm(range(n)):
            obs = ptu.from_numpy(data["observations"][i] / 255.0)
            latent = vqvae.encode(obs, cont=False)
            all_data.append(latent)

        encodings = ptu.get_numpy(torch.stack(all_data, dim=0))
        return encodings

    if train_img_loader:
        _, test_loader, test_batch_loader = create_conditional_data_loader(
            test_img_loader, 80, vqvae, "test2")  # 80
        _, train_loader, train_batch_loader = create_conditional_data_loader(
            train_img_loader, 2000, vqvae, "train2")  # 2000
    else:
        if cached_dataset_path:
            train_data, test_data = prep_sample_data(cached_dataset_path)
        else:
            train_data = encode_dataset(dataset_path['train'],
                                        None)  # object_list)
            test_data = encode_dataset(dataset_path['test'], None)
        dataset = {'train': train_data, 'test': test_data}
        np.save(new_path, dataset)

        _, _, train_loader, test_loader, _ = \
            rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size)

    #train_dataset = InfiniteBatchLoader(train_loader)
    #test_dataset = InfiniteBatchLoader(test_loader)

    print("Finished loading data")

    model = GatedPixelCNN(num_embeddings,
                          root_len**2,
                          n_classes=representation_size,
                          **model_kwargs).to(ptu.device)
    trainer = PixelCNNTrainer(
        model,
        vqvae,
        batch_size=batch_size,
        **trainer_kwargs,
    )

    print("Starting training")

    BEST_LOSS = 999
    for epoch in range(num_epochs):
        should_save = (epoch % save_period == 0) and (epoch > 0)
        trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch)
        trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch)

        test_data = test_batch_loader.random_batch(bz)["x"]
        train_data = train_batch_loader.random_batch(bz)["x"]
        trainer.dump_samples(epoch, test_data, test=True)
        trainer.dump_samples(epoch, train_data, test=False)

        if should_save:
            logger.save_itr_params(epoch, model)

        stats = trainer.get_diagnostics()

        cur_loss = stats["test/loss"]
        if cur_loss < BEST_LOSS:
            BEST_LOSS = cur_loss
            vqvae.set_pixel_cnn(model)
            logger.save_extra_data(vqvae, 'best_vqvae', mode='torch')
        else:
            return vqvae

        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

    return vqvae
Exemplo n.º 29
0
    def train(self):
        '''
        meta-training loop
        '''
        self.pretrain()
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(
                range(self.num_iterations),
                save_itrs=True,
        ):
            self._start_epoch(it_)
            self.training_mode(True)
            print("\nIteration:{}".format(it_+1))
            if it_ == 0:#                                                                算法第一步,初始化每个任务的buffer
                print('\nCollecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:#在训练开始之前,为每个任务采集2000条transition
                    self.task_idx = idx#更改当前任务idx
                    self.env.reset_task(idx)#重置任务
                    self.collect_data(self.num_initial_steps, 1, np.inf)#采集num_initial_steps条轨迹c并利用q(z|c)更新self.z
                    # print("task id:", self.task_idx, " env:", self.replay_buffer.env)
                    # print("buffer ", self.task_idx, ":", self.replay_buffer.task_buffers[self.task_idx].__dict__.items())
            # Sample data from train tasks.
            print("\nFinishing collecting initial pool of data")
            print("\nSampling data from train tasks for Meta-training")
            for i in range(self.num_tasks_sample):#对于所有的train_tasks,随机从中取5个,然后为每个任务的buffer采集num_steps_prior + num_extra_rl_steps_posterior条transition
                print("\nSample data , round{}".format(i+1))#为每个任务的enc_buffer采集num_steps_prior条transition
                idx = np.random.randint(len(self.train_tasks))#train_tasks里面随便选一个task
                self.task_idx = idx
                self.env.reset_task(idx)#task重置
                self.enc_replay_buffer.task_buffers[idx].clear()#清除对应的enc_bufffer

                # collect some trajectories with z ~ prior
                if self.num_steps_prior > 0:
                    print("\ncollect some trajectories with z ~ prior")
                    self.collect_data(self.num_steps_prior, 1, np.inf)#利用z的先验采集num_steps_prior条transition
                # collect some trajectories with z ~ posterior
                if self.num_steps_posterior > 0:
                    print("\ncollect some trajectories with z ~ posterior")
                    self.collect_data(self.num_steps_posterior, 1, self.update_post_train)#利用后验的z收集轨迹
                # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior
                if self.num_extra_rl_steps_posterior > 0:
                    print("\ncollect some trajectories for policy update only")
                    self.collect_data(self.num_extra_rl_steps_posterior, 1, self.update_post_train, add_to_enc_buffer=False)#利用后验的z收集num_extra_rl_steps_posterior条轨迹,仅用于策略
            print("\nFinishing sample data from train tasks")
            # Sample train tasks and compute gradient updates on parameters.
            print("\nStrating Meta-training , Episode {}".format(it_))
            for train_step in range(self.num_train_steps_per_itr):#每轮迭代计算num_train_steps_per_itr次梯度              500x2000=1000000
                indices = np.random.choice(self.train_tasks, self.meta_batch)#train_tasks中随机取meta_batch个task , sample RL batch b~B
                if ((train_step + 1) % 500 == 0):
                    print("\nTraining step {}".format(train_step + 1))
                    print("Indices: {}".format(indices))
                    print("alpha:{}".format(self.alpha))
                self._do_training(indices)#梯度下降
                self._n_train_steps_total += 1
            gt.stamp('train')

            self.training_mode(False)

            # eval
            self._try_to_eval(it_)
            gt.stamp('eval')

            self._end_epoch()
Exemplo n.º 30
0
    def test_epoch(
        self,
        epoch,
        sample_batch=None,
        key=None,
        save_reconstruction=True,
        save_vae=True,
        from_rl=False,
        save_prefix='r',
        only_train_vae=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        triplet_losses = []
        matching_losses = []
        vae_matching_losses = []
        kles = []
        lstm_kles = []
        ae_losses = []
        contrastive_losses = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            # print(batch_idx)
            if sample_batch is not None:
                data = sample_batch(self.batch_size, key=key)
                next_obs = data['next_obs']
            else:
                next_obs = self.get_batch(epoch=epoch)

            reconstructions, obs_distribution_params, vae_latent_distribution_params, lstm_latent_encodings = self.model(
                next_obs)
            latent_encodings = lstm_latent_encodings
            vae_mu = vae_latent_distribution_params[0]  # this is lstm inputs
            latent_distribution_params = vae_latent_distribution_params

            triplet_loss = ptu.zeros(1)
            for tri_idx, triplet_type in enumerate(self.triplet_loss_type):
                if triplet_type == 1 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss(latent_encodings)
                elif triplet_type == 2 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_2(next_obs)
                elif triplet_type == 3 and not only_train_vae:
                    triplet_loss += self.triplet_loss_coef[
                        tri_idx] * self.triplet_loss_3(next_obs)

            if self.matching_loss_coef > 0 and not only_train_vae:
                matching_loss = self.matching_loss(next_obs)
            else:
                matching_loss = ptu.zeros(1)

            if self.vae_matching_loss_coef > 0:
                matching_loss_vae = self.matching_loss_vae(next_obs)
            else:
                matching_loss_vae = ptu.zeros(1)

            if self.contrastive_loss_coef > 0 and not only_train_vae:
                contrastive_loss = self.contrastive_loss(next_obs)
            else:
                contrastive_loss = ptu.zeros(1)

            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            lstm_kle = ptu.zeros(1)

            ae_loss = F.mse_loss(
                latent_encodings.view((-1, self.model.representation_size)),
                vae_mu.detach())
            ae_losses.append(ae_loss.item())

            loss = -self.recon_loss_coef * log_prob + beta * kle + \
                        self.matching_loss_coef * matching_loss + self.ae_loss_coef * ae_loss + triplet_loss + \
                            self.vae_matching_loss_coef * matching_loss_vae + self.contrastive_loss_coef * contrastive_loss

            losses.append(loss.item())
            log_probs.append(log_prob.item())
            triplet_losses.append(triplet_loss.item())
            matching_losses.append(matching_loss.item())
            vae_matching_losses.append(matching_loss_vae.item())
            kles.append(kle.item())
            lstm_kles.append(lstm_kle.item())
            contrastive_losses.append(contrastive_loss.item())

            if batch_idx == 0 and save_reconstruction:
                seq_len, batch_size, feature_size = next_obs.shape
                show_obs = next_obs[0][:8]
                reconstructions = reconstructions.view(
                    (seq_len, batch_size, feature_size))[0][:8]
                comparison = torch.cat([
                    show_obs.narrow(start=0, length=self.imlength,
                                    dim=1).contiguous().view(
                                        -1, self.input_channels, self.imsize,
                                        self.imsize).transpose(2, 3),
                    reconstructions.view(
                        -1,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    ).transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    '{}{}.png'.format(save_prefix, epoch))
                save_image(comparison.data.cpu(), save_dir, nrow=8)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/triplet loss'] = np.mean(triplet_losses)
        self.eval_statistics['test/vae matching loss'] = np.mean(
            vae_matching_losses)
        self.eval_statistics['test/matching loss'] = np.mean(matching_losses)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/lstm KL'] = np.mean(lstm_kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['test/contrastive loss'] = np.mean(
            contrastive_losses)
        self.eval_statistics['beta'] = beta
        self.eval_statistics['test/ae loss'] = np.mean(ae_losses)

        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)

        torch.cuda.empty_cache()