Exemplo n.º 1
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     save_dir = osp.join(logger.get_snapshot_dir(),
                         'samples_%d.png' % epoch)
     save_image(
         sample.data.view(64, self.input_channels, self.img_size,
                          self.img_size).transpose(2, 3), save_dir)
Exemplo n.º 2
0
def get_video_save_func(rollout_function, env, policy, variant):
    logdir = logger.get_snapshot_dir()
    save_period = variant.get('save_video_period', 50)
    do_state_exp = variant.get("do_state_exp", False)
    dump_video_kwargs = variant.get("dump_video_kwargs", dict())
    if do_state_exp:
        imsize = variant.get('imsize')
        dump_video_kwargs['imsize'] = imsize
        image_env = ImageEnv(
            env,
            imsize,
            init_camera=variant.get('init_camera', None),
            transpose=True,
            normalize=True,
        )

        def save_video(algo, epoch):
            if epoch % save_period == 0 or epoch == algo.num_epochs:
                filename = osp.join(
                    logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
                dump_video(image_env, policy, filename, rollout_function,
                           **dump_video_kwargs)
    else:
        image_env = env
        dump_video_kwargs['imsize'] = env.imsize

        def save_video(algo, epoch):
            if epoch % save_period == 0 or epoch == algo.num_epochs:
                filename = osp.join(
                    logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_env',
                               func=dump_video,
                               args=(image_env, policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(image_env, policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)

    return save_video
Exemplo n.º 3
0
    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.img_size,
                                 self.img_size).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.img_size,
                              self.img_size).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=len(idxs),
        )
Exemplo n.º 4
0
    def dump_uniform_imgs_and_reconstructions(self, dataset, epoch):
        idxs = np.random.choice(range(dataset.shape[0]), 4)
        filename = 'uniform{}.png'.format(epoch)
        imgs = []
        recons = []
        for i in idxs:
            img_np = dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.img_size,
                                 self.img_size).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.img_size,
                              self.img_size).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=4,
        )
Exemplo n.º 5
0
    def test_epoch(
        self,
        epoch,
        save_reconstruction=True,
        save_vae=True,
        from_rl=False,
    ):
        self.model.eval()
        losses = []
        log_probs = []
        kles = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))
        for batch_idx in range(10):
            next_obs = self.get_batch(train=False)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                next_obs)
            re_show = reconstructions[0].reshape(self.input_channels,
                                                 self.img_size, self.img_size)
            show_one_tensor_image(re_show,
                                  channel_first=True,
                                  name='test reconstruction')
            log_prob = self.model.logprob(next_obs, obs_distribution_params)
            kle = self.model.kl_divergence(latent_distribution_params)
            loss = -1 * log_prob + beta * kle

            encoder_mean = latent_distribution_params[0]
            z_data = ptu.get_numpy(encoder_mean.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.item())
            log_probs.append(log_prob.item())
            kles.append(kle.item())

            if batch_idx == 0 and save_reconstruction:
                n = min(next_obs.size(0), 8)
                comparison = torch.cat([
                    next_obs[:n].narrow(start=0, length=self.imlength,
                                        dim=1).contiguous().view(
                                            -1, self.input_channels,
                                            self.img_size,
                                            self.img_size).transpose(2, 3),
                    reconstructions.view(
                        self.batch_size,
                        self.input_channels,
                        self.img_size,
                        self.img_size,
                    )[:n].transpose(2, 3)
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'test_%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)

        zs = np.array(zs)

        self.eval_statistics['epoch'] = epoch
        self.eval_statistics['test/log prob'] = np.mean(log_probs)
        self.eval_statistics['test/KL'] = np.mean(kles)
        self.eval_statistics['test/loss'] = np.mean(losses)
        self.eval_statistics['beta'] = beta
        if not from_rl:
            for k, v in self.eval_statistics.items():
                logger.record_tabular(k, v)
            logger.dump_tabular()
            if save_vae:
                logger.save_itr_params(epoch, self.model)
Exemplo n.º 6
0
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rorlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rorlkit.torch.vae.conv_vae as conv_vae
    from rorlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rorlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data = prepare_vae_dataset(variant)
    # train_data, test_data, info = generate_vae_dataset(cfgs)
    # logger.save_extra_data(info)
    logger.get_snapshot_dir()

    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
Exemplo n.º 7
0
def train_vae(cfgs, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    train_data, test_data, info = generate_vae_dataset(cfgs)
    logger.save_extra_data(info)
    logger.get_snapshot_dir()

    # FIXME default gaussian
    if cfgs.VAE.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = cfgs.VAE.get('architecture', None)
    if not architecture and cfgs.ENV.get('img_size') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and cfgs.ENV.get('img_size') == 48:
        architecture = conv_vae.imsize48_default_architecture

    vae_model = ConvVAE(
        representation_size=cfgs.VAE.representation_size,
        architecture=architecture,
        decoder_output_activation=decoder_activation,
        input_channels=cfgs.VAE.input_channels,
        decoder_distribution=cfgs.VAE.decoder_distribution,
        imsize=cfgs.VAE.imsize,
    )
    vae_model.to(ptu.device)

    # FIXME the function of beta_schedule?
    if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER:
        beta_schedule = PiecewiseLinearSchedule(
            **cfgs.VAE_TRAINER.beta_schedule_kwargs)
    else:
        beta_schedule = None

    t = ConvVAETrainer(train_data,
                       test_data,
                       vae_model,
                       lr=cfgs.VAE_TRAINER.lr,
                       beta=cfgs.VAE_TRAINER.beta,
                       beta_schedule=beta_schedule)

    save_period = cfgs.VAE_TRAINER.save_period
    for epoch in range(cfgs.VAE_TRAINER.num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    if return_data:
        return vae_model, train_data, test_data
    return vae_model