Exemplo n.º 1
0
    def collect_paths(self, idx, epoch, eval_task=False):
        self.task_idx = idx
        dprint('Task:', idx)
        self.env.reset_task(idx)
        # if eval_task:
        #     num_evals = self.num_evals
        # else:
        num_evals = 1

        paths = []
        for _ in range(num_evals):
            paths += self.obtain_eval_paths(idx,
                                            eval_task=eval_task,
                                            deterministic=True)

        # goal = self.env._goal
        # for path in paths:
        #     path['goal'] = goal # goal

        # save the paths for visualization, only useful for point mass
        if self.dump_eval_paths:
            split = 'test' if eval_task else 'train'
            logger.save_extra_data(
                paths,
                path='eval_trajectories/{}-task{}-epoch{}'.format(
                    split, idx, epoch))
        return paths
Exemplo n.º 2
0
    def _try_to_eval(self, epoch=0):
        if epoch % self.save_extra_data_interval == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch), epoch)
        if self._can_evaluate():
            self.evaluate(epoch)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
def train_vae_and_update_variant(variant):
    from rlkit.core import logger

    skewfit_variant = variant["skewfit_variant"]
    train_vae_variant = variant["train_vae_variant"]
    if skewfit_variant.get("vae_path", None) is None:
        logger.remove_tabular_output("progress.csv",
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output("vae_progress.csv",
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(train_vae_variant,
                                                       return_data=True)
        if skewfit_variant.get("save_vae_data", False):
            skewfit_variant["vae_train_data"] = vae_train_data
            skewfit_variant["vae_test_data"] = vae_test_data
        logger.save_extra_data(vae, "vae.pkl", mode="pickle")
        logger.remove_tabular_output("vae_progress.csv",
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output("progress.csv",
                                  relative_to_snapshot_dir=True)
        skewfit_variant["vae_path"] = vae  # just pass the VAE directly
    else:
        if skewfit_variant.get("save_vae_data", False):
            vae_train_data, vae_test_data, info = generate_vae_dataset(
                train_vae_variant["generate_vae_dataset_kwargs"])
            skewfit_variant["vae_train_data"] = vae_train_data
            skewfit_variant["vae_test_data"] = vae_test_data
Exemplo n.º 4
0
    def _try_to_eval(self, epoch):
        if self._can_evaluate():
            # save if it's time to save
            if epoch % self.freq_saving == 0:
                logger.save_extra_data(self.get_extra_data_to_save(epoch))
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)

            self.evaluate(epoch)

            logger.record_tabular(
                "Number of train calls total",
                self._n_train_steps_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 5
0
    def collect_paths(self, idx, epoch, run):
        self.task_idx = idx
        self.env.reset_task(idx)

        self.agent.clear_z()
        paths = []
        num_transitions = 0
        num_trajs = 0
        while num_transitions < self.num_steps_per_eval:
            path, num = self.sampler.obtain_samples(
                deterministic=self.eval_deterministic,
                max_samples=self.num_steps_per_eval - num_transitions,
                max_trajs=1,
                accum_context=True)
            paths += path
            num_transitions += num
            num_trajs += 1
            if num_trajs >= self.num_exp_traj_eval:
                self.agent.infer_posterior(self.agent.context)

        if self.sparse_rewards:
            for p in paths:
                sparse_rewards = np.stack(
                    e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1)
                p['rewards'] = sparse_rewards

        # save the paths for visualization, only useful for point mass
        if self.dump_eval_paths:
            logger.save_extra_data(
                paths,
                path='eval_trajectories/task{}-epoch{}-run{}'.format(
                    idx, epoch, run))

        return paths
def train_vae_and_update_variant(variant):
    from rlkit.core import logger
    skewfit_variant = variant['skewfit_variant']
    train_vae_variant = variant['train_vae_variant']
    if skewfit_variant.get('vae_path', None) is None:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(
            train_vae_variant, variant['other_variant'], return_data=True)
        if skewfit_variant.get('save_vae_data', False):
            skewfit_variant['vae_train_data'] = vae_train_data
            skewfit_variant['vae_test_data'] = vae_test_data
        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        skewfit_variant['vae_path'] = vae  # just pass the VAE directly
    else:
        if skewfit_variant.get('save_vae_data', False):
            vae_train_data, vae_test_data, info = generate_vae_dataset(
                train_vae_variant['generate_vae_dataset_kwargs'])
            skewfit_variant['vae_train_data'] = vae_train_data
            skewfit_variant['vae_test_data'] = vae_test_data
Exemplo n.º 7
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Exemplo n.º 8
0
def train_vae(variant, return_data=False):
    from rlkit.misc.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        # statistics.update(eval_util.get_generic_path_information(
        #     self._exploration_paths, stat_prefix="Exploration",
        # ))

        for mode in ['meta_train', 'meta_test']:
            logger.log("Collecting samples for evaluation")
            test_paths = self.obtain_eval_samples(epoch, mode=mode)

            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths,
                    stat_prefix="Test " + mode,
                ))
            # print(statistics.keys())
            if hasattr(self.env, "log_diagnostics"):
                self.env.log_diagnostics(test_paths)
            if hasattr(self.env, "log_statistics"):
                log_stats = self.env.log_statistics(test_paths)
                new_log_stats = OrderedDict(
                    (k + ' ' + mode, v) for k, v in log_stats.items())
                statistics.update(new_log_stats)

            average_returns = rlkit.core.eval_util.get_average_returns(
                test_paths)
            statistics['AverageReturn ' + mode] = average_returns

            if self.render_eval_paths:
                self.env.render_paths(test_paths)

        # meta_test_this_epoch = statistics['Percent_Solved meta_test']
        # meta_test_this_epoch = statistics['Percent_Solved meta_test']
        meta_test_this_epoch = statistics['AverageReturn meta_test']
        if meta_test_this_epoch >= self.best_meta_test:
            # make sure you set save_algorithm to true then call save_extra_data
            prev_save_alg = self.save_algorithm
            self.save_algorithm = True
            if self.save_best:
                if epoch > self.save_best_after_epoch:
                    temp = self.replay_buffer
                    self.replay_buffer = None
                    logger.save_extra_data(self.get_extra_data_to_save(epoch),
                                           'best_meta_test.pkl')
                    self.replay_buffer = temp
                    self.best_meta_test = meta_test_this_epoch
                    print('\n\nSAVED ALG AT EPOCH %d\n\n' % epoch)
            self.save_algorithm = prev_save_alg

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.plotter:
            self.plotter.draw()
Exemplo n.º 10
0
    def _try_to_eval(self, epoch):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            if self.environment_farming:
                # Create new new eval_sampler each evaluation time in order to avoid relesed environment problem
                env_for_eval_sampler = self.farmer.force_acq_env()
                print(env_for_eval_sampler)
                self.eval_sampler = InPlacePathSampler(
                    env=env_for_eval_sampler,
                    policy=self.eval_policy,
                    max_samples=self.num_steps_per_eval + self.max_path_length,
                    max_path_length=self.max_path_length,
                )

            self.evaluate(epoch)

            # Adding env back to free_env list
            self.farmer.add_free_env(env_for_eval_sampler)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 11
0
 def _start_epoch(self, epoch):
     self._epoch_start_time = time.time()
     self._exploration_paths = []
     self._do_train_time = 0
     logger.push_prefix('Iteration #%d | ' % epoch)
     if epoch in self.save_extra_manual_beginning_epoch_list:
         logger.save_extra_data(
             self.get_extra_data_to_save(epoch),
             file_name='extra_snapshot_beginning_itr{}'.format(epoch),
             mode='cloudpickle',
         )
Exemplo n.º 12
0
    def _try_to_eval(self, epoch):
        if epoch % self.freq_saving == 0:
            logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch)

            if epoch % self.freq_saving == 0:
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            # if self._old_table_keys is not None:
            #     print('$$$$$$$$$$$$$$$')
            #     print(table_keys)
            #     print('\n'*4)
            #     print(self._old_table_keys)
            #     print('$$$$$$$$$$$$$$$')
            #     print(set(table_keys) - set(self._old_table_keys))
            #     print(set(self._old_table_keys) - set(table_keys))
            #     assert table_keys == self._old_table_keys, (
            #         "Table keys cannot change from iteration to iteration."
            #     )
            # self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs['train'][-1]
            sample_time = times_itrs['sample'][-1]
            eval_time = times_itrs['eval'][-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 13
0
    def collect_paths(self, idx, epoch, run):
        self.agent.clear_z()
        paths = []
        num_transitions = 0
        num_trajs = 0
        init_context = None
        infer_posterior_at_start = False
        while num_transitions < self.num_steps_per_eval:
            # We follow the PEARL protocol and never update the posterior or resample z within an episode during evaluation.
            if idx in self.fake_task_idx_to_z:
                initialized_z_reward = self.fake_task_idx_to_z[idx]
            else:
                initialized_z_reward = None
            loop_paths, num = self.sampler.obtain_samples(
                deterministic=self.eval_deterministic,
                max_samples=self.num_steps_per_eval - num_transitions,
                max_trajs=1,
                accum_context=True,
                initial_context=init_context,
                task_idx=idx,
                resample_latent_period=self.
                exploration_resample_latent_period,  # PEARL had this=0.
                update_posterior_period=0,  # following PEARL protocol
                infer_posterior_at_start=infer_posterior_at_start,
                initialized_z_reward=initialized_z_reward,
                use_predicted_reward=initialized_z_reward is not None,
            )
            paths += loop_paths
            num_transitions += num
            num_trajs += 1
            # accumulated contexts across rollouts
            init_context = paths[-1]['context']  # TODO clean hack
            if num_trajs >= self.num_exp_traj_eval:
                infer_posterior_at_start = True

        if self.sparse_rewards:
            for p in paths:
                sparse_rewards = np.stack(
                    e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1)
                p['rewards'] = sparse_rewards

        goal = self.env._goal
        for path in paths:
            path['goal'] = goal  # goal

        # save the paths for visualization, only useful for point mass
        if self.dump_eval_paths and epoch >= 0:
            logger.save_extra_data(
                paths,
                file_name='eval_trajectories/task{}-epoch{}-run{}'.format(
                    idx, epoch, run))

        return paths
Exemplo n.º 14
0
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        statistics = OrderedDict()
        try:
            statistics.update(self.eval_statistics)
            self.eval_statistics = None
        except:
            print('No Stats to Eval')

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))

        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)
        if hasattr(self.env, "log_statistics"):
            statistics.update(self.env.log_statistics(test_paths))
        if epoch % self.freq_log_visuals == 0:
            if hasattr(self.env, "log_visuals"):
                self.env.log_visuals(test_paths, epoch,
                                     logger.get_snapshot_dir())

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        best_statistic = statistics[self.best_key]
        if best_statistic > self.best_statistic_so_far:
            self.best_statistic_so_far = best_statistic
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {'epoch': epoch, 'statistics': statistics}
                data_to_save.update(self.get_epoch_snapshot(epoch))
                logger.save_extra_data(data_to_save, 'best.pkl')
                print('\n\nSAVED BEST\n\n')
Exemplo n.º 15
0
    def _try_to_eval(self,
                     epoch,
                     eval_all=False,
                     eval_train_offline=True,
                     animated=False):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            self.evaluate(epoch, eval_all, eval_train_offline, animated)

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration.")
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            times_itrs = gt.get_times().stamps.itrs
            train_time = times_itrs.get('train', [0])[-1]
            sample_time = times_itrs.get('sample', [0])[-1]
            eval_time = times_itrs.get('eval', [0])[-1] if epoch > 0 else 0
            epoch_time = train_time + sample_time + eval_time
            total_time = gt.get_times().total

            logger.record_tabular('Train Time (s)', train_time)
            logger.record_tabular('(Previous) Eval Time (s)', eval_time)
            logger.record_tabular('Sample Time (s)', sample_time)
            logger.record_tabular('Epoch Time (s)', epoch_time)
            logger.record_tabular('Total Train Time (s)', total_time)

            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
Exemplo n.º 16
0
    def collect_paths(self, idx, epoch, run):
        self.task_idx = idx
        self.env.reset_task(idx)

        self.agent.clear_z()
        paths = []
        all_zs = []
        num_transitions = 0
        num_trajs = 0
        while num_transitions < self.num_steps_per_eval:
            path, num = self.sampler.obtain_samples(
                deterministic=self.eval_deterministic,
                max_samples=self.num_steps_per_eval - num_transitions,
                max_trajs=1,
                accum_context=True)
            paths += path
            num_transitions += num
            num_trajs += 1
            if num_trajs >= self.num_exp_traj_eval:
                self.agent.infer_posterior(self.agent.context)
                all_zs.append({
                    'z_mean':
                    self.agent.z_means.detach().cpu().numpy(),
                    'z_vars':
                    self.agent.z_vars.detach().cpu().numpy(),
                    'z_sample':
                    self.agent.z.detach().cpu().numpy()
                })

        if self.sparse_rewards:
            for p in paths:
                sparse_rewards = np.stack(
                    e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1)
                p['rewards'] = sparse_rewards

        #import ipdb ; ipdb.set_trace()

        if self.dump_eval_paths:
            logger.save_extra_data({
                'paths': paths,
                'zs': all_zs
            },
                                   _dir_annotation='inference/task_' +
                                   str(idx))

        return paths
Exemplo n.º 17
0
    def collect_paths(self, idx, epoch, run, animated=False):
        # print ('enter collect path')
        self.task_idx = idx
        self.env.reset_task(idx)

        self.agent.clear_z()
        paths = []
        num_transitions = 0
        num_trajs = 0
        while num_transitions < self.num_steps_per_eval:
            path, num = self.sampler.obtain_samples(deterministic=self.eval_deterministic, \
                max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, \
                accum_context=True and self.glob, animated=animated, glob=self.glob)
            paths += path
            num_transitions += num
            num_trajs += 1
            if num_trajs >= self.num_exp_traj_eval and self.glob:
                self.agent.infer_posterior(self.agent.context)

        if self.sparse_rewards:
            for p in paths:
                sparse_rewards = np.stack(
                    e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1)
                p['rewards'] = sparse_rewards

        goal = self.env._goal
        for path in paths:
            path['goal'] = goal  # goal
        if animated:
            for i in range(len(paths)):
                video_writer = imageio.get_writer(os.path.join(
                    logger.get_snapshot_dir(),
                    'task{}-epoch{}-run{}.mp4'.format(idx, epoch, i)),
                                                  fps=20)
                for j in paths[i]['frames']:
                    video_writer.append_data(j)
                video_writer.close()
        # save the paths for visualization, only useful for point mass
        if self.dump_eval_paths:
            logger.save_extra_data(
                paths,
                path='eval_trajectories/task{}-epoch{}-run{}'.format(
                    idx, epoch, run))

        return paths
Exemplo n.º 18
0
def train_vae(variant):
    #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5'
    #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    # train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'
    # test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'

    train_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10kActions.h5'
    test_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10kActions.h5'

    train_feats, train_actions = load_dataset(train_path, train=True)
    test_feats, test_actions = load_dataset(test_path, train=False)

    K = variant['vae_kwargs']['K']
    rep_size = variant['vae_kwargs']['representation_size']

    logger.get_snapshot_dir()
    variant['vae_kwargs']['architecture'] = iodine.imsize64_large_iodine_architecture
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    refinement_net = RefinementNetwork(**iodine.imsize64_large_iodine_architecture['refine_args'],
                                       hidden_activation=nn.ELU())

    physics_net = PhysicsNetwork(K, rep_size, train_actions.shape[-1])
    m = IodineVAE(
        **variant['vae_kwargs'],
        refinement_net=refinement_net,
        dynamic=True,
        physics_net=physics_net,
    )

    m.to(ptu.device)

    t = IodineTrainer(train_feats, test_feats, m, variant['train_seedsteps'], variant['test_seedsteps'],
                      train_actions=train_actions, test_actions=test_actions,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch, batches=train_feats.shape[0]//variant['algo_kwargs']['batch_size'])
        t.test_epoch(epoch, save_vae=True, train=False, record_stats=True, batches=1,
                     save_reconstruction=should_save_imgs)
        t.test_epoch(epoch, save_vae=False, train=True, record_stats=False, batches=1,
                     save_reconstruction=should_save_imgs)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
Exemplo n.º 19
0
def train_vae(variant, return_data=False):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
Exemplo n.º 20
0
    def collect_paths(self, idx, epoch, run, wideeval=False):
        self.task_idx = idx
        if wideeval==False:
            self.env.reset_task(idx)
        else:
            self.env_eval.reset_task(idx)
        self.agent.clear_z()
        paths = []
        num_transitions = 0
        num_trajs = 0
        test_suc = 0
        while num_transitions < self.num_steps_per_eval:
            if wideeval == False:
                path, num, info = self.sampler.obtain_samples(deterministic=self.eval_deterministic,
                                                              max_samples=self.num_steps_per_eval - num_transitions,
                                                              max_trajs=1, accum_context=True)
            else:
                path, num, info = self.sampler_eval.obtain_samples(deterministic=self.eval_deterministic,
                                                                   max_samples=self.num_steps_per_eval - num_transitions,
                                                                   max_trajs=1, accum_context=True)
            paths += path
            num_transitions += num
            num_trajs += 1
            test_suc += info['n_success_num']
            if num_trajs >= self.num_exp_traj_eval:
                self.agent.infer_posterior(self.agent.context)
        suc_rate = test_suc / num_trajs
        if self.sparse_rewards:
            for p in paths:
                sparse_rewards = np.stack(e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1)
                p['rewards'] = sparse_rewards

        goal = self.env._goal
        for path in paths:
            path['goal'] = goal # goal

        # save the paths for visualization, only useful for point mass
        if self.dump_eval_paths:
            logger.save_extra_data(paths, path='eval_trajectories/task{}-epoch{}-run{}'.format(idx, epoch, run))

        return paths, suc_rate
Exemplo n.º 21
0
Arquivo: monet.py Projeto: mbchang/OP3
def train_vae(variant):
    #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train_10000.hdf5'
    #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorTwoBallSmall.h5'
    test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorTwoBallSmall.h5'

    train_data = load_dataset(train_path, train=True)
    test_data = load_dataset(test_path, train=False)

    train_data = train_data.reshape((train_data.shape[0], -1))
    test_data = test_data.reshape((test_data.shape[0], -1))
    #logger.save_extra_data(info)
    logger.get_snapshot_dir()
    variant['vae_kwargs'][
        'architecture'] = monet.imsize64_monet_architecture  #monet.imsize84_monet_architecture
    variant['vae_kwargs']['decoder_output_activation'] = identity
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    attention_net = UNet(in_channels=4,
                         n_classes=1,
                         up_mode='upsample',
                         depth=3,
                         padding=True)
    m = MonetVAE(**variant['vae_kwargs'], attention_net=attention_net)

    m.to(ptu.device)
    t = MonetTrainer(train_data, test_data, m, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
Exemplo n.º 22
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
Exemplo n.º 23
0
def train_vae(variant):
    train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5'
    test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    #train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'
    #test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'

    train_data = load_dataset(train_path, train=True)
    test_data = load_dataset(test_path, train=False)

    train_data = train_data.reshape((train_data.shape[0], -1))[:500]
    #train_data = train_data.reshape((train_data.shape[0], -1))[0]
    #train_data = np.reshape(train_data[:2], (2, -1)).repeat(100, 0)
    test_data = test_data.reshape((test_data.shape[0], -1))[:10]
    #logger.save_extra_data(info)
    logger.get_snapshot_dir()
    variant['vae_kwargs']['architecture'] = iodine.imsize84_iodine_architecture
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    refinement_net = RefinementNetwork(
        **iodine.imsize84_iodine_architecture['refine_args'],
        hidden_activation=nn.ELU())
    m = IodineVAE(**variant['vae_kwargs'], refinement_net=refinement_net)

    m.to(ptu.device)
    t = IodineTrainer(train_data, test_data, m, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch,
                      batches=train_data.shape[0] //
                      variant['algo_kwargs']['batch_size'])
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_vae=False)
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
Exemplo n.º 24
0
def grill_her_full_experiment(variant, mode='td3'):
    train_vae_variant = variant['train_vae_variant']
    grill_variant = variant['grill_variant']
    env_class = variant['env_class']
    env_kwargs = variant['env_kwargs']
    init_camera = variant['init_camera']
    train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class
    train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs
    train_vae_variant['generate_vae_dataset_kwargs'][
        'init_camera'] = init_camera
    grill_variant['env_class'] = env_class
    grill_variant['env_kwargs'] = env_kwargs
    grill_variant['init_camera'] = init_camera
    if 'vae_paths' not in grill_variant:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae = train_vae(train_vae_variant)
        rdim = train_vae_variant['representation_size']
        vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        grill_variant['vae_paths'] = {
            str(rdim): vae_file,
        }
        grill_variant['rdim'] = str(rdim)
    if mode == 'td3':
        grill_her_td3_experiment(variant['grill_variant'])
    elif mode == 'twin-sac':
        grill_her_twin_sac_experiment(variant['grill_variant'])
    elif mode == 'sac':
        grill_her_sac_experiment(variant['grill_variant'])
    def evaluate(self, epoch):
        if self.eval_statistics is None:
            self.eval_statistics = OrderedDict()

        ### sample trajectories from prior for debugging / visualization
        if self.dump_eval_paths:
            # 100 arbitrarily chosen for visualizations of point_robot trajectories
            # just want stochasticity of z, not the policy
            self.agent.clear_z()
            prior_paths, _ = self.sampler.obtain_samples(
                deterministic=self.eval_deterministic,
                max_samples=self.max_path_length * 20,
                accum_context=False,
                resample=1,
                testing=True)
            logger.save_extra_data(
                prior_paths,
                path='eval_trajectories/prior-epoch{}'.format(epoch))

        ### train tasks
        # eval on a subset of train tasks for speed
        indices = np.random.choice(self.train_tasks, len(self.eval_tasks))
        eval_util.dprint('evaluating on {} train tasks'.format(len(indices)))
        ### eval train tasks with posterior sampled from the training replay buffer
        train_returns = []
        for idx in indices:
            self.task_idx = idx
            self.env.reset_task(idx)
            paths = []
            for _ in range(self.num_steps_per_eval // self.max_path_length):
                context = self.sample_context(idx)
                self.agent.infer_posterior(context)
                p, _ = self.sampler.obtain_samples(
                    deterministic=self.eval_deterministic,
                    max_samples=self.max_path_length,
                    accum_context=False,
                    max_trajs=1,
                    resample=np.inf,
                    testing=True)
                paths += p

            if self.sparse_rewards:
                for p in paths:
                    sparse_rewards = np.stack(e['sparse_reward']
                                              for e in p['env_infos']).reshape(
                                                  -1, 1)
                    p['rewards'] = sparse_rewards

            train_returns.append(eval_util.get_average_returns(paths))
        train_returns = np.mean(train_returns)
        ### eval train tasks with on-policy data to match eval of test tasks
        train_final_returns, train_online_returns = self._do_eval(
            indices, epoch)
        eval_util.dprint('train online returns')
        eval_util.dprint(train_online_returns)

        ### test tasks
        eval_util.dprint('evaluating on {} test tasks'.format(
            len(self.eval_tasks)))
        test_final_returns, test_online_returns = self._do_eval(
            self.eval_tasks, epoch)
        eval_util.dprint('test online returns')
        eval_util.dprint(test_online_returns)

        # save the final posterior
        self.agent.log_diagnostics(self.eval_statistics)

        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(paths, prefix=None)

        avg_train_return = np.mean(train_final_returns)
        avg_test_return = np.mean(test_final_returns)
        avg_train_online_return = np.mean(np.stack(train_online_returns),
                                          axis=0)
        avg_test_online_return = np.mean(np.stack(test_online_returns), axis=0)
        self.eval_statistics[
            'AverageTrainReturn_all_train_tasks'] = train_returns
        self.eval_statistics[
            'AverageReturn_all_train_tasks'] = avg_train_return
        self.eval_statistics['AverageReturn_all_test_tasks'] = avg_test_return
        logger.save_extra_data(avg_train_online_return,
                               path='online-train-epoch{}'.format(epoch))
        logger.save_extra_data(avg_test_online_return,
                               path='online-test-epoch{}'.format(epoch))

        for key, value in self.eval_statistics.items():
            logger.record_tabular(key, value)
        self.eval_statistics = None

        if self.render_eval_paths:
            self.env.render_paths(paths)

        if self.plotter:
            self.plotter.draw()
Exemplo n.º 26
0
def train_vae(variant):
    #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5'
    #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    # train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'
    # test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'

    train_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10k.h5'
    test_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10k.h5'

    train_data = load_dataset(train_path, train=True)
    test_data = load_dataset(test_path, train=False)

    n_frames = 2
    imsize = train_data.shape[-1]
    T = variant['vae_kwargs']['T']
    K = variant['vae_kwargs']['K']
    rep_size = variant['vae_kwargs']['representation_size']
    # t_sample = np.array([0, 0, 0, 0, 0, 10, 15, 20, 25, 30])
    #t_sample = np.array([0, 34, 34, 34, 34])
    t_sample = np.array([0, 0, 0, 0, 1])
    train_data = train_data.reshape(
        (n_frames, -1, 3, imsize, imsize)).swapaxes(0, 1)[:8000, t_sample]
    test_data = test_data.reshape(
        (n_frames, -1, 3, imsize, imsize)).swapaxes(0, 1)[:50, t_sample]
    #logger.save_extra_data(info)
    logger.get_snapshot_dir()
    variant['vae_kwargs'][
        'architecture'] = iodine.imsize64_large_iodine_architecture
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    refinement_net = RefinementNetwork(
        **iodine.imsize64_large_iodine_architecture['refine_args'],
        hidden_activation=nn.ELU())
    physics_net = None
    if variant['physics']:
        physics_net = PhysicsNetwork(K, rep_size)
    m = IodineVAE(
        **variant['vae_kwargs'],
        refinement_net=refinement_net,
        dynamic=True,
        physics_net=physics_net,
    )

    m.to(ptu.device)

    t = IodineTrainer(train_data, test_data, m, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch,
                      batches=train_data.shape[0] //
                      variant['algo_kwargs']['batch_size'])
        t.test_epoch(epoch,
                     save_vae=True,
                     train=False,
                     record_stats=True,
                     batches=1,
                     save_reconstruction=should_save_imgs)
        t.test_epoch(epoch,
                     save_vae=False,
                     train=True,
                     record_stats=False,
                     batches=1,
                     save_reconstruction=should_save_imgs)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
def train_vae(variant, other_variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       other_variant,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus
    if return_data:
        return m, train_data, test_data
    return m
Exemplo n.º 28
0
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """

        statistics = OrderedDict()
        try:
            statistics.update(self.eval_statistics)
            self.eval_statistics = None
        except:
            print('No Stats to Eval')

        logger.log("Collecting random samples for evaluation")

        eval_steps = self.num_steps_per_eval

        test_paths = self.eval_sampler.obtain_samples(eval_steps)
        obs = torch.Tensor(
            np.squeeze(np.vstack([path["observations"]
                                  for path in test_paths])))
        acts = torch.Tensor(
            np.squeeze(np.vstack([path["actions"] for path in test_paths])))
        if len(acts.shape) < 2:
            acts = torch.unsqueeze(acts, 1)
        random_input = torch.cat([obs, acts], dim=1).to(ptu.device)

        exp_batch = self.get_batch(eval_steps,
                                   keys=['observations', 'actions'],
                                   use_expert_buffer=True)
        # exp_batch = {'observations':torch.Tensor([[0.],[1.],[2.],[3.],[4.],[5.],[6.],[7.],[8.],[9.],[10.]]), 'actions':torch.Tensor([[0.5]]*11)}

        obs = exp_batch['observations']
        acts = exp_batch['actions']
        exp_input = torch.cat([obs, acts], dim=1).to(ptu.device)

        statistics['random_avg_energy'] = self.ebm(random_input).mean().item()
        statistics['expert_avg_energy'] = self.get_energy(
            exp_input).mean().item()
        statistics['expert*20_avg_energy'] = self.get_energy(exp_input *
                                                             20).mean().item()

        statistics["random_expert_diff"] = statistics[
            "random_avg_energy"] - statistics["expert_avg_energy"]

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        best_statistic = statistics[self.best_key]

        if best_statistic > self.best_statistic_so_far:
            self.best_statistic_so_far = best_statistic
            self.best_epoch = epoch
            self.best_random_avg_energy = statistics['random_avg_energy']
            self.best_expert_avg_energy = statistics['expert_avg_energy']
            logger.record_tabular("Best Model Epoch", self.best_epoch)
            logger.record_tabular("Best Random Energy",
                                  self.best_random_avg_energy)
            logger.record_tabular("Best Expert Energy",
                                  self.best_expert_avg_energy)
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {'epoch': epoch, 'statistics': statistics}
                data_to_save.update(self.get_epoch_snapshot(epoch))
                logger.save_extra_data(data_to_save, 'best.pkl')
                print('\n\nSAVED BEST\n\n')
        logger.record_tabular("Best Model Epoch", self.best_epoch)
        logger.record_tabular("Best Random Energy",
                              self.best_random_avg_energy)
        logger.record_tabular("Best Expert Energy",
                              self.best_expert_avg_energy)
Exemplo n.º 29
0
    def train(self):
        '''
        meta-training loop
        '''
        params = self.get_epoch_snapshot(-1)
        logger.save_itr_params(-1, params)
        gt.reset()
        gt.set_def_unique(False)
        self._current_path_builder = PathBuilder()

        # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate
        for it_ in gt.timed_for(range(self.num_iterations), save_itrs=True):
            self._start_epoch(it_)
            self.training_mode(True)
            if it_ == 0:
                print('collecting initial pool of data for train and eval')
                # temp for evaluating
                for idx in self.train_tasks:
                    self.task_idx = idx
                    self.env.reset_task(idx)
                    self.collect_data(self.num_initial_steps,
                                      1,
                                      np.inf,
                                      buffer=self.train_buffer)
            # Sample data from train tasks.
            for i in range(self.num_tasks_sample):
                idx = np.random.choice(self.train_tasks, 1)[0]
                self.task_idx = idx
                self.env.reset_task(idx)
                self.enc_replay_buffer.task_buffers[idx].clear()

                # collect some trajectories with z ~ prior
                if self.num_steps_prior > 0:
                    self.collect_data(self.num_steps_prior,
                                      1,
                                      np.inf,
                                      buffer=self.train_buffer)
                # collect some trajectories with z ~ posterior
                if self.num_steps_posterior > 0:
                    self.collect_data(self.num_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      buffer=self.train_buffer)
                # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior
                if self.num_extra_rl_steps_posterior > 0:
                    self.collect_data(self.num_extra_rl_steps_posterior,
                                      1,
                                      self.update_post_train,
                                      buffer=self.train_buffer,
                                      add_to_enc_buffer=False)

            indices_lst = []
            z_means_lst = []
            z_vars_lst = []
            # Sample train tasks and compute gradient updates on parameters.
            for train_step in range(self.num_train_steps_per_itr):
                indices = np.random.choice(self.train_tasks,
                                           self.meta_batch,
                                           replace=self.mb_replace)
                z_means, z_vars = self._do_training(indices, zloss=True)
                indices_lst.append(indices)
                z_means_lst.append(z_means)
                z_vars_lst.append(z_vars)
                self._n_train_steps_total += 1

            indices = np.concatenate(indices_lst)
            z_means = np.concatenate(z_means_lst)
            z_vars = np.concatenate(z_vars_lst)
            data_dict = self.data_dict(indices, z_means, z_vars)
            logger.save_itr_data(it_, **data_dict)
            gt.stamp('train')
            self.training_mode(False)
            # eval
            params = self.get_epoch_snapshot(it_)
            logger.save_itr_params(it_, params)

            if self.allow_eval:
                logger.save_extra_data(self.get_extra_data_to_save(it_))
                self._try_to_eval(it_)
                gt.stamp('eval')
            self._end_epoch()
Exemplo n.º 30
0
def train_pixelcnn(
    vqvae=None,
    vqvae_path=None,
    num_epochs=100,
    batch_size=32,
    n_layers=15,
    dataset_path=None,
    save=True,
    save_period=10,
    cached_dataset_path=False,
    trainer_kwargs=None,
    model_kwargs=None,
    data_filter_fn=lambda x: x,
    debug=False,
    data_size=float('inf'),
    num_train_batches_per_epoch=None,
    num_test_batches_per_epoch=None,
    train_img_loader=None,
    test_img_loader=None,
):
    trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs
    model_kwargs = {} if model_kwargs is None else model_kwargs

    # Load VQVAE + Define Args
    if vqvae is None:
        vqvae = load_local_or_remote_file(vqvae_path)
        vqvae.to(ptu.device)
        vqvae.eval()

    root_len = vqvae.root_len
    num_embeddings = vqvae.num_embeddings
    embedding_dim = vqvae.embedding_dim
    cond_size = vqvae.num_embeddings
    imsize = vqvae.imsize
    discrete_size = root_len * root_len
    representation_size = embedding_dim * discrete_size
    input_channels = vqvae.input_channels
    imlength = imsize * imsize * input_channels

    log_dir = logger.get_snapshot_dir()

    # Define data loading info
    new_path = osp.join(log_dir, 'pixelcnn_data.npy')

    def prep_sample_data(cached_path):
        data = load_local_or_remote_file(cached_path).item()
        train_data = data['train']
        test_data = data['test']
        return train_data, test_data

    def encode_dataset(path, object_list):
        data = load_local_or_remote_file(path)
        data = data.item()
        data = data_filter_fn(data)

        all_data = []
        n = min(data["observations"].shape[0], data_size)

        for i in tqdm(range(n)):
            obs = ptu.from_numpy(data["observations"][i] / 255.0)
            latent = vqvae.encode(obs, cont=False)
            all_data.append(latent)

        encodings = ptu.get_numpy(torch.stack(all_data, dim=0))
        return encodings

    if train_img_loader:
        _, test_loader, test_batch_loader = create_conditional_data_loader(
            test_img_loader, 80, vqvae, "test2")  # 80
        _, train_loader, train_batch_loader = create_conditional_data_loader(
            train_img_loader, 2000, vqvae, "train2")  # 2000
    else:
        if cached_dataset_path:
            train_data, test_data = prep_sample_data(cached_dataset_path)
        else:
            train_data = encode_dataset(dataset_path['train'],
                                        None)  # object_list)
            test_data = encode_dataset(dataset_path['test'], None)
        dataset = {'train': train_data, 'test': test_data}
        np.save(new_path, dataset)

        _, _, train_loader, test_loader, _ = \
            rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size)

    #train_dataset = InfiniteBatchLoader(train_loader)
    #test_dataset = InfiniteBatchLoader(test_loader)

    print("Finished loading data")

    model = GatedPixelCNN(num_embeddings,
                          root_len**2,
                          n_classes=representation_size,
                          **model_kwargs).to(ptu.device)
    trainer = PixelCNNTrainer(
        model,
        vqvae,
        batch_size=batch_size,
        **trainer_kwargs,
    )

    print("Starting training")

    BEST_LOSS = 999
    for epoch in range(num_epochs):
        should_save = (epoch % save_period == 0) and (epoch > 0)
        trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch)
        trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch)

        test_data = test_batch_loader.random_batch(bz)["x"]
        train_data = train_batch_loader.random_batch(bz)["x"]
        trainer.dump_samples(epoch, test_data, test=True)
        trainer.dump_samples(epoch, train_data, test=False)

        if should_save:
            logger.save_itr_params(epoch, model)

        stats = trainer.get_diagnostics()

        cur_loss = stats["test/loss"]
        if cur_loss < BEST_LOSS:
            BEST_LOSS = cur_loss
            vqvae.set_pixel_cnn(model)
            logger.save_extra_data(vqvae, 'best_vqvae', mode='torch')
        else:
            return vqvae

        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

    return vqvae