예제 #1
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
예제 #2
0
def train_vae(variant, return_data=False):
    from rlkit.misc.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
예제 #3
0
def save_paths(algo, epoch):
    expl_paths = algo.expl_data_collector.get_epoch_paths()
    filename = osp.join(logger.get_snapshot_dir(),
                        'video_{epoch}_vae.p'.format(epoch=epoch))
    pickle.dump(expl_paths, open(filename, "wb"))
    print("saved", filename)
    eval_paths = algo.eval_data_collector.get_epoch_paths()
    filename = osp.join(logger.get_snapshot_dir(),
                        'video_{epoch}_env.p'.format(epoch=epoch))
    pickle.dump(eval_paths, open(filename, "wb"))
    print("saved", filename)
예제 #4
0
 def plot_scattered(self, z, epoch):
     try:
         import matplotlib.pyplot as plt
     except ImportError:
         logger.log(__file__ + ": Unable to load matplotlib. Consider "
                    "setting do_scatterplot to False")
         return
     dim_and_stds = [(i, np.std(z[:, i])) for i in range(z.shape[1])]
     dim_and_stds = sorted(dim_and_stds, key=lambda x: x[1])
     dim1 = dim_and_stds[-1][0]
     dim2 = dim_and_stds[-2][0]
     plt.figure(figsize=(8, 8))
     plt.scatter(z[:, dim1], z[:, dim2], marker='o', edgecolor='none')
     if self.model.dist_mu is not None:
         x1 = self.model.dist_mu[dim1:dim1 + 1]
         y1 = self.model.dist_mu[dim2:dim2 + 1]
         x2 = self.model.dist_mu[dim1:dim1 +
                                 1] + self.model.dist_std[dim1:dim1 + 1]
         y2 = self.model.dist_mu[dim2:dim2 +
                                 1] + self.model.dist_std[dim2:dim2 + 1]
     plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=2)
     axes = plt.gca()
     axes.set_xlim([-6, 6])
     axes.set_ylim([-6, 6])
     axes.set_title('dim {} vs dim {}'.format(dim1, dim2))
     plt.grid(True)
     save_file = osp.join(logger.get_snapshot_dir(),
                          'scatter%d.png' % epoch)
     plt.savefig(save_file)
예제 #5
0
def experiment(user_variant):
    variant = default_variant.copy()
    variant.update(user_variant)

    if ptu.gpu_enabled():
        enable_gpus("0")

    env_id = variant["env"]
    env = build_env(env_id)

    agent_configs = variant["agent_configs"]
    agent = build_agent(env, env_id, agent_configs)
    agent.visualize = variant["visualize"]
    model_file = variant.get("model_file")
    if (model_file is not ""):
        agent.load_model(model_file)

    log_dir = logger.get_snapshot_dir()
    if (variant["train"]):
        agent.train(max_iter=variant["max_iter"],
                    test_episodes=variant["test_episodes"],
                    output_dir=log_dir,
                    output_iters=variant["output_iters"])
    else:
        agent.eval(num_episodes=variant["test_episodes"])

    return
예제 #6
0
def test_iql():
    logger.reset()

    # make tests small by mutating variant
    iql.variant["algo_kwargs"]["start_epoch"] = -2
    iql.variant["algo_kwargs"]["num_epochs"] = 2
    iql.variant["algo_kwargs"]["batch_size"] = 2
    iql.variant["algo_kwargs"]["num_eval_steps_per_epoch"] = 2
    iql.variant["algo_kwargs"]["num_expl_steps_per_train_loop"] = 2
    iql.variant["algo_kwargs"]["num_trains_per_train_loop"] = 100
    iql.variant["algo_kwargs"]["min_num_steps_before_training"] = 2
    iql.variant["qf_kwargs"] = dict(hidden_sizes=[2, 2])

    iql.variant["seed"] = 25580

    iql.main()

    reference_csv = "tests/regression/iql/halfcheetah_online_progress.csv"
    output_csv = os.path.join(logger.get_snapshot_dir(), "progress.csv")
    print("comparing reference %s against output %s" %
          (reference_csv, output_csv))
    output = csv_util.get_exp(output_csv)
    reference = csv_util.get_exp(reference_csv)
    keys = [
        "epoch",
        "expl/num steps total",
        "expl/Average Returns",
        "trainer/Q1 Predictions Mean",
    ]
    csv_util.check_equal(reference, output, keys)
예제 #7
0
def create_save_h_vs_state_distance_fn(save_period, initial_save_period,
                                       encoder, encoder_input_key):
    import matplotlib.pyplot as plt
    from rlkit.core import logger
    import os.path as osp

    logdir = logger.get_snapshot_dir()

    def save_h_vs_state_distance(algo, epoch):
        if ((epoch < save_period and epoch % initial_save_period == 0)
                or epoch % save_period == 0 or epoch >= algo.num_epochs - 1):
            filename = osp.join(
                logdir,
                'h_vs_distance_scatterplot_{epoch}.png'.format(epoch=epoch))
            replay_buffer = algo.replay_buffer
            size = min(1024, replay_buffer._size)
            idxs1 = replay_buffer._sample_indices(size)
            idxs2 = replay_buffer._sample_indices(size)
            encoder_obs = replay_buffer._obs[encoder_input_key]
            x1 = encoder_obs[idxs1]
            x2 = encoder_obs[idxs2]
            z1 = encoder.encode(x1)
            z2 = encoder.encode(x2)

            state_obs = replay_buffer._obs['state_observation']
            states1 = state_obs[idxs1]
            states2 = state_obs[idxs2]
            state_deltas = np.linalg.norm(states1 - states2, axis=1, ord=1)
            encoder_deltas = np.linalg.norm(z1 - z2, axis=1, ord=1)

            plt.clf()
            plt.scatter(state_deltas, encoder_deltas, alpha=0.2)
            plt.savefig(filename)

    return save_h_vs_state_distance
예제 #8
0
 def __init__(self,
              model,
              data_collector,
              tag,
              save_video_period,
              goal_image_key=None,
              decode_goal_image_key=None,
              reconstruction_key=None,
              **kwargs):
     self.model = model
     self.data_collector = data_collector
     self.tag = tag
     self.goal_image_key = goal_image_key
     self.decode_goal_image_key = decode_goal_image_key
     self.reconstruction_key = reconstruction_key
     self.dump_video_kwargs = kwargs
     self.save_video_period = save_video_period
     self.keys = []
     if goal_image_key:
         self.keys.append(goal_image_key)
     if decode_goal_image_key:
         self.keys.append(decode_goal_image_key)
     self.keys.append("image_observation")
     if reconstruction_key:
         self.keys.append(reconstruction_key)
     self.logdir = logger.get_snapshot_dir()
예제 #9
0
    def dump_samples(self, epoch):
        self.model.eval()
        # set gpu device explicitly
        label = torch.tensor(np.zeros(64)).long().to(ptu.device)

        #sample1 = ptu.randn(64, self.representation_size)
        #sample2 = self.model.decode(sample1)[0].cpu()
        sample = self.model.pixelcnn.generate(label, (3, 3), 64).cpu()

        e_indices = torch.tensor(sample).reshape(-1, 1).long().to(ptu.device)

        min_encodings = torch.zeros(e_indices.shape[0], 64).to(ptu.device)
        min_encodings.scatter_(1, e_indices, 1)
        e_weights = self.model.vector_quantization.embedding.weight
        z_q = torch.matmul(min_encodings, e_weights).view((64, 3, 3, 2))
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        z_q = self.model.pre_dequantization_conv(z_q)
        x_recon = self.model.decoder(z_q)
        sample = x_recon.view(64, -1)
        #print(sample1.shape, sample2.shape, sample.shape, x_recon.shape)
        #assert False
        save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch)
        save_image(
            sample.data.view(64, self.input_channels, self.imsize,
                             self.imsize).transpose(2, 3), save_dir)
예제 #10
0
    def save(self):
        path = logger.get_snapshot_dir()
        trainer = self.trainer
        expl_data_collector = self.expl_data_collector
        eval_data_collector = self.eval_data_collector
        replay_buffer = self.replay_buffer
        expl_env = self.expl_env
        eval_env = self.eval_env
        pretrain_policy = self.pretrain_policy

        delattr(self, "trainer")
        delattr(self, "expl_data_collector")
        delattr(self, "eval_data_collector")
        delattr(self, "replay_buffer")
        delattr(self, "expl_env")
        delattr(self, "eval_env")
        delattr(self, "pretrain_policy")

        pickle.dump(self, open(os.path.join(path, "algorithm.pkl"), "wb"))

        trainer.save(path, "trainer.pkl")
        expl_data_collector.save(path, "expl_data_collector.pkl")
        eval_data_collector.save(path, "eval_data_collector.pkl")
        replay_buffer.save(path, "replay_buffer.pkl")
        expl_env.save(path, "expl_env.pkl")
        eval_env.save(path, "eval_env.pkl")
        pretrain_policy.save(path, "pretrain_policy.pkl")

        self.trainer = trainer
        self.expl_data_collector = expl_data_collector
        self.eval_data_collector = eval_data_collector
        self.replay_buffer = replay_buffer
        self.expl_env = expl_env
        self.eval_env = eval_env
        self.pretrain_policy = pretrain_policy
예제 #11
0
def post_epoch_visualize_func(algorithm, epoch):
    if epoch % 10 == 0:
        visualize_rollout(
            algorithm.eval_env.envs[0],
            algorithm.trainer.world_model,
            logger.get_snapshot_dir(),
            algorithm.max_path_length,
            low_level_primitives=algorithm.low_level_primitives,
            policy=algorithm.eval_data_collector._policy,
            use_raps_obs=False,
            use_true_actions=True,
            num_rollouts=6,
        )
        if algorithm.low_level_primitives:
            visualize_rollout(
                algorithm.eval_env.envs[0],
                algorithm.trainer.world_model,
                logger.get_snapshot_dir(),
                algorithm.max_path_length,
                low_level_primitives=algorithm.low_level_primitives,
                policy=algorithm.eval_data_collector._policy,
                use_raps_obs=True,
                use_true_actions=True,
                num_rollouts=2,
            )
            visualize_rollout(
                algorithm.eval_env.envs[0],
                algorithm.trainer.world_model,
                logger.get_snapshot_dir(),
                algorithm.max_path_length,
                low_level_primitives=algorithm.low_level_primitives,
                policy=algorithm.eval_data_collector._policy,
                use_raps_obs=True,
                use_true_actions=False,
                num_rollouts=2,
            )
            visualize_rollout(
                algorithm.eval_env.envs[0],
                algorithm.trainer.world_model,
                logger.get_snapshot_dir(),
                algorithm.max_path_length,
                low_level_primitives=algorithm.low_level_primitives,
                policy=algorithm.eval_data_collector._policy,
                use_raps_obs=False,
                use_true_actions=False,
                num_rollouts=2,
            )
    def train(self):
        epoch = -1
        for t in range(self.num_update_loops_per_train_call):
            epoch += 1
            for _ in range(self.num_disc_updates_per_loop_iter):
                self._do_reward_training(epoch)

            self.discriminator.eval()
            logits = self.discriminator(self.xy_var, None)
            rewards = self._convert_logits_to_reward(logits)
            self.discriminator.train()

            logit_bound = 10.0
            if self.train_objective == 'airl':
                rew_bound = 10.0
            elif self.train_objective == 'fairl':
                rew_bound = 100.0
            elif self.train_objective == 'gail':
                rew_bound = 10.0
            elif self.train_objective == 'w1':
                rew_bound = 10.0
            else:
                raise Exception()

            # plot the logits of the discriminator
            # print(logit_bound)
            # print(rew_bound)
            logits = ptu.get_numpy(logits)
            logits = np.reshape(logits, (int(self._d_len), int(self._d_len))).T
            plot_seaborn_grid(
                logits, -logit_bound, logit_bound,
                'Disc Logits Epoch %d' % epoch,
                osp.join(logger.get_snapshot_dir(),
                         'disc_logits_epoch_%d.png' % epoch))

            # plot the rewards given by the discriminator
            rewards = ptu.get_numpy(rewards)
            rewards = np.reshape(rewards,
                                 (int(self._d_len), int(self._d_len))).T
            plot_seaborn_grid(
                rewards, -rew_bound, rew_bound,
                'Disc Rewards Epoch %d' % epoch,
                osp.join(logger.get_snapshot_dir(),
                         'disc_rewards_epoch_%d.png' % epoch))

            logger.dump_tabular(with_prefix=False, with_timestamp=False)
            self.rewardf_eval_statistics = None
    def test_epoch(
            self,
            epoch,
            save_reconstruction=True,
            save_vae=True,
            from_rl=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        kles = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            next_obs = self.get_batch(train=False)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(next_obs)
            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            loss = -1 * log_prob + beta * kle

            encoder_mean = latent_distribution_params[0]
            z_data = ptu.get_numpy(encoder_mean.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())

            if batch_idx == 0 and save_reconstruction:
                n = min(next_obs.size(0), 8)
                comparison = torch.cat([
                    next_obs[:n].narrow(start=0, length=self.imlength, dim=1)
                        .contiguous().view(
                        -1, self.input_channels, self.imsize, self.imsize
                    ).transpose(2, 3),
                    reconstructions.view(
                        self.batch_size,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    )[:n].transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'r%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)

        zs = np.array(zs)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['beta'] = beta
        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)
예제 #14
0
def run_task(variant):
    from rlkit.core import logger
    print(variant)
    logger.log("Hello from script")
    logger.log("variant: " + str(variant))
    logger.record_tabular("value", 1)
    logger.dump_tabular()
    logger.log("snapshot_dir:", logger.get_snapshot_dir())
예제 #15
0
    def render_video(self, tag, counter):
        import numpy as np
        import pdb
#         log.debug("{}".format("render_video_and_add_to_tensorboard"))
        
        path = self.eval_data_collector.collect_new_paths(
            self.max_path_length,
            self.num_eval_steps_per_epoch,
            discard_incomplete_paths=True
        )
        
        if ("vae_reconstruction" in path[0]['env_infos'][0]):
            video = np.array([ [y['vae_reconstruction'] for y in x['env_infos']] for x in  path])
            display_gif(images=video, logdir=logger.get_snapshot_dir()+"/"+tag+"_reconstruction" , fps=15, counter=counter)
            
        video = np.array([ [y['rendering'] for y in x['env_infos']] for x in  path])
        display_gif(images=video, logdir=logger.get_snapshot_dir()+"/"+tag , fps=15, counter=counter)
예제 #16
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.Variable(torch.randn(64, self.representation_size))
     sample = self.model.decode(sample).cpu()
     save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch)
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize), save_dir)
예제 #17
0
def train_vae(variant):
    #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5'
    #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    # train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'
    # test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'

    train_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10kActions.h5'
    test_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10kActions.h5'

    train_feats, train_actions = load_dataset(train_path, train=True)
    test_feats, test_actions = load_dataset(test_path, train=False)

    K = variant['vae_kwargs']['K']
    rep_size = variant['vae_kwargs']['representation_size']

    logger.get_snapshot_dir()
    variant['vae_kwargs']['architecture'] = iodine.imsize64_large_iodine_architecture
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    refinement_net = RefinementNetwork(**iodine.imsize64_large_iodine_architecture['refine_args'],
                                       hidden_activation=nn.ELU())

    physics_net = PhysicsNetwork(K, rep_size, train_actions.shape[-1])
    m = IodineVAE(
        **variant['vae_kwargs'],
        refinement_net=refinement_net,
        dynamic=True,
        physics_net=physics_net,
    )

    m.to(ptu.device)

    t = IodineTrainer(train_feats, test_feats, m, variant['train_seedsteps'], variant['test_seedsteps'],
                      train_actions=train_actions, test_actions=test_actions,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch, batches=train_feats.shape[0]//variant['algo_kwargs']['batch_size'])
        t.test_epoch(epoch, save_vae=True, train=False, record_stats=True, batches=1,
                     save_reconstruction=should_save_imgs)
        t.test_epoch(epoch, save_vae=False, train=True, record_stats=False, batches=1,
                     save_reconstruction=should_save_imgs)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
예제 #18
0
 def dump_samples(self, epoch, save_prefix='s'):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     save_dir = osp.join(logger.get_snapshot_dir(),
                         '{}{}.png'.format(save_prefix, epoch))
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize).transpose(2, 3), save_dir)
예제 #19
0
    def dump_debug_images(
        self,
        epoch,
        dump_images=True,
        num_recons=10,
        num_samples=25,
        debug_period=10,
        unnormalize_images=False,
    ):
        """

        :param epoch:
        :param dump_images: Set to False to not dump any images.
        :param num_recons:
        :param num_samples:
        :param debug_period: How often do you dump debug images?
        :param unnormalize_images: Should your unnormalize images before
            dumping them? Set to True if images are floats in [0, 1].
        :return:
        """
        if not dump_images or epoch % debug_period != 0:
            return
        example_obs_batch_np = ptu.get_numpy(self.example_obs_batch)
        recon_examples_np = ptu.get_numpy(
            self.vae.reconstruct(self.example_obs_batch))

        top_row_example = example_obs_batch_np[:num_recons]
        bottom_row_recon = np.clip(recon_examples_np, 0, 1)[:num_recons]

        recon_vis = combine_images_into_grid(
            imgs=list(top_row_example) + list(bottom_row_recon),
            imwidth=example_obs_batch_np.shape[2],
            imheight=example_obs_batch_np.shape[3],
            max_num_cols=len(top_row_example),
            image_format='CWH',
            unnormalize=unnormalize_images,
        )

        logdir = logger.get_snapshot_dir()
        cv2.imwrite(
            osp.join(logdir, '{}_recons.png'.format(epoch)),
            cv2.cvtColor(recon_vis, cv2.COLOR_RGB2BGR),
        )

        raw_samples = ptu.get_numpy(self.vae.sample(num_samples))
        vae_samples = np.clip(raw_samples, 0, 1)
        vae_sample_vis = combine_images_into_grid(
            imgs=list(vae_samples),
            imwidth=example_obs_batch_np.shape[2],
            imheight=example_obs_batch_np.shape[3],
            image_format='CWH',
            unnormalize=unnormalize_images,
        )
        cv2.imwrite(
            osp.join(logdir, '{}_vae_samples.png'.format(epoch)),
            cv2.cvtColor(vae_sample_vis, cv2.COLOR_RGB2BGR),
        )
예제 #20
0
def train_vae(variant, return_data=False):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
예제 #21
0
    def _log_stats(self, epoch):
        logger.log("Epoch {} finished".format(epoch), with_timestamp=True)
        """
        Replay Buffer
        """
        logger.record_dict(self.replay_buffer.get_diagnostics(),
                           prefix='replay_buffer/')

        # If you want to save replay buffer as a whole, use this
        snap_shot_dir = logger.get_snapshot_dir()
        self.replay_buffer.save_buffer(snap_shot_dir + '/online_buffer.hdf5')
        """
        Trainer
        """
        logger.record_dict(self.trainer.get_diagnostics(), prefix='trainer/')
        """
        Exploration
        """
        logger.record_dict(self.expl_data_collector.get_diagnostics(),
                           prefix='exploration/')
        expl_paths = self.expl_data_collector.get_epoch_paths()
        # import ipdb; ipdb.set_trace()
        if hasattr(self.expl_env, 'get_diagnostics'):
            logger.record_dict(
                self.expl_env.get_diagnostics(expl_paths),
                prefix='exploration/',
            )
        if not self.batch_rl or self.eval_both:
            logger.record_dict(
                eval_util.get_generic_path_information(expl_paths),
                prefix="exploration/",
            )
        """
        Evaluation
        """
        logger.record_dict(
            self.eval_data_collector.get_diagnostics(),
            prefix='evaluation/',
        )
        eval_paths = self.eval_data_collector.get_epoch_paths()
        if hasattr(self.eval_env, 'get_diagnostics'):
            logger.record_dict(
                self.eval_env.get_diagnostics(eval_paths),
                prefix='evaluation/',
            )
        logger.record_dict(
            eval_util.get_generic_path_information(eval_paths),
            prefix="evaluation/",
        )
        """
        Misc
        """
        gt.stamp('logging')
        logger.record_dict(_get_epoch_timings())
        logger.record_tabular('Epoch', epoch)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
예제 #22
0
def train_set_vae(create_vae_kwargs,
                  vae_trainer_kwargs,
                  algo_kwargs,
                  data_loader_kwargs,
                  generate_set_kwargs,
                  num_ungrouped_images,
                  env_id=None,
                  env_class=None,
                  env_kwargs=None,
                  beta_schedule_kwargs=None,
                  env=None,
                  renderer=None,
                  sets=None) -> VAE:
    if beta_schedule_kwargs is None:
        beta_schedule_kwargs = {}
    print("vae_launcher:train_set_vae: device", ptu.device)
    eval_set_imgs, renderer, set_imgs, set_imgs_iterator, ungrouped_imgs = create_dataset(
        env_id,
        env_class,
        env_kwargs,
        generate_set_kwargs,
        num_ungrouped_images,
        env=env,
        renderer=renderer,
        sets=sets,
    )

    set_imgs_flat = set_imgs.view((-1, *set_imgs.shape[-3:]))
    all_imgs = torch.cat([ungrouped_imgs, set_imgs_flat], dim=0)
    all_imgs_iterator = data.DataLoader(all_imgs, **data_loader_kwargs)

    vae = create_image_vae(img_chw=renderer.image_chw, **create_vae_kwargs)

    set_key = 'set'
    data_key = 'data'
    dict_loader = DictLoader({
        data_key: all_imgs_iterator,
        set_key: infinite(set_imgs_iterator),
    })
    beta_schedule = create_beta_schedule(**beta_schedule_kwargs)
    vae_trainer = SetVAETrainer(vae=vae,
                                set_key=set_key,
                                data_key=data_key,
                                train_sets=set_imgs,
                                eval_sets=eval_set_imgs[:2],
                                beta_schedule=beta_schedule,
                                **vae_trainer_kwargs)
    algorithm = UnsupervisedTorchAlgorithm(
        vae_trainer,
        dict_loader,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)
    algorithm.run()
    print(logger.get_snapshot_dir())
    return vae
예제 #23
0
def plot_encoder_function(variant, encoder, tag=""):
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from rlkit.core import logger
    logdir = logger.get_snapshot_dir()

    def plot_encoder(algo, epoch, is_x=False):
        save_period = variant.get('save_video_period', 50)
        if epoch % save_period == 0 or epoch == algo.num_epochs:
            filename = osp.join(
                logdir,
                'encoder_{}_{}_{epoch}_env.gif'.format(tag,
                                                       "x" if is_x else "y",
                                                       epoch=epoch))

            vary = np.arange(-4, 4, .1)
            static = np.zeros(len(vary))

            points_x = np.c_[vary.reshape(-1, 1), static.reshape(-1, 1)]
            points_y = np.c_[static.reshape(-1, 1), vary.reshape(-1, 1)]

            encoded_points_x = ptu.get_numpy(
                encoder.forward(ptu.from_numpy(points_x)))
            encoded_points_y = ptu.get_numpy(
                encoder.forward(ptu.from_numpy(points_y)))

            plt.clf()
            fig = plt.figure()
            plt.xlim(
                min(min(encoded_points_x[:, 0]), min(encoded_points_y[:, 0])),
                max(max(encoded_points_x[:, 0]), max(encoded_points_y[:, 0])))
            plt.ylim(
                min(min(encoded_points_x[:, 1]), min(encoded_points_y[:, 1])),
                max(max(encoded_points_x[:, 1]), max(encoded_points_y[:, 1])))
            colors = ["red", "blue"]
            lines = [
                plt.plot([], [], 'o', color=colors[i], alpha=0.4)[0]
                for i in range(2)
            ]

            def animate(i):
                lines[0].set_data(encoded_points_x[:i + 1, 0],
                                  encoded_points_x[:i + 1, 1])
                lines[1].set_data(encoded_points_y[:i + 1, 0],
                                  encoded_points_y[:i + 1, 1])
                return lines

            ani = FuncAnimation(fig, animate, frames=len(vary), interval=40)
            ani.save(filename, writer='imagemagick', fps=60)

    # def plot_encoder_x_and_y(algo, epoch):
    # plot_encoder(algo, epoch, is_x=True)
    # plot_encoder(algo, epoch, is_x=False)

    return plot_encoder
예제 #24
0
    def visualize_representation(algo, epoch):
        if ((epoch < save_period and epoch % initial_save_period == 0)
                or epoch % save_period == 0 or epoch >= algo.num_epochs - 1):
            logdir = logger.get_snapshot_dir()
            filename = osp.join(
                logdir,
                'obj{obj_id}_sweep_visualization_{epoch}.png'.format(
                    obj_id=obj_to_sweep, epoch=epoch),
            )
            visualizations = []
            for goal_dict in goal_dicts:
                start_img = goal_dict['image_observation']
                new_states = goal_dict['new_states']

                encoded = encoder.encode(new_states)
                # img_format = renderer.output_image_format
                images_to_stack = [start_img]
                for i in range(encoded.shape[1]):
                    values = encoded[:, i]
                    value_image = values.reshape(env_renderer.image_chw[1:])
                    # TODO: fix hardcoding of CHW
                    value_img_rgb = np.repeat(value_image[None, :, :],
                                              3,
                                              axis=0)
                    value_img_rgb = (
                        (value_img_rgb - value_img_rgb.min()) /
                        (value_img_rgb.max() - value_img_rgb.min() + 1e-9))
                    images_to_stack.append(value_img_rgb)

                visualizations.append(
                    combine_images_into_grid(
                        images_to_stack,
                        imwidth=renderer.image_chw[2],
                        imheight=renderer.image_chw[1],
                        max_num_cols=len(start_states),
                        pad_length=1,
                        pad_color=0,
                        subpad_length=1,
                        subpad_color=128,
                        image_format=renderer.output_image_format,
                        unnormalize=True,
                    ))

            final_image = combine_images_into_grid(
                visualizations,
                imwidth=visualizations[0].shape[1],
                imheight=visualizations[0].shape[0],
                max_num_cols=3,
                image_format='HWC',
                pad_length=0,
                subpad_length=0,
            )
            cv2.imwrite(filename, final_image)

            print("Saved visualization image to to ", filename)
예제 #25
0
def create_and_save_dict(list_of_waypoints, goals):
    dataset = {
        'list_of_waypoints': list_of_waypoints,
        'goals': goals,
    }
    from rlkit.core import logger
    logdir = logger.get_snapshot_dir()
    np.save(
        os.path.join(logdir, 'example_dataset.npy'),
        dataset
    )
    return dataset
예제 #26
0
def get_video_save_func(rollout_function, env, policy, variant):
    from multiworld.core.image_env import ImageEnv
    from rlkit.core import logger
    from rlkit.envs.vae_wrappers import temporary_mode
    from rlkit.visualization.video import dump_video
    logdir = logger.get_snapshot_dir()
    save_period = variant.get('save_video_period', 50)
    do_state_exp = variant.get("do_state_exp", False)
    dump_video_kwargs = variant.get("dump_video_kwargs", dict())
    dump_video_kwargs['horizon'] = variant['max_path_length']

    if do_state_exp:
        imsize = variant.get('imsize')
        dump_video_kwargs['imsize'] = imsize
        image_env = ImageEnv(
            env,
            imsize,
            init_camera=variant.get('init_camera', None),
            transpose=True,
            normalize=True,
        )

        def save_video(algo, epoch):
            if epoch % save_period == 0 or epoch == algo.num_epochs:
                filename = osp.join(
                    logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
                dump_video(image_env, policy, filename, rollout_function,
                           **dump_video_kwargs)
    else:
        image_env = env
        dump_video_kwargs['imsize'] = env.imsize

        def save_video(algo, epoch):
            if epoch % save_period == 0 or epoch == algo.num_epochs:
                filename = osp.join(
                    logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_env',
                               func=dump_video,
                               args=(image_env, policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(image_env, policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)

    return save_video
예제 #27
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
예제 #28
0
파일: monet.py 프로젝트: mbchang/OP3
def train_vae(variant):
    #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train_10000.hdf5'
    #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorTwoBallSmall.h5'
    test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorTwoBallSmall.h5'

    train_data = load_dataset(train_path, train=True)
    test_data = load_dataset(test_path, train=False)

    train_data = train_data.reshape((train_data.shape[0], -1))
    test_data = test_data.reshape((test_data.shape[0], -1))
    #logger.save_extra_data(info)
    logger.get_snapshot_dir()
    variant['vae_kwargs'][
        'architecture'] = monet.imsize64_monet_architecture  #monet.imsize84_monet_architecture
    variant['vae_kwargs']['decoder_output_activation'] = identity
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    attention_net = UNet(in_channels=4,
                         n_classes=1,
                         up_mode='upsample',
                         depth=3,
                         padding=True)
    m = MonetVAE(**variant['vae_kwargs'], attention_net=attention_net)

    m.to(ptu.device)
    t = MonetTrainer(train_data, test_data, m, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
예제 #29
0
def train_vae(variant):
    train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5'
    test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5'

    #train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'
    #test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5'

    train_data = load_dataset(train_path, train=True)
    test_data = load_dataset(test_path, train=False)

    train_data = train_data.reshape((train_data.shape[0], -1))[:500]
    #train_data = train_data.reshape((train_data.shape[0], -1))[0]
    #train_data = np.reshape(train_data[:2], (2, -1)).repeat(100, 0)
    test_data = test_data.reshape((test_data.shape[0], -1))[:10]
    #logger.save_extra_data(info)
    logger.get_snapshot_dir()
    variant['vae_kwargs']['architecture'] = iodine.imsize84_iodine_architecture
    variant['vae_kwargs']['decoder_class'] = BroadcastCNN

    refinement_net = RefinementNetwork(
        **iodine.imsize84_iodine_architecture['refine_args'],
        hidden_activation=nn.ELU())
    m = IodineVAE(**variant['vae_kwargs'], refinement_net=refinement_net)

    m.to(ptu.device)
    t = IodineTrainer(train_data, test_data, m, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch,
                      batches=train_data.shape[0] //
                      variant['algo_kwargs']['batch_size'])
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_vae=False)
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
예제 #30
0
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        statistics = OrderedDict()
        try:
            statistics.update(self.eval_statistics)
            self.eval_statistics = None
        except:
            print('No Stats to Eval')

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))

        if hasattr(self.env, "log_diagnostics"):
            self.env.log_diagnostics(test_paths)
        if hasattr(self.env, "log_statistics"):
            statistics.update(self.env.log_statistics(test_paths))
        if epoch % self.freq_log_visuals == 0:
            if hasattr(self.env, "log_visuals"):
                self.env.log_visuals(test_paths, epoch,
                                     logger.get_snapshot_dir())

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['AverageReturn'] = average_returns
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        best_statistic = statistics[self.best_key]
        if best_statistic > self.best_statistic_so_far:
            self.best_statistic_so_far = best_statistic
            if self.save_best and epoch >= self.save_best_starting_from_epoch:
                data_to_save = {'epoch': epoch, 'statistics': statistics}
                data_to_save.update(self.get_epoch_snapshot(epoch))
                logger.save_extra_data(data_to_save, 'best.pkl')
                print('\n\nSAVED BEST\n\n')