Ejemplo n.º 1
0
 def denormalize(self, v):
     mean = ptu.from_numpy(self.mean)
     std = ptu.from_numpy(self.std)
     if v.dim() == 2:
         mean = mean.unsqueeze(0)
         std = std.unsqueeze(0)
     return mean + v * std
Ejemplo n.º 2
0
 def get_debug_batch(self, train=True):
     dataset = self.train_dataset if train else self.test_dataset
     X, Y = dataset
     ind = np.random.randint(0, Y.shape[0], self.batch_size)
     X = X[ind, :]
     Y = Y[ind, :]
     return ptu.from_numpy(X), ptu.from_numpy(Y)
Ejemplo n.º 3
0
 def normalize(self, v, clip_range=None):
     if clip_range is None:
         clip_range = self.default_clip_range
     mean = ptu.from_numpy(self.mean)
     std = ptu.from_numpy(self.std)
     if v.dim() == 2:
         # Unsqueeze along the batch use automatic broadcasting
         mean = mean.unsqueeze(0)
         std = std.unsqueeze(0)
     return torch.clamp((v - mean) / std, -clip_range, clip_range)
Ejemplo n.º 4
0
    def get_batch(self, train=True, epoch=None):
        if self.use_parallel_dataloading:
            if not train:
                dataloader = self.test_dataloader
            else:
                dataloader = self.train_dataloader
            samples = next(dataloader).to(ptu.device)
            return samples

        dataset = self.train_dataset if train else self.test_dataset
        skew = False
        if epoch is not None:
            skew = (self.start_skew_epoch < epoch)
        if train and self.skew_dataset and skew:
            probs = self._train_weights / np.sum(self._train_weights)
            ind = np.random.choice(
                len(probs),
                self.batch_size,
                p=probs,
            )
        else:
            ind = np.random.randint(0, len(dataset), self.batch_size)
        samples = normalize_image(dataset[ind, :])
        if self.normalize:
            samples = ((samples - self.train_data_mean) + 1) / 2
        if self.background_subtract:
            samples = samples - self.train_data_mean
        return ptu.from_numpy(samples)
Ejemplo n.º 5
0
 def get_dataset_stats(self, data):
     torch_input = ptu.from_numpy(normalize_image(data))
     mus, log_vars = self.model.encode(torch_input)
     mus = ptu.get_numpy(mus)
     mean = np.mean(mus, axis=0)
     std = np.std(mus, axis=0)
     return mus, mean, std
Ejemplo n.º 6
0
 def _reconstruct_img(self, flat_img):
     latent_distribution_params = self.vae.encode(
         ptu.from_numpy(flat_img.reshape(1, -1)))
     reconstructions, _ = self.vae.decode(latent_distribution_params[0])
     imgs = ptu.get_numpy(reconstructions)
     imgs = imgs.reshape(1, self.input_channels, self.imsize, self.imsize)
     return imgs[0]
Ejemplo n.º 7
0
 def normalize_scale(self, v):
     """
     Only normalize the scale. Do not subtract the mean.
     """
     std = ptu.from_numpy(self.std)
     if v.dim() == 2:
         std = std.unsqueeze(0)
     return v / std
Ejemplo n.º 8
0
    def random_vae_training_data(self, batch_size, epoch):
        # epoch no longer needed. Using self.skew in sample_weighted_indices
        # instead.
        weighted_idxs = self.sample_weighted_indices(batch_size, )

        next_image_obs = normalize_image(
            self._next_obs[self.decoded_obs_key][weighted_idxs])
        return dict(next_obs=ptu.from_numpy(next_image_obs))
Ejemplo n.º 9
0
 def denormalize_scale(self, v):
     """
     Only denormalize the scale. Do not add the mean.
     """
     std = ptu.from_numpy(self.std)
     if v.dim() == 2:
         std = std.unsqueeze(0)
     return v * std
Ejemplo n.º 10
0
def compute_log_p_log_q_log_d(model,
                              data,
                              decoder_distribution='bernoulli',
                              num_latents_to_sample=1,
                              sampling_method='importance_sampling'):
    assert data.dtype == np.float64, 'images should be normalized'
    imgs = ptu.from_numpy(data)
    latent_distribution_params = model.encode(imgs)
    batch_size = data.shape[0]
    representation_size = model.representation_size
    log_p, log_q, log_d = ptu.zeros(
        (batch_size, num_latents_to_sample)), ptu.zeros(
            (batch_size, num_latents_to_sample)), ptu.zeros(
                (batch_size, num_latents_to_sample))
    true_prior = Normal(ptu.zeros((batch_size, representation_size)),
                        ptu.ones((batch_size, representation_size)))
    mus, logvars = latent_distribution_params
    for i in range(num_latents_to_sample):
        if sampling_method == 'importance_sampling':
            latents = model.rsample(latent_distribution_params)
        elif sampling_method == 'biased_sampling':
            latents = model.rsample(latent_distribution_params)
        elif sampling_method == 'true_prior_sampling':
            latents = true_prior.rsample()
        else:
            raise EnvironmentError('Invalid Sampling Method Provided')

        stds = logvars.exp().pow(.5)
        vae_dist = Normal(mus, stds)
        log_p_z = true_prior.log_prob(latents).sum(dim=1)
        log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1)
        if decoder_distribution == 'bernoulli':
            decoded = model.decode(latents)[0]
            log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) *
                                        (1 - decoded) + 1e-8).sum(dim=1)
        elif decoder_distribution == 'gaussian_identity_variance':
            _, obs_distribution_params = model.decode(latents)
            dec_mu, dec_logvar = obs_distribution_params
            dec_var = dec_logvar.exp()
            decoder_dist = Normal(dec_mu, dec_var.pow(.5))
            log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1)
        else:
            raise EnvironmentError('Invalid Decoder Distribution Provided')

        log_p[:, i] = log_p_z
        log_q[:, i] = log_q_z_given_x
        log_d[:, i] = log_d_x_given_z
    return log_p, log_q, log_d
Ejemplo n.º 11
0
    def log_loss_under_uniform(self, model, data, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], self.batch_size):
            img = normalize_image(data[i:min(data.shape[0], i +
                                             self.batch_size), :])
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.model(
                torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(
                model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img,
                             reconstructions,
                             reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        logger.record_tabular("Uniform Data Log Prob (True Prior)",
                              np.mean(log_probs_prior))
        logger.record_tabular("Uniform Data Log Prob (Biased)",
                              np.mean(log_probs_biased))
        logger.record_tabular("Uniform Data Log Prob (Importance)",
                              np.mean(log_probs_importance))
        logger.record_tabular("Uniform Data KL", np.mean(kles))
        logger.record_tabular("Uniform Data MSE", np.mean(mses))
Ejemplo n.º 12
0
 def _update_info(self, info, obs):
     latent_distribution_params = self.vae.encode(
         ptu.from_numpy(obs[self.vae_input_observation_key].reshape(1, -1)))
     latent_obs, logvar = ptu.get_numpy(
         latent_distribution_params[0])[0], ptu.get_numpy(
             latent_distribution_params[1])[0]
     # assert (latent_obs == obs['latent_observation']).all()
     latent_goal = self.desired_goal['latent_desired_goal']
     dist = latent_goal - latent_obs
     var = np.exp(logvar.flatten())
     var = np.maximum(var, self.reward_min_variance)
     err = dist * dist / 2 / var
     mdist = np.sum(err)  # mahalanobis distance
     info["vae_mdist"] = mdist
     info["vae_success"] = 1 if mdist < self.epsilon else 0
     info["vae_dist"] = np.linalg.norm(dist, ord=self.norm_order)
     info["vae_dist_l1"] = np.linalg.norm(dist, ord=1)
     info["vae_dist_l2"] = np.linalg.norm(dist, ord=2)
Ejemplo n.º 13
0
    def _dump_imgs_and_reconstructions(self, idxs, filename):
        imgs = []
        recons = []
        for i in idxs:
            img_np = self.train_dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=len(idxs),
        )
Ejemplo n.º 14
0
    def dump_uniform_imgs_and_reconstructions(self, dataset, epoch):
        idxs = np.random.choice(range(dataset.shape[0]), 4)
        filename = 'uniform{}.png'.format(epoch)
        imgs = []
        recons = []
        for i in idxs:
            img_np = dataset[i]
            img_torch = ptu.from_numpy(normalize_image(img_np))
            recon, *_ = self.model(img_torch.view(1, -1))

            img = img_torch.view(self.input_channels, self.imsize,
                                 self.imsize).transpose(1, 2)
            rimg = recon.view(self.input_channels, self.imsize,
                              self.imsize).transpose(1, 2)
            imgs.append(img)
            recons.append(rimg)
        all_imgs = torch.stack(imgs + recons)
        save_file = osp.join(logger.get_snapshot_dir(), filename)
        save_image(
            all_imgs.data,
            save_file,
            nrow=4,
        )
Ejemplo n.º 15
0
def torch_ify(np_array_or_other):
    if isinstance(np_array_or_other, np.ndarray):
        return ptu.from_numpy(np_array_or_other)
    else:
        return np_array_or_other
Ejemplo n.º 16
0
 def _decode(self, latents):
     reconstructions, _ = self.vae.decode(ptu.from_numpy(latents))
     decoded = ptu.get_numpy(reconstructions)
     return decoded
Ejemplo n.º 17
0
 def _encode(self, imgs):
     latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs))
     return ptu.get_numpy(latent_distribution_params[0])
Ejemplo n.º 18
0
def _elem_or_tuple_to_variable(elem_or_tuple):
    if isinstance(elem_or_tuple, tuple):
        return tuple(_elem_or_tuple_to_variable(e) for e in elem_or_tuple)
    return ptu.from_numpy(elem_or_tuple).float()