def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs']
    )
    uniform_dataset=generate_uniform_dataset_reacher(
       **variant['generate_uniform_dataset_kwargs']
    )
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = None
    m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta,
                       beta_schedule=beta_schedule, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(m, uniform_dataset)
        t.test_epoch(epoch, save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch)
        if epoch % variant['train_weight_update_period'] == 0:
            t.update_train_weights()
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
        gpu_id = variant.get("gpu_id", None)
        if gpu_id is not None:
            ptu.set_device(gpu_id)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
Esempio n. 4
0
def dump_latent_plots(vae_env, epoch):
    from railrl.core import logger
    import os.path as osp
    from torchvision.utils import save_image

    if getattr(vae_env, "get_states_sweep", None) is None:
        return

    nx, ny = (vae_env.vis_granularity, vae_env.vis_granularity)
    states_sweep = vae_env.get_states_sweep(nx, ny)
    sweep_latents_mu, sweep_latents_logvar = vae_env.encode_states(states_sweep, clip_std=False)

    sweep_latents_std = np.exp(0.5*sweep_latents_logvar)
    sweep_latents_sample = vae_env.reparameterize(sweep_latents_mu, sweep_latents_logvar, noisy=True)
    images_mu_sc, images_std_sc, images_sample_sc = [], [], []
    imsize = 84
    for i in range(sweep_latents_mu.shape[1]):
        image_mu_sc = vae_env.transform_image(vae_env.get_image_plt(
            sweep_latents_mu[:,i].reshape((nx, ny)),
            vmin=-2.0, vmax=2.0,
            draw_state=False, imsize=imsize))
        images_mu_sc.append(image_mu_sc)

        image_std_sc = vae_env.transform_image(vae_env.get_image_plt(
            sweep_latents_std[:,i].reshape((nx, ny)),
            vmin=0.0, vmax=2.0,
            draw_state=False, imsize=imsize))
        images_std_sc.append(image_std_sc)

        image_sample_sc = vae_env.transform_image(vae_env.get_image_plt(
            sweep_latents_sample[:,i].reshape((nx, ny)),
            vmin=-3.0, vmax=3.0,
            draw_state=False, imsize=imsize))
        images_sample_sc.append(image_sample_sc)

    images = images_mu_sc + images_std_sc + images_sample_sc
    images = np.array(images)

    if vae_env.representation_size > 16:
        nrow = 16
    else:
        nrow = vae_env.representation_size

    if epoch is not None:
        save_dir = osp.join(logger.get_snapshot_dir(), 'z_%d.png' % epoch)
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), 'z.png')
    save_image(
        ptu.FloatTensor(
            ptu.from_numpy(
                images.reshape(
                    (vae_env.representation_size*3, -1, imsize, imsize)
                ))),
        save_dir,
        nrow=nrow,
    )
Esempio n. 5
0
def save_paths(algo, epoch):
    expl_paths = algo.expl_data_collector.get_epoch_paths()
    filename = osp.join(logger.get_snapshot_dir(),
                        'video_{epoch}_vae.p'.format(epoch=epoch))
    pickle.dump(expl_paths, open(filename, "wb"))
    print("saved", filename)
    eval_paths = algo.eval_data_collector.get_epoch_paths()
    filename = osp.join(logger.get_snapshot_dir(),
                        'video_{epoch}_env.p'.format(epoch=epoch))
    pickle.dump(eval_paths, open(filename, "wb"))
    print("saved", filename)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    uniform_dataset = load_local_or_remote_file(
        variant['uniform_dataset_path']).item()
    uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size,
                       decoder_output_activation=nn.Sigmoid(),
                       **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(
            m, uniform_dataset,
            variant['algo_kwargs']['priority_function_kwargs'])
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(
                    dataset=uniform_dataset, epoch=epoch)
        t.update_train_weights()
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    # train_data, test_data, info = generate_vae_dataset(
    #     **variant['get_data_kwargs']
    # )
    num_divisions = 5
    images = np.zeros((num_divisions * 10000, 21168))
    for i in range(num_divisions):
        imgs = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_images100000_' +
            str(i + 1) + '.npy')
        images[i * 10000:(i + 1) * 10000] = imgs
        print(i)
    mid = int(num_divisions * 10000 * .9)
    train_data, test_data = images[:mid], images[mid:]
    info = dict()

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        kwargs = variant['beta_schedule_kwargs']
        kwargs['y_values'][2] = variant['beta']
        kwargs['x_values'][1] = variant['flat_x']
        kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Esempio n. 8
0
def dump_reconstructions(vae_env, epoch, n_recon=16):
    from railrl.core import logger
    import os.path as osp
    from torchvision.utils import save_image

    if vae_env.use_vae_dataset and vae_env.vae_dataset_path is not None:
        from multiworld.core.image_env import normalize_image
        from railrl.misc.asset_loader import local_path_from_s3_or_local_path
        filename = local_path_from_s3_or_local_path(vae_env.vae_dataset_path)
        dataset = np.load(filename).item()
        sampled_idx = np.random.choice(dataset['next_obs'].shape[0], n_recon)
        if vae_env.vae_input_key_prefix == 'state':
            states = dataset['next_obs'][sampled_idx]
            imgs = ptu.np_to_var(
                vae_env.wrapped_env.states_to_images(states)
            )
            recon_samples, _, _ = vae_env.vae(ptu.np_to_var(states))
            recon_imgs = ptu.np_to_var(
                vae_env.wrapped_env.states_to_images(ptu.get_numpy(recon_samples))
            )
        else:
            imgs = ptu.np_to_var(
                normalize_image(dataset['next_obs'][sampled_idx])
            )
            recon_imgs, _, _, _ = vae_env.vae(imgs)
        del dataset
    else:
        return

    comparison = torch.cat([
        imgs.narrow(start=0, length=vae_env.wrapped_env.image_length, dimension=1).contiguous().view(
            -1,
            vae_env.wrapped_env.channels,
            vae_env.wrapped_env.imsize,
            vae_env.wrapped_env.imsize
        ),
        recon_imgs.contiguous().view(
            n_recon,
            vae_env.wrapped_env.channels,
            vae_env.wrapped_env.imsize,
            vae_env.wrapped_env.imsize
        )[:n_recon]
    ])

    if epoch is not None:
        save_dir = osp.join(logger.get_snapshot_dir(), 'r_%d.png' % epoch)
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), 'r.png')
    save_image(comparison.data.cpu(), save_dir, nrow=n_recon)
def experiment(variant):
    num_feat_points=variant['feat_points']
    from railrl.core import logger
    beta = variant["beta"]
    print('collecting data')
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    print('finish collecting data')
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    m = SpatialAutoEncoder(2 * num_feat_points, num_feat_points, input_channels=3)
#    m = ConvVAE(2*num_feat_points, input_channels=3)
    t = ConvVAETrainer(train_data, test_data, m,  lr=variant['lr'], beta=beta)
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
Esempio n. 10
0
 def __init__(self, variant):
     self.logdir = logger.get_snapshot_dir()
     self.dump_buffer_kwargs = variant.get("dump_buffer_kwargs", dict())
     self.save_period = self.dump_buffer_kwargs.pop('dump_buffer_period',
                                                    50)
     self.buffer_dir = osp.join(self.logdir, 'buffers')
     if not osp.exists(self.buffer_dir):
         os.makedirs(self.buffer_dir)
Esempio n. 11
0
def run_task(variant):
    from railrl.core import logger
    print(variant)
    logger.log("Hello from script")
    logger.log("variant: " + str(variant))
    logger.record_tabular("value", 1)
    logger.dump_tabular()
    logger.log("snapshot_dir:", logger.get_snapshot_dir())
Esempio n. 12
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    data = joblib.load(variant['file'])
    obs = data['obs']
    size = int(data['size'])
    dataset = obs[:size, :]
    n = int(size * .9)
    train_data = dataset[:n, :]
    test_data = dataset[n:, :]
    logger.get_snapshot_dir()
    print('SIZE: ', size)
    uniform_dataset = generate_uniform_dataset_door(
        **variant['generate_uniform_dataset_kwargs']
    )
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta,
                       beta_schedule=beta_schedule, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(uniform_dataset)
        t.test_epoch(epoch, save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch)
        t.update_train_weights()
Esempio n. 13
0
    def pretrain_policy_with_bc(self):
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'pretrain_policy.csv', relative_to_snapshot_dir=True
        )
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret/20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret/20))

            if i % self.pretraining_logging_period==0:
                stats = {
                "pretrain_bc/batch": i,
                "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss),
                "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss),
                "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss),
                "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss),
                "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss),
                "pretrain_bc/epoch_time":time.time()-prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    #this has both states and images so can't use generate vae dataset
    X = np.load(
        '/home/murtaza/vae_data/sawyer_torque_control_ou_imgs_zoomed_out10000.npy'
    )
    Y = np.load(
        '/home/murtaza/vae_data/sawyer_torque_control_ou_states_zoomed_out10000.npy'
    )
    Y = np.concatenate((Y[:, :7], Y[:, 14:]), axis=1)
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.1)
    info = dict()
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                state_sim_debug=True,
                state_size=Y.shape[1],
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer((X_train, Y_train), (X_test, Y_test),
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       state_sim_debug=True,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Esempio n. 15
0
def plot_encoder_function(variant, encoder, tag=""):
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from railrl.core import logger
    logdir = logger.get_snapshot_dir()

    def plot_encoder(algo, epoch, is_x=False):
        save_period = variant.get('save_video_period', 50)
        if epoch % save_period == 0 or epoch == algo.num_epochs:
            filename = osp.join(
                logdir,
                'encoder_{}_{}_{epoch}_env.gif'.format(tag,
                                                       "x" if is_x else "y",
                                                       epoch=epoch))

            vary = np.arange(-4, 4, .1)
            static = np.zeros(len(vary))

            points_x = np.c_[vary.reshape(-1, 1), static.reshape(-1, 1)]
            points_y = np.c_[static.reshape(-1, 1), vary.reshape(-1, 1)]

            encoded_points_x = ptu.get_numpy(
                encoder.forward(ptu.from_numpy(points_x)))
            encoded_points_y = ptu.get_numpy(
                encoder.forward(ptu.from_numpy(points_y)))

            plt.clf()
            fig = plt.figure()
            plt.xlim(
                min(min(encoded_points_x[:, 0]), min(encoded_points_y[:, 0])),
                max(max(encoded_points_x[:, 0]), max(encoded_points_y[:, 0])))
            plt.ylim(
                min(min(encoded_points_x[:, 1]), min(encoded_points_y[:, 1])),
                max(max(encoded_points_x[:, 1]), max(encoded_points_y[:, 1])))
            colors = ["red", "blue"]
            lines = [
                plt.plot([], [], 'o', color=colors[i], alpha=0.4)[0]
                for i in range(2)
            ]

            def animate(i):
                lines[0].set_data(encoded_points_x[:i + 1, 0],
                                  encoded_points_x[:i + 1, 1])
                lines[1].set_data(encoded_points_y[:i + 1, 0],
                                  encoded_points_y[:i + 1, 1])
                return lines

            ani = FuncAnimation(fig, animate, frames=len(vary), interval=40)
            ani.save(filename, writer='imagemagick', fps=60)

    # def plot_encoder_x_and_y(algo, epoch):
    # plot_encoder(algo, epoch, is_x=True)
    # plot_encoder(algo, epoch, is_x=False)

    return plot_encoder
Esempio n. 16
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    ptu.set_gpu_mode(True)
    info = dict()
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    net = CNN(**variant['cnn_kwargs'])
    net.cuda()
    num_divisions = variant['num_divisions']
    images = np.zeros((num_divisions * 10000, 21168))
    states = np.zeros((num_divisions * 10000, 7))
    for i in range(num_divisions):
        imgs = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_images100000_' +
            str(i + 1) + '.npy')
        state = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_states100000_' +
            str(i + 1) + '.npy')[:, :7] % (2 * np.pi)
        images[i * 10000:(i + 1) * 10000] = imgs
        states[i * 10000:(i + 1) * 10000] = state
        print(i)
    if variant['normalize']:
        std = np.std(states, axis=0)
        mu = np.mean(states, axis=0)
        states = np.divide((states - mu), std)
        print(mu, std)
    mid = int(num_divisions * 10000 * .9)
    train_images, test_images = images[:mid], images[mid:]
    train_labels, test_labels = states[:mid], states[mid:]

    algo = SupervisedAlgorithm(train_images,
                               test_images,
                               train_labels,
                               test_labels,
                               net,
                               batch_size=variant['batch_size'],
                               lr=variant['lr'],
                               weight_decay=variant['weight_decay'])
    for epoch in range(variant['num_epochs']):
        algo.train_epoch(epoch)
        algo.test_epoch(epoch)
Esempio n. 17
0
def dump_samples(vae_env, epoch, n_samples=64):
    from railrl.core import logger
    from torchvision.utils import save_image
    import os.path as osp
    vae_env.vae.eval()
    sample = ptu.Variable(torch.randn(n_samples, vae_env.representation_size))
    sample = vae_env.vae.decode(sample).cpu()
    if vae_env.vae_input_key_prefix == 'state':
        sample = ptu.np_to_var(vae_env.wrapped_env.states_to_images(ptu.get_numpy(sample)))
        if sample is None:
            return
    if epoch is not None:
        save_dir = osp.join(logger.get_snapshot_dir(), 's_%d.png' % epoch)
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), 's.png')
    save_image(
        sample.data.view(n_samples, -1, vae_env.wrapped_env.imsize, vae_env.wrapped_env.imsize),
        save_dir,
        nrow=int(np.sqrt(n_samples))
    )
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size, **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
        t.update_train_weights()
def get_video_save_func(rollout_function, env, policy, variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.core import logger
    from railrl.envs.vae_wrappers import temporary_mode
    from railrl.visualization.video import dump_video
    logdir = logger.get_snapshot_dir()
    save_period = variant.get('save_video_period', 50)
    do_state_exp = variant.get("do_state_exp", False)
    dump_video_kwargs = variant.get("dump_video_kwargs", dict())
    dump_video_kwargs['horizon'] = variant['max_path_length']

    if do_state_exp:
        imsize = variant.get('imsize')
        dump_video_kwargs['imsize'] = imsize
        image_env = ImageEnv(
            env,
            imsize,
            init_camera=variant.get('init_camera', None),
            transpose=True,
            normalize=True,
        )

        def save_video(algo, epoch):
            if epoch % save_period == 0 or epoch == algo.num_epochs:
                filename = osp.join(
                    logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
                dump_video(image_env, policy, filename, rollout_function,
                           **dump_video_kwargs)
    else:
        image_env = env
        dump_video_kwargs['imsize'] = env.imsize

        def save_video(algo, epoch):
            if epoch % save_period == 0 or epoch == algo.num_epochs:
                filename = osp.join(
                    logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_env',
                               func=dump_video,
                               args=(image_env, policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(image_env, policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)

    return save_video
Esempio n. 20
0
    def visualization_post_processing(save_vis, save_video, epoch):
        vis_list = vis_variant.get('vis_list', [])

        if save_vis:
            if vae_env.vae_input_key_prefix == 'state':
                vae_env.dump_reconstructions(epoch,
                                             n_recon=vis_variant.get(
                                                 'n_recon', 16))
            vae_env.dump_samples(epoch,
                                 n_samples=vis_variant.get('n_samples', 64))
            if 'latent_representation' in vis_list:
                vae_env.dump_latent_plots(epoch)
            if any(elem in vis_list for elem in [
                    'latent_histogram', 'latent_histogram_mu',
                    'latent_histogram_2d', 'latent_histogram_mu_2d'
            ]):
                vae_env.compute_latent_histogram()
            if not save_video and ('latent_histogram' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=True,
                                              use_true_prior=True)
            if not save_video and ('latent_histogram_mu' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=False,
                                              use_true_prior=True)

        if save_video and save_vis:
            from railrl.envs.vae_wrappers import temporary_mode
            from railrl.misc.video_gen import dump_video
            from railrl.core import logger

            vae_env.compute_goal_encodings()

            logdir = logger.get_snapshot_dir()
            filename = osp.join(logdir,
                                'video_{epoch}.mp4'.format(epoch=epoch))
            variant['dump_video_kwargs']['epoch'] = epoch
            temporary_mode(vae_env,
                           mode='video_env',
                           func=dump_video,
                           args=(vae_env, random_policy, filename,
                                 rollout_function),
                           kwargs=variant['dump_video_kwargs'])
            if not vis_variant.get('save_video_env_only', True):
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(vae_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(vae_env, random_policy, filename,
                                     rollout_function),
                               kwargs=variant['dump_video_kwargs'])
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Esempio n. 23
0
def train_reprojection_network(variant):
    from railrl.torch.vae.reprojection_network import (
        ReprojectionNetwork,
        ReprojectionNetworkTrainer,
    )
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu

    logger.get_snapshot_dir()

    vae = variant['vae']

    generate_reprojection_network_dataset_kwargs = variant.get(
        "generate_reprojection_network_dataset_kwargs", {})
    generate_reprojection_network_dataset_kwargs['vae'] = vae
    train_data, test_data = generate_reprojection_network_dataset(
        generate_reprojection_network_dataset_kwargs)

    reprojection_network_kwargs = variant.get("reprojection_network_kwargs",
                                              {})
    m = ReprojectionNetwork(vae, **reprojection_network_kwargs)
    if ptu.gpu_enabled():
        m.cuda()

    algo_kwargs = variant.get("algo_kwargs", {})
    t = ReprojectionNetworkTrainer(train_data, test_data, m, **algo_kwargs)

    num_epochs = variant.get('num_epochs', 5000)
    for epoch in range(num_epochs):
        should_save_network = (epoch % 250 == 0 or epoch == num_epochs - 1)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_network=should_save_network,
        )
    logger.save_extra_data(m, 'reproj_network.pkl', mode='pickle')
    return m
Esempio n. 24
0
def dump_latent_histogram(vae_env, epoch, noisy=False, reproj=False, use_true_prior=None, draw_dots=False):
    from railrl.core import logger
    import os.path as osp
    from torchvision.utils import save_image

    images = vae_env.get_image_latent_histogram(
        noisy=noisy, reproj=reproj, draw_dots=draw_dots, use_true_prior=use_true_prior
    )
    if noisy:
        prefix = 'h'
    elif reproj:
        prefix = 'h_r'
    else:
        prefix = 'h_mu'

    if epoch is None:
        save_dir = osp.join(logger.get_snapshot_dir(), prefix + '.png')
    else:
        save_dir = osp.join(logger.get_snapshot_dir(), prefix + '_%d.png' % epoch)
    save_image(
        ptu.FloatTensor(ptu.from_numpy(images)),
        save_dir,
        nrow=int(np.sqrt(images.shape[0])),
    )
Esempio n. 25
0
def plot_buffer_function(save_period, buffer_key):
    import matplotlib.pyplot as plt
    from railrl.core import logger
    logdir = logger.get_snapshot_dir()

    def plot_buffer(algo, epoch):
        replay_buffer = algo.replay_buffer
        if epoch % save_period == 0 or epoch == algo.num_epochs:
            filename = osp.join(
                logdir, '{}_buffer_{epoch}_env.png'.format(buffer_key,
                                                           epoch=epoch))
            goals = replay_buffer._next_obs[buffer_key][:replay_buffer._size]

            plt.clf()
            plt.scatter(goals[:, 0], goals[:, 1], alpha=0.2)
            plt.savefig(filename)

    return plot_buffer
Esempio n. 26
0
 def __init__(self,
              env,
              variant,
              expl_path_collector=None,
              eval_path_collector=None):
     self.env = env
     self.logdir = logger.get_snapshot_dir()
     self.dump_video_kwargs = variant.get("dump_video_kwargs", dict())
     if 'imsize' not in self.dump_video_kwargs:
         self.dump_video_kwargs['imsize'] = env.imsize
     self.dump_video_kwargs.setdefault("rows", 2)
     self.dump_video_kwargs.setdefault("columns", 5)
     self.dump_video_kwargs.setdefault("unnormalize", True)
     self.save_period = self.dump_video_kwargs.pop('save_video_period', 50)
     self.exploration_goal_image_key = self.dump_video_kwargs.pop(
         "exploration_goal_image_key", "decoded_goal_image")
     self.evaluation_goal_image_key = self.dump_video_kwargs.pop(
         "evaluation_goal_image_key", "image_desired_goal")
     self.expl_path_collector = expl_path_collector
     self.eval_path_collector = eval_path_collector
     self.variant = variant
Esempio n. 27
0
def get_save_video_function(rollout_function,
                            env,
                            policy,
                            save_video_period=10,
                            imsize=48,
                            tag="",
                            video_image_env_kwargs=None,
                            **dump_video_kwargs):
    logdir = logger.get_snapshot_dir()

    if not isinstance(env, ImageEnv) and not isinstance(env, VAEWrappedEnv):
        if video_image_env_kwargs is None:
            video_image_env_kwargs = {}
        image_env = ImageEnv(env,
                             imsize,
                             transpose=True,
                             normalize=True,
                             **video_image_env_kwargs)
    else:
        image_env = env
        assert image_env.imsize == imsize, "Imsize must match env imsize"

    def save_video(algo, epoch):
        if epoch % save_video_period == 0 or epoch == algo.num_epochs:
            filename = osp.join(
                logdir,
                'video_{}_{epoch}_env.mp4'.format(tag, epoch=epoch),
            )
            dump_video(image_env,
                       policy,
                       filename,
                       rollout_function,
                       imsize=imsize,
                       **dump_video_kwargs)

    return save_video