Example #1
0
    def test_epoch(self, epoch, save_network=True, batches=100):
        self.model.eval()
        mses = []
        losses = []
        for batch_idx in range(batches):
            data = self.get_batch(train=False)
            z = data["z"]
            z_proj = data['z_proj']
            z_proj_hat = self.model(z)
            mse = self.mse_loss(z_proj_hat, z_proj)
            loss = mse

            mses.append(mse.data[0])
            losses.append(loss.data[0])

        logger.record_tabular("test/epoch", epoch)
        logger.record_tabular("test/MSE", np.mean(mses))
        logger.record_tabular("test/loss", np.mean(losses))

        logger.dump_tabular()
        if save_network:
            logger.save_itr_params(epoch,
                                   self.model,
                                   prefix='reproj',
                                   save_anyway=True)
Example #2
0
 def train(self, start_epoch=0):
     self.pretrain()
     if start_epoch == 0:
         params = self.get_epoch_snapshot(-1)
         logger.save_itr_params(-1, params)
     self.training_mode(False)
     self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch
     gt.reset()
     gt.set_def_unique(False)
     if self.collection_mode == 'online':
         self.train_online(start_epoch=start_epoch)
     elif self.collection_mode == 'online-parallel':
         try:
             self.train_parallel(start_epoch=start_epoch)
         except:
             import traceback
             traceback.print_exc()
             self.parallel_env.shutdown()
     elif self.collection_mode == 'batch':
         self.train_batch(start_epoch=start_epoch)
     elif self.collection_mode == 'offline':
         self.train_offline(start_epoch=start_epoch)
     else:
         raise TypeError("Invalid collection_mode: {}".format(
             self.collection_mode
         ))
     self.cleanup()
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
Example #4
0
    def test_epoch(self, epoch, save_vae=True, **kwargs):
        self.model.eval()
        losses = []
        kles = []
        zs = []

        recon_logging_dict = {
            'MSE': [],
            'WSE': [],
        }
        for k in self.extra_recon_logging:
            recon_logging_dict[k] = []

        beta = self.beta_schedule.get_value(epoch)
        for batch_idx in range(100):
            data = self.get_batch(train=False)
            obs = data['obs']
            next_obs = data['next_obs']
            actions = data['actions']
            recon_batch, mu, logvar = self.model(next_obs)
            mse = self.logprob(recon_batch, next_obs)
            wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights)
            for k, idx in self.extra_recon_logging.items():
                recon_loss = self.logprob(recon_batch, next_obs, idx=idx)
                recon_logging_dict[k].append(recon_loss.data[0])
            kle = self.kl_divergence(mu, logvar)
            if self.recon_loss_type == 'mse':
                loss = mse + beta * kle
            elif self.recon_loss_type == 'wse':
                loss = wse + beta * kle
            z_data = ptu.get_numpy(mu.cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :])
            losses.append(loss.data[0])
            recon_logging_dict['WSE'].append(wse.data[0])
            recon_logging_dict['MSE'].append(mse.data[0])
            kles.append(kle.data[0])
        zs = np.array(zs)
        self.model.dist_mu = zs.mean(axis=0)
        self.model.dist_std = zs.std(axis=0)

        for k in recon_logging_dict:
            logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k]))
        logger.record_tabular("test/KL", np.mean(kles))
        logger.record_tabular("test/loss", np.mean(losses))
        logger.record_tabular("beta", beta)

        process = psutil.Process(os.getpid())
        logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000))

        num_active_dims = 0
        for std in self.model.dist_std:
            if std > 0.15:
                num_active_dims += 1
        logger.record_tabular("num_active_dims", num_active_dims)

        logger.dump_tabular()
        if save_vae:
            logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True)  # slow...
Example #5
0
 def train_offline(self, start_epoch=0):
     self.training_mode(False)
     params = self.get_epoch_snapshot(-1)
     logger.save_itr_params(-1, params)
     for epoch in range(start_epoch, self.num_epochs):
         self._start_epoch(epoch)
         self._try_to_train()
         self._try_to_offline_eval(epoch)
         self._end_epoch()
Example #6
0
    def _try_to_eval(self, epoch, eval_paths=None):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))

        params = self.get_epoch_snapshot(epoch)
        logger.save_itr_params(epoch, params)

        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            # params = self.get_epoch_snapshot(epoch)
            # logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration."
                )
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            if self.collection_mode != 'online-parallel':
                times_itrs = gt.get_times().stamps.itrs
                train_time = times_itrs['train'][-1]
                sample_time = times_itrs['sample'][-1]
                if 'eval' in times_itrs:
                    eval_time = times_itrs['eval'][-1] if epoch > 0 else -1
                else:
                    eval_time = -1
                epoch_time = train_time + sample_time + eval_time
                total_time = gt.get_times().total

                logger.record_tabular('Train Time (s)', train_time)
                logger.record_tabular('(Previous) Eval Time (s)', eval_time)
                logger.record_tabular('Sample Time (s)', sample_time)
                logger.record_tabular('Epoch Time (s)', epoch_time)
                logger.record_tabular('Total Train Time (s)', total_time)
            else:
                logger.record_tabular('Epoch Time (s)',
                                      time.time() - self._epoch_start_time)
            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
 def train(self):
     timer.return_global_times = True
     for _ in range(self.num_epochs):
         self._begin_epoch()
         # logger.save_itr_params(self.epoch, self._get_snapshot())
         # timer.stamp('saving')
         log_dict, _ = self._train()
         logger.record_dict(log_dict)
         logger.dump_tabular(with_prefix=True, with_timestamp=False)
         logger.save_itr_params(self.epoch, self._get_snapshot())
         self._end_epoch()
Example #8
0
 def _try_to_offline_eval(self, epoch):
     start_time = time.time()
     logger.save_extra_data(self.get_extra_data_to_save(epoch))
     self.offline_evaluate(epoch)
     params = self.get_epoch_snapshot(epoch)
     logger.save_itr_params(epoch, params)
     table_keys = logger.get_table_key_set()
     if self._old_table_keys is not None:
         assert table_keys == self._old_table_keys, (
             "Table keys cannot change from iteration to iteration.")
     self._old_table_keys = table_keys
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
     logger.log("Eval Time: {0}".format(time.time() - start_time))
Example #9
0
 def train(self):
     self.fix_data_set()
     logger.log("Done creating dataset.")
     num_batches_total = 0
     for epoch in range(self.num_epochs):
         for _ in range(self.num_batches_per_epoch):
             self.qf.train(True)
             self._do_training()
             num_batches_total += 1
         logger.push_prefix('Iteration #%d | ' % epoch)
         self.qf.train(False)
         self.evaluate(epoch)
         params = self.get_epoch_snapshot(epoch)
         logger.save_itr_params(epoch, params)
         logger.log("Done evaluating")
         logger.pop_prefix()
def create_policy(variant):
    bottom_snapshot = joblib.load(variant['bottom_path'])
    column_snapshot = joblib.load(variant['column_path'])
    policy = variant['combiner_class'](
        policy1=bottom_snapshot['naf_policy'],
        policy2=column_snapshot['naf_policy'],
    )
    env = bottom_snapshot['env']
    logger.save_itr_params(0, dict(
        policy=policy,
        env=env,
    ))
    path = rollout(
        env,
        policy,
        max_path_length=variant['max_path_length'],
        animated=variant['render'],
    )
    env.log_diagnostics([path])
    logger.dump_tabular()
Example #11
0
    def train(self):
        for epoch in range(self.num_epochs):
            logger.push_prefix('Iteration #%d | ' % epoch)

            start_time = time.time()
            for _ in range(self.num_steps_per_epoch):
                batch = self.get_batch()
                train_dict = self.get_train_dict(batch)

                self.policy_optimizer.zero_grad()
                policy_loss = train_dict['Policy Loss']
                policy_loss.backward()
                self.policy_optimizer.step()
            logger.log("Train time: {}".format(time.time() - start_time))

            start_time = time.time()
            self.evaluate(epoch)
            logger.log("Eval time: {}".format(time.time() - start_time))

            params = self.get_epoch_snapshot(epoch)
            logger.save_itr_params(epoch, params)
            logger.pop_prefix()
Example #12
0
def train_rfeatures_model(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    # from railrl.torch.vae.conv_vae import (
    #     ConvVAE, ConvResnetVAE
    # )
    import railrl.torch.vae.conv_vae as conv_vae
    # from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    output_classes = variant["output_classes"]
    representation_size = variant["representation_size"]
    batch_size = variant["batch_size"]
    variant['dataset_kwargs']["output_classes"] = output_classes
    train_dataset, test_dataset, info = get_data(variant['dataset_kwargs'])

    num_train_workers = variant.get("num_train_workers",
                                    0)  # 0 uses main process (good for pdb)
    train_dataset_loader = InfiniteBatchLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_train_workers,
    )
    test_dataset_loader = InfiniteBatchLoader(
        test_dataset,
        batch_size=batch_size,
    )

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['model_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['model_kwargs']['architecture'] = architecture

    model_class = variant.get('model_class', TimestepPredictionModel)
    model = model_class(
        representation_size,
        decoder_output_activation=decoder_activation,
        output_classes=output_classes,
        **variant['model_kwargs'],
    )
    # model = torch.nn.DataParallel(model)
    model.to(ptu.device)

    variant['trainer_kwargs']['batch_size'] = batch_size
    trainer_class = variant.get('trainer_class', TimePredictionTrainer)
    trainer = trainer_class(
        model,
        **variant['trainer_kwargs'],
    )
    save_period = variant['save_period']

    trainer.dump_trajectory_rewards(
        "initial", dict(train=train_dataset.dataset,
                        test=test_dataset.dataset))

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset_loader, batches=10)
        trainer.test_epoch(epoch, test_dataset_loader, batches=1)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)

        trainer.dump_trajectory_rewards(
            epoch, dict(train=train_dataset.dataset,
                        test=test_dataset.dataset), should_save_imgs)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
Example #14
0
    def test_epoch(self,
                   epoch,
                   save_reconstruction=True,
                   save_interpolation=True,
                   save_vae=True):
        self.model.eval()
        vae_losses = []
        iwae_losses = []
        losses = []
        des = []
        kles = []
        linear_losses = []
        zs = []
        beta = float(self.beta_schedule.get_value(epoch))

        for batch_idx in range(10):
            data = self.get_batch(train=False)
            obs = data['obs']
            obs = obs.detach()
            next_obs = data['next_obs']
            next_obs = next_obs.detach()
            actions = data['actions']
            actions = actions.detach()

            x_recon, z_mu, z_logvar, z = self.model(next_obs, n_imp=25)
            x_recon = x_recon.detach()
            z_mu = z_mu.detach()
            z = z.detach()

            batch_size = x_recon.shape[0]
            k = x_recon.shape[1]
            x = next_obs.view((batch_size, 1, -1)).repeat(torch.Size([1, k,
                                                                      1]))
            x = x.detach()
            vae_loss, de, kle = self.compute_vae_loss(x_recon, x, z_mu,
                                                      z_logvar, z, beta)
            vae_loss, de, kle = vae_loss.detach(), de.detach(), kle.detach()
            iwae_loss = self.compute_iwae_loss(x_recon, x, z_mu, z_logvar, z,
                                               beta)
            iwae_loss = iwae_loss.detach()
            loss = vae_loss
            if self.use_linear_dynamics:
                linear_dynamics_loss = self.state_linearity_loss(
                    obs, next_obs, actions)
                linear_dynamics_loss = linear_dynamics_loss.detach()
                loss += self.linearity_weight * linear_dynamics_loss
                linear_losses.append(float(
                    linear_dynamics_loss.data[0]))  #here too

            z_data = ptu.get_numpy(z_mu[:, 0].cpu())
            for i in range(len(z_data)):
                zs.append(z_data[i, :].copy())
            vae_losses.append(float(vae_loss.data[0]))
            iwae_losses.append(float(iwae_loss.data[0]))
            losses.append(float(loss.data[0]))
            des.append(float(de.data[0]))
            kles.append(float(kle.data[0]))

            if batch_idx == 0 and save_reconstruction:
                n = min(data['next_obs'].size(0), 16)
                comparison = torch.cat([
                    data['next_obs'][:n].narrow(start=0,
                                                length=self.imlength,
                                                dimension=1).contiguous().view(
                                                    -1, self.input_channels,
                                                    self.imsize, self.imsize),
                    x_recon[:, 0].contiguous().view(
                        self.batch_size,
                        self.input_channels,
                        self.imsize,
                        self.imsize,
                    )[:n]
                ])
                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'r_%d.png' % epoch)
                save_image(comparison.data.cpu(), save_dir, nrow=n)
                del comparison

            if batch_idx == 0 and save_interpolation:
                n = min(data['next_obs'].size(0), 10)

                z1 = z_mu[:n, 0]
                z2 = z_mu[n:2 * n, 0]

                num_steps = 8

                z_interp = []
                for i in np.linspace(0.0, 1.0, num_steps):
                    z_interp.append(float(i) * z1 + float(1 - i) * z2)
                z_interp = torch.cat(z_interp)

                imgs = self.model.decode(z_interp)
                imgs = imgs.view((num_steps, n, 3, self.imsize, self.imsize))
                imgs = imgs.permute([1, 0, 2, 3, 4])
                imgs = imgs.contiguous().view(
                    (n * num_steps, 3, self.imsize, self.imsize))

                save_dir = osp.join(logger.get_snapshot_dir(),
                                    'i_%d.png' % epoch)
                save_image(
                    imgs.data.cpu(),
                    save_dir,
                    nrow=num_steps,
                )
                del imgs
                del z_interp

            del obs, next_obs, actions, x_recon, z_mu, z_logvar, \
                z, x, vae_loss, de, kle, loss

        zs = np.array(zs)
        self.model.dist_mu = zs.mean(axis=0)
        self.model.dist_std = zs.std(axis=0)
        del zs

        logger.record_tabular("test/decoder_loss", np.mean(des))
        logger.record_tabular("test/KL", np.mean(kles))
        if self.use_linear_dynamics:
            logger.record_tabular("test/linear_loss", np.mean(linear_losses))
        logger.record_tabular("test/loss", np.mean(losses))
        logger.record_tabular("test/vae_loss", np.mean(vae_losses))
        logger.record_tabular("test/iwae_loss", np.mean(iwae_losses))
        logger.record_tabular(
            "test/iwae_vae_diff",
            np.mean(np.array(iwae_losses) - np.array(vae_losses)))
        logger.record_tabular("beta", beta)

        process = psutil.Process(os.getpid())
        logger.record_tabular("RAM Usage (Mb)",
                              int(process.memory_info().rss / 1000000))

        num_active_dims = 0
        num_active_dims2 = 0
        for std in self.model.dist_std:
            if std > 0.15:
                num_active_dims += 1
            if std > 0.05:
                num_active_dims2 += 1
        logger.record_tabular("num_active_dims", num_active_dims)
        logger.record_tabular("num_active_dims2", num_active_dims2)

        logger.dump_tabular()
        if save_vae:
            logger.save_itr_params(epoch,
                                   self.model,
                                   prefix='vae',
                                   save_anyway=True)  # slow...
def train(
        dataset_generator,
        n_start_samples,
        projection=project_samples_square_np,
        n_samples_to_add_per_epoch=1000,
        n_epochs=100,
        z_dim=1,
        hidden_size=32,
        save_period=10,
        append_all_data=True,
        full_variant=None,
        dynamics_noise=0,
        decoder_output_var='learned',
        num_bins=5,
        skew_config=None,
        use_perfect_samples=False,
        use_perfect_density=False,
        vae_reset_period=0,
        vae_kwargs=None,
        use_dataset_generator_first_epoch=True,
        **kwargs
):

    """
    Sanitize Inputs
    """
    assert skew_config is not None
    if not (use_perfect_density and use_perfect_samples):
        assert vae_kwargs is not None
    if vae_kwargs is None:
        vae_kwargs = {}

    report = HTMLReport(
        logger.get_snapshot_dir() + '/report.html',
        images_per_row=10,
    )
    dynamics = Dynamics(projection, dynamics_noise)
    if full_variant:
        report.add_header("Variant")
        report.add_text(
            json.dumps(
                ppp.dict_to_safe_json(
                    full_variant,
                    sort=True),
                indent=2,
            )
        )

    vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
        decoder_output_var,
        hidden_size,
        z_dim,
        vae_kwargs,
    )
    vae.to(ptu.device)

    epochs = []
    losses = []
    kls = []
    log_probs = []
    hist_heatmap_imgs = []
    vae_heatmap_imgs = []
    sample_imgs = []
    entropies = []
    tvs_to_uniform = []
    entropy_gains_from_reweighting = []
    p_theta = Histogram(num_bins)
    p_new = Histogram(num_bins)

    orig_train_data = dataset_generator(n_start_samples)
    train_data = orig_train_data
    start = time.time()
    for epoch in progressbar(range(n_epochs)):
        p_theta = Histogram(num_bins)
        if epoch == 0 and use_dataset_generator_first_epoch:
            vae_samples = dataset_generator(n_samples_to_add_per_epoch)
        else:
            if use_perfect_samples and epoch != 0:
                # Ideally the VAE = p_new, but in practice, it won't be...
                vae_samples = p_new.sample(n_samples_to_add_per_epoch)
            else:
                vae_samples = vae.sample(n_samples_to_add_per_epoch)
        projected_samples = dynamics(vae_samples)
        if append_all_data:
            train_data = np.vstack((train_data, projected_samples))
        else:
            train_data = np.vstack((orig_train_data, projected_samples))

        p_theta.fit(train_data)
        if use_perfect_density:
            prob = p_theta.compute_density(train_data)
        else:
            prob = vae.compute_density(train_data)
        all_weights = prob_to_weight(prob, skew_config)
        p_new.fit(train_data, weights=all_weights)
        if epoch == 0 or (epoch + 1) % save_period == 0:
            epochs.append(epoch)
            report.add_text("Epoch {}".format(epoch))
            hist_heatmap_img = visualize_histogram(p_theta, skew_config, report)
            vae_heatmap_img = visualize_vae(
                vae, skew_config, report,
                resolution=num_bins,
            )
            sample_img = visualize_vae_samples(
                epoch, train_data, vae, report, dynamics,
            )

            visualize_samples(
                p_theta.sample(n_samples_to_add_per_epoch),
                report,
                title="P Theta/RB Samples",
            )
            visualize_samples(
                p_new.sample(n_samples_to_add_per_epoch),
                report,
                title="P Adjusted Samples",
            )
            hist_heatmap_imgs.append(hist_heatmap_img)
            vae_heatmap_imgs.append(vae_heatmap_img)
            sample_imgs.append(sample_img)
            report.save()

            Image.fromarray(hist_heatmap_img).save(
                logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)
            )
            Image.fromarray(vae_heatmap_img).save(
                logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)
            )
            Image.fromarray(sample_img).save(
                logger.get_snapshot_dir() + '/samples{}.png'.format(epoch)
            )

        """
        train VAE to look like p_new
        """
        if sum(all_weights) == 0:
            all_weights[:] = 1
        if vae_reset_period > 0 and epoch % vae_reset_period == 0:
            vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
                decoder_output_var,
                hidden_size,
                z_dim,
                vae_kwargs,
            )
            vae.to(ptu.device)
        vae.fit(train_data, weights=all_weights)
        epoch_stats = vae.get_epoch_stats()

        losses.append(np.mean(epoch_stats['losses']))
        kls.append(np.mean(epoch_stats['kls']))
        log_probs.append(np.mean(epoch_stats['log_probs']))
        entropies.append(p_theta.entropy())
        tvs_to_uniform.append(p_theta.tv_to_uniform())
        entropy_gain = p_new.entropy() - p_theta.entropy()
        entropy_gains_from_reweighting.append(entropy_gain)

        for k in sorted(epoch_stats.keys()):
            logger.record_tabular(k, epoch_stats[k])

        logger.record_tabular("Epoch", epoch)
        logger.record_tabular('Entropy ', p_theta.entropy())
        logger.record_tabular('KL from uniform', p_theta.kl_from_uniform())
        logger.record_tabular('TV to uniform', p_theta.tv_to_uniform())
        logger.record_tabular('Entropy gain from reweight', entropy_gain)
        logger.record_tabular('Total Time (s)', time.time() - start)
        logger.dump_tabular()
        logger.save_itr_params(epoch, {
            'vae': vae,
            'train_data': train_data,
            'vae_samples': vae_samples,
            'dynamics': dynamics,
        })

    report.add_header("Training Curves")
    plot_curves(
        [
            ("Training Loss", losses),
            ("KL", kls),
            ("Log Probs", log_probs),
            ("Entropy Gain from Reweighting", entropy_gains_from_reweighting),
        ],
        report,
    )
    plot_curves(
        [
            ("Entropy", entropies),
            ("TV to Uniform", tvs_to_uniform),
        ],
        report,
    )
    report.add_text("Max entropy: {}".format(p_theta.max_entropy()))
    report.save()

    for filename, imgs in [
        ("hist_heatmaps", hist_heatmap_imgs),
        ("vae_heatmaps", vae_heatmap_imgs),
        ("samples", sample_imgs),
    ]:
        video = np.stack(imgs)
        vwrite(
            logger.get_snapshot_dir() + '/{}.mp4'.format(filename),
            video,
        )
        local_gif_file_path = '{}.gif'.format(filename)
        gif_file_path = '{}/{}'.format(
            logger.get_snapshot_dir(),
            local_gif_file_path
        )
        gif(gif_file_path, video)
        report.add_image(local_gif_file_path, txt=filename, is_url=True)
    report.save()
def experiment(variant):
    num_rollouts = variant['num_rollouts']
    path = variant['qf_path']
    data = joblib.load(path)
    goal_conditioned_model = data['qf']
    env = data['env']
    argmax_qf_policy = data['policy']
    extra_data_path = Path(path).parent / 'extra_data.pkl'
    extra_data = joblib.load(str(extra_data_path))
    replay_buffer = extra_data['replay_buffer']
    """
    Train amortized policy
    """
    # goal_chooser = Mlp(
    #     output_size=env.goal_dim,
    #     input_size=int(env.observation_space.flat_dim),
    #     hidden_sizes=[100, 100],
    # )
    # goal_chooser = ReacherGoalChooser(
    #     hidden_sizes=[64, 64],
    # )
    goal_chooser = UniversalGoalChooser(input_goal_dim=7,
                                        output_goal_dim=env.goal_dim,
                                        obs_dim=int(
                                            env.observation_space.flat_dim),
                                        **variant['goal_chooser_params'])
    tau = variant['tau']
    if ptu.gpu_enabled():
        goal_chooser.to(ptu.device)
        goal_conditioned_model.to(ptu.device)
        argmax_qf_policy.to(ptu.device)
    train_amortized_goal_chooser(goal_chooser, goal_conditioned_model,
                                 argmax_qf_policy, tau, replay_buffer,
                                 **variant['train_params'])
    policy = AmortizedPolicy(argmax_qf_policy, goal_chooser)

    goal = np.array(variant['goal'])
    logger.save_itr_params(
        0, dict(
            env=env,
            policy=policy,
            goal_chooser=goal_chooser,
            goal=goal,
        ))
    """
    Eval policy.
    """
    paths = []
    # env.set_goal(goal)
    for _ in range(num_rollouts):
        # path = rollout(
        #     env,
        #     policy,
        #     **variant['rollout_params']
        # )
        # goal_expanded = np.expand_dims(goal, axis=0)
        # path['goal_states'] = goal_expanded.repeat(len(path['observations']), 0)
        goal = env.sample_goal_for_rollout()
        path = multitask_rollout(env, policy, goal,
                                 **variant['rollout_params'])
        paths.append(path)
    env.log_diagnostics(paths)
    logger.dump_tabular(with_timestamp=False)
Example #17
0
    def pretrain_policy_with_bc(self):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_policy.csv',
                                  relative_to_snapshot_dir=True)
        for i in range(self.bc_num_pretrain_steps):
            train_batch = self.get_batch_from_buffer(self.demo_train_buffer)
            train_o = train_batch["observations"]
            train_u = train_batch["actions"]
            if self.goal_conditioned:
                train_g = train_batch["resampled_goals"]
                train_o = torch.cat((train_o, train_g), dim=1)

            train_pred_u = self.policy(train_o)
            train_error = (train_pred_u - train_u)**2
            train_bc_loss = train_error.mean()

            policy_loss = self.bc_weight * train_bc_loss.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            test_batch = self.get_batch_from_buffer(self.demo_test_buffer)
            test_o = test_batch["observations"]
            test_u = test_batch["actions"]

            if self.goal_conditioned:
                test_g = test_batch["resampled_goals"]
                test_o = torch.cat((test_o, test_g), dim=1)

            test_pred_u = self.policy(test_o)

            test_error = (test_pred_u - test_u)**2
            test_bc_loss = test_error.mean()

            train_loss_mean = np.mean(ptu.get_numpy(train_bc_loss))
            test_loss_mean = np.mean(ptu.get_numpy(test_bc_loss))

            stats = {
                "Train BC Loss": train_loss_mean,
                "Test BC Loss": test_loss_mean,
                "policy_loss": ptu.get_numpy(policy_loss),
                "batch": i,
            }
            logger.record_dict(stats)
            logger.dump_tabular(with_prefix=True, with_timestamp=False)

            if i % 1000 == 0:
                logger.save_itr_params(
                    i, {
                        "evaluation/policy": self.policy,
                        "evaluation/env": self.env.wrapped_env,
                    })

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
Example #18
0
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule
    from railrl.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import railrl.torch.vae.conv_vae as conv_vae
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant.get("representation_size",
                                      variant.get("latent_sizes", None))
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    variant['generate_vae_dataset_kwargs']['batch_size'] = variant[
        'algo_kwargs']['batch_size']
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if 'context_schedule' in variant:
        schedule = variant['context_schedule']
        if type(schedule) is dict:
            context_schedule = PiecewiseLinearSchedule(**schedule)
        else:
            context_schedule = ConstantSchedule(schedule)
        variant['algo_kwargs']['context_schedule'] = context_schedule
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        model = AutoEncoder(representation_size,
                            decoder_output_activation=decoder_activation,
                            **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        model = SpatialAutoEncoder(
            representation_size,
            decoder_output_activation=decoder_activation,
            **variant['vae_kwargs'])
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        if use_linear_dynamics:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              action_dim=action_dim,
                              **variant['vae_kwargs'])
        else:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              **variant['vae_kwargs'])
    model.to(ptu.device)

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model