Exemplo n.º 1
0
    def update(self, aug_obs):
        obs = aug_obs

        enc_features = self.enc(obs)
        mu = self.enc_mu(enc_features)
        logvar = self.enc_logvar(enc_features)

        stds = (0.5 * logvar).exp()
        epsilon = ptu.randn(*mu.size())
        code = epsilon * stds + mu

        kle = -0.5 * torch.sum(
            1 + logvar - mu.pow(2) - logvar.exp(), dim=1
        ).mean()

        obs_distribution_params = self.dec(code)
        log_prob = -1. * F.mse_loss(obs, obs_distribution_params,
                                    reduction='elementwise_mean')

        loss = self.beta * kle - log_prob

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.cpu().item()
Exemplo n.º 2
0
    def debug_statistics(self):
        """
        Given an image $$x$$, samples a bunch of latents from the prior
        $$z_i$$ and decode them $$\hat x_i$$.
        Compare this to $$\hat x$$, the reconstruction of $$x$$.
        Ideally
         - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE
           isn’t ignoring the latent)
         - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for
           coverage)
        """
        debug_batch_size = 64
        data = self.get_batch(train=False)
        reconstructions, _, _ = self.model(data)
        img = data[0]
        recon_mse = ((reconstructions[0] - img)**2).mean().view(-1)
        img_repeated = img.expand((debug_batch_size, img.shape[0]))

        samples = ptu.randn(debug_batch_size, self.representation_size)
        random_imgs, _ = self.model.decode(samples)
        random_mses = (random_imgs - img_repeated)**2
        mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse)
        stats = create_stats_ordered_dict(
            'debug/MSE improvement over random',
            mse_improvement,
        )
        stats.update(
            create_stats_ordered_dict(
                'debug/MSE of random decoding',
                ptu.get_numpy(random_mses),
            ))
        stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0]
        return stats
Exemplo n.º 3
0
    def load_dataset(self, dataset_path):
        dataset = load_local_or_remote_file(dataset_path)
        dataset = dataset.item()

        observations = dataset['observations']
        actions = dataset['actions']

        # dataset['observations'].shape # (2000, 50, 6912)
        # dataset['actions'].shape # (2000, 50, 2)
        # dataset['env'].shape # (2000, 6912)
        N, H, imlength = observations.shape

        self.vae.eval()
        for n in range(N):
            x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0)
            x = ptu.from_numpy(observations[n, :, :] / 255.0)
            latents = self.vae.encode(x, x0, distrib=False)

            r1, r2 = self.vae.latent_sizes
            conditioning = latents[0, r1:]
            goal = torch.cat(
                [ptu.randn(self.vae.latent_sizes[0]), conditioning])
            goal = ptu.get_numpy(goal)  # latents[-1, :]

            latents = ptu.get_numpy(latents)
            latent_delta = latents - goal
            distances = np.zeros((H - 1, 1))
            for i in range(H - 1):
                distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :])

            terminals = np.zeros((H - 1, 1))
            # terminals[-1, 0] = 1
            path = dict(
                observations=[],
                actions=actions[n, :H - 1, :],
                next_observations=[],
                rewards=-distances,
                terminals=terminals,
            )

            for t in range(H - 1):
                # reward = -np.linalg.norm(latent_delta[i, :])

                obs = dict(
                    latent_observation=latents[t, :],
                    latent_achieved_goal=latents[t, :],
                    latent_desired_goal=goal,
                )
                next_obs = dict(
                    latent_observation=latents[t + 1, :],
                    latent_achieved_goal=latents[t + 1, :],
                    latent_desired_goal=goal,
                )

                path['observations'].append(obs)
                path['next_observations'].append(next_obs)

            # import ipdb; ipdb.set_trace()
            self.replay_buffer.add_path(path)
Exemplo n.º 4
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     save_dir = osp.join(self.log_dir, 's%d.png' % epoch)
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize).transpose(2, 3), save_dir)
Exemplo n.º 5
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch)
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize), save_dir)
Exemplo n.º 6
0
 def get_encoding_and_suff_stats(self, x):
     output = self(x)
     z_dim = output.shape[1] // 2
     means, log_var = (output[:, :z_dim], output[:, z_dim:])
     stds = (0.5 * log_var).exp()
     epsilon = ptu.randn(means.shape)
     latents = epsilon * stds + means
     return latents, means, log_var, stds
Exemplo n.º 7
0
 def dump_samples(self, epoch, save_prefix='s'):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     save_dir = osp.join(logger.get_snapshot_dir(),
                         '{}{}.png'.format(save_prefix, epoch))
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize).transpose(2, 3), save_dir)
Exemplo n.º 8
0
    def sample_prior(self, batch_size, x_0):
        if x_0.shape[0] == 1:
            x_0 = x_0.repeat(batch_size, 1)

        x_0 = x_0.reshape(-1, self.input_channels, self.imsize, self.imsize)
        z_cond, _, _, _ = self.netE.cond_encoder(x_0)
        z_delta = ptu.randn(batch_size, self.latent_size, 1, 1)
        cond_sample = torch.cat([z_delta, z_cond], dim=1)
        return cond_sample
Exemplo n.º 9
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     # save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch)
     # save_dir = osp.join('/mnt/manh/project/visual_RL_imaged_goal', 's%d.png' % epoch)
     project_path = osp.abspath(os.curdir)
     save_dir = osp.join(project_path + str('/result_image/'),
                         's%d.png' % epoch)
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize).transpose(2, 3), save_dir)
Exemplo n.º 10
0
    def sample_prior(self, batch_size, x_0, true_prior=True):
        if x_0.shape[0] == 1:
            x_0 = x_0.repeat(batch_size, 1)

        z_sample = ptu.randn(batch_size, self.latent_sizes[0])

        if not true_prior:
            stds = np.exp(0.5 * self.prior_logvar)
            z_sample = z_sample * stds + self.prior_mu

        conditioning = self.bn_c(self.c(self.dropout(self.cond_encoder(x_0))))
        cond_sample = torch.cat([z_sample, conditioning], dim=1)
        return cond_sample
Exemplo n.º 11
0
    def sample_prior(self, batch_size, cond=None, image_cond=True):
        if cond.shape[0] == 1:
            cond = cond.repeat(batch_size, axis=0)
        cond = ptu.from_numpy(cond)

        if image_cond:
            z_cond = self.encode_cond(batch_size, cond)
        else:
            z_cat = cond.reshape(batch_size, 2 * self.embedding_dim, self.root_len, self.root_len)
            z_cond = z_cat[:, self.embedding_dim:]

        z_delta = ptu.randn(batch_size, self.embedding_dim, self.root_len, self.root_len)
        z_cat = torch.cat([z_delta, z_cond], dim=1).view(-1, self.representation_size)
        
        return ptu.get_numpy(z_cat)
Exemplo n.º 12
0
    def compute_density(self, data):
        orig_data_length = len(data)
        data = np.vstack([data for _ in range(self.n_average)])
        data = ptu.from_numpy(data)
        if self.mode == 'biased':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data))
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'prior':
            latents = ptu.randn(len(data), self.z_dim)
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'importance_sampling':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data))
            prior = Normal(ptu.zeros(1), ptu.ones(1))
            prior_log_prob = prior.log_prob(latents).sum(dim=1)

            encoder_distrib = Normal(means, stds)
            encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1)

            importance_weights = (prior_log_prob - encoder_log_prob).exp()
        else:
            raise NotImplementedError()

        unweighted_data_log_prob = self.compute_log_prob(
            data, self.decoder, latents).squeeze(1)
        unweighted_data_prob = unweighted_data_log_prob.exp()
        unnormalized_data_prob = unweighted_data_prob * importance_weights
        """
        Average over `n_average`
        """
        dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0)
        # pre_avg.shape = ORIG_LEN x N_AVERAGE
        dp_stacked = torch.stack(dp_split, dim=1)
        # final.shape = ORIG_LEN
        unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False)
        """
        Compute the importance weight denomintors.
        This requires summing across the `n_average` dimension.
        """
        iw_split = torch.split(importance_weights, orig_data_length, dim=0)
        iw_stacked = torch.stack(iw_split, dim=1)
        iw_denominators = iw_stacked.sum(dim=1, keepdim=False)

        final = unnormalized_dp / iw_denominators
        return ptu.get_numpy(final)
Exemplo n.º 13
0
    def sample_prior(self, batch_size, cond=None, image_cond=True):
        if cond.shape[0] == 1:
            cond = cond.repeat(batch_size, axis=0)
        cond = ptu.from_numpy(cond)

        if image_cond:
            cond = cond.reshape(-1, self.input_channels, self.imsize,
                                self.imsize)
            z_cond, _, _, _ = self.netE.cond_encoder(cond)
            z_cond = z_cond.reshape(-1, self.latent_size)
        else:
            z_cond = cond[:, self.latent_size:]

        z_delta = ptu.randn(batch_size, self.latent_size)
        cond_sample = torch.cat([z_delta, z_cond], dim=1)

        return ptu.get_numpy(cond_sample)
Exemplo n.º 14
0
    def dump_samples(self, epoch):
        self.model.eval()
        batch, _ = self.eval_data["test/last_batch"]
        sample = ptu.randn(64, self.representation_size)
        sample = self.model.decode(sample, batch["observations"])[0].cpu()
        save_dir = osp.join(self.log_dir, 's%d.png' % epoch)
        save_image(
            sample.data.view(64, 3, self.imsize, self.imsize).transpose(2, 3),
            save_dir)

        x0 = batch["x_0"]
        x0_img = x0[:64].narrow(start=0, length=self.imlength // 2,
                                dim=1).contiguous().view(
                                    -1, 3, self.imsize,
                                    self.imsize).transpose(2, 3)
        save_dir = osp.join(self.log_dir, 'x0_%d.png' % epoch)
        save_image(x0_img.data.cpu(), save_dir)
Exemplo n.º 15
0
    def get_output_for(self, aug_obs, sample=True):
        """
        Returns the log probability of the given observation.
        """
        obs = aug_obs
        with torch.no_grad():
            enc_features = self.enc(obs)
            mu = self.enc_mu(enc_features)
            logvar = self.enc_logvar(enc_features)

            stds = (0.5 * logvar).exp()
            if sample:
                epsilon = ptu.randn(*mu.size())
            else:
                epsilon = torch.ones_like(mu)
            code = epsilon * stds + mu

            obs_distribution_params = self.dec(code)
            log_prob = -1. * F.mse_loss(obs, obs_distribution_params,
                                        reduction='none')
            log_prob = torch.sum(log_prob, -1, keepdim=True)
        return log_prob.detach()
Exemplo n.º 16
0
 def sample(self, num_samples):
     return ptu.get_numpy(
         self.sample_given_z(ptu.randn(num_samples, self.z_dim)))
Exemplo n.º 17
0
 def fixed_noise(self, b_size):
     return ptu.randn(b_size, self.representation_size, 1, 1)
Exemplo n.º 18
0
 def noise(self, size, num_epochs, epoch):
     noise = ptu.randn(size)
     std = 0.1 * (num_epochs - epoch) / num_epochs
     return std * noise
Exemplo n.º 19
0
 def fixed_noise(self, b_size, latent):
     z_cond = latent[:, self.model.latent_size:]
     z_delta = ptu.randn(b_size, self.model.latent_size, 1, 1)
     return torch.cat([z_delta, z_cond], dim=1)
Exemplo n.º 20
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(
            noise,
            -self.target_policy_noise_clip,
            self.target_policy_noise_clip
        )
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target) ** 2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target) ** 2
        qf2_loss = bellman_errors_2.mean()

        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = - q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = - q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 1',
                ptu.get_numpy(bellman_errors_1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Bellman Errors 2',
                ptu.get_numpy(bellman_errors_2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Action',
                ptu.get_numpy(policy_actions),
            ))
        self._n_train_steps_total += 1
Exemplo n.º 21
0
all_imgs = [
    x0.narrow(start=0, length=imlength,
              dim=1).contiguous().view(-1, 3, imsize, imsize).transpose(2, 3),
]
comparison = torch.cat(all_imgs)
save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%sx0.png" % prefix
save_image(comparison.data.cpu(), save_dir, nrow=8)

vae = load_local_or_remote_file(vae_path).to("cpu")
vae.eval()

model = vae
all_imgs = []
for i in range(N_ROWS):
    latent = ptu.randn(
        n,
        model.representation_size)  # model.sample_prior(self.batch_size, env)
    samples = model.decode(latent)[0]
    all_imgs.extend([samples.view(
        n,
        3,
        imsize,
        imsize,
    )[:n].transpose(2, 3)])
comparison = torch.cat(all_imgs)
save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%svae_samples.png" % prefix
save_image(comparison.data.cpu(), save_dir, nrow=8)

cvae = load_local_or_remote_file(cvae_path).to("cpu")
cvae.eval()
Exemplo n.º 22
0
Arquivo: td4.py Projeto: xtma/dsac
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        gt.stamp('preback_start', unique=False)
        """
        Update QF
        """
        with torch.no_grad():
            next_actions = self.target_policy(next_obs)
            noise = ptu.randn(next_actions.shape) * self.target_policy_noise
            noise = torch.clamp(noise, -self.target_policy_noise_clip,
                                self.target_policy_noise_clip)
            noisy_next_actions = torch.clamp(next_actions + noise,
                                             -self.max_action, self.max_action)

            next_tau, next_tau_hat, next_presum_tau = self.get_tau(
                next_obs, noisy_next_actions, fp=self.target_fp)
            target_z1_values = self.target_zf1(next_obs, noisy_next_actions,
                                               next_tau_hat)
            target_z2_values = self.target_zf2(next_obs, noisy_next_actions,
                                               next_tau_hat)
            target_z_values = torch.min(target_z1_values, target_z2_values)
            z_target = self.reward_scale * rewards + (
                1. - terminals) * self.discount * target_z_values

        tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp)
        z1_pred = self.zf1(obs, actions, tau_hat)
        z2_pred = self.zf2(obs, actions, tau_hat)
        zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat,
                                     next_presum_tau)
        zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat,
                                     next_presum_tau)
        gt.stamp('preback_zf', unique=False)

        self.zf1_optimizer.zero_grad()
        zf1_loss.backward()
        self.zf1_optimizer.step()
        gt.stamp('backward_zf1', unique=False)

        self.zf2_optimizer.zero_grad()
        zf2_loss.backward()
        self.zf2_optimizer.step()
        gt.stamp('backward_zf2', unique=False)
        """
        Update FP
        """
        if self.tau_type == 'fqf':
            with torch.no_grad():
                dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) -
                                z1_pred[:, :-1] - z1_pred[:, 1:] +
                                2 * self.zf2(obs, actions, tau[:, :-1]) -
                                z2_pred[:, :-1] - z2_pred[:, 1:])
                dWdtau /= dWdtau.shape[0]  # (N, T-1)
            gt.stamp('preback_fp', unique=False)
            self.fp_optimizer.zero_grad()
            tau[:, :-1].backward(gradient=dWdtau)
            self.fp_optimizer.step()
            gt.stamp('backward_fp', unique=False)
        """
        Policy Loss
        """
        policy_actions = self.policy(obs)
        risk_param = self.risk_schedule(self._n_train_steps_total)

        if self.risk_type == 'VaR':
            tau_ = ptu.ones_like(rewards) * risk_param
            q_new_actions = self.zf1(obs, policy_actions, tau_)
        else:
            with torch.no_grad():
                new_tau, new_tau_hat, new_presum_tau = self.get_tau(
                    obs, policy_actions, fp=self.fp)
            z_new_actions = self.zf1(obs, policy_actions, new_tau_hat)
            if self.risk_type in ['neutral', 'std']:
                q_new_actions = torch.sum(new_presum_tau * z_new_actions,
                                          dim=1,
                                          keepdims=True)
                if self.risk_type == 'std':
                    q_std = new_presum_tau * (z_new_actions -
                                              q_new_actions).pow(2)
                    q_new_actions -= risk_param * q_std.sum(
                        dim=1, keepdims=True).sqrt()
            else:
                with torch.no_grad():
                    risk_weights = distortion_de(new_tau_hat, self.risk_type,
                                                 risk_param)
                q_new_actions = torch.sum(risk_weights * new_presum_tau *
                                          z_new_actions,
                                          dim=1,
                                          keepdims=True)

        policy_loss = -q_new_actions.mean()

        gt.stamp('preback_policy', unique=False)

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(),
                                                  self.clip_norm)
            self.policy_optimizer.step()
            gt.stamp('backward_policy', unique=False)

            ptu.soft_update_from_to(self.policy, self.target_policy,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf1, self.target_zf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf2, self.target_zf2,
                                    self.soft_target_tau)
            if self.tau_type == 'fqf':
                ptu.soft_update_from_to(self.fp, self.target_fp,
                                        self.soft_target_tau)
        gt.stamp('soft_update', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['ZF1 Loss'] = zf1_loss.item()
            self.eval_statistics['ZF2 Loss'] = zf2_loss.item()
            self.eval_statistics['Policy Loss'] = policy_loss.item()
            self.eval_statistics['Policy Grad'] = policy_grad
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z1 Predictions',
                    ptu.get_numpy(z1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z2 Predictions',
                    ptu.get_numpy(z2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z Targets',
                    ptu.get_numpy(z_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))

        self._n_train_steps_total += 1
Exemplo n.º 23
0
 def sample_prior(self, batch_size):
     z_s = ptu.randn(batch_size, self.representation_size)
     return ptu.get_numpy(z_s)
Exemplo n.º 24
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)
     save_dir = osp.join(self.log_dir, 's%d.png' % epoch)
     save_image(sample.data.transpose(2, 3), save_dir)
Exemplo n.º 25
0
 def rsample(self, latent_distribution_params):
     mu, logvar = latent_distribution_params
     stds = (0.5 * logvar).exp()
     epsilon = ptu.randn(*mu.size())
     latents = epsilon * stds + mu
     return latents