예제 #1
0
    def debug_statistics(self):
        """
        Given an image $$x$$, samples a bunch of latents from the prior
        $$z_i$$ and decode them $$\hat x_i$$.
        Compare this to $$\hat x$$, the reconstruction of $$x$$.
        Ideally
         - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE
           isn’t ignoring the latent)
         - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for
           coverage)
        """
        debug_batch_size = 64
        data = self.get_batch(train=False)
        reconstructions, _, _ = self.model(data)
        img = data[0]
        recon_mse = ((reconstructions[0] - img)**2).mean().view(-1)
        img_repeated = img.expand((debug_batch_size, img.shape[0]))

        samples = ptu.randn(debug_batch_size, self.representation_size)
        random_imgs, _ = self.model.decode(samples)
        random_mses = (random_imgs - img_repeated)**2
        mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse)
        stats = create_stats_ordered_dict(
            'debug/MSE improvement over random',
            mse_improvement,
        )
        stats.update(
            create_stats_ordered_dict(
                'debug/MSE of random decoding',
                ptu.get_numpy(random_mses),
            ))
        stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0]
        return stats
예제 #2
0
    def load_dataset(self, dataset_path):
        dataset = load_local_or_remote_file(dataset_path)
        dataset = dataset.item()

        observations = dataset['observations']
        actions = dataset['actions']

        # dataset['observations'].shape # (2000, 50, 6912)
        # dataset['actions'].shape # (2000, 50, 2)
        # dataset['env'].shape # (2000, 6912)
        N, H, imlength = observations.shape

        self.vae.eval()
        for n in range(N):
            x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0)
            x = ptu.from_numpy(observations[n, :, :] / 255.0)
            latents = self.vae.encode(x, x0, distrib=False)

            r1, r2 = self.vae.latent_sizes
            conditioning = latents[0, r1:]
            goal = torch.cat(
                [ptu.randn(self.vae.latent_sizes[0]), conditioning])
            goal = ptu.get_numpy(goal)  # latents[-1, :]

            latents = ptu.get_numpy(latents)
            latent_delta = latents - goal
            distances = np.zeros((H - 1, 1))
            for i in range(H - 1):
                distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :])

            terminals = np.zeros((H - 1, 1))
            # terminals[-1, 0] = 1
            path = dict(
                observations=[],
                actions=actions[n, :H - 1, :],
                next_observations=[],
                rewards=-distances,
                terminals=terminals,
            )

            for t in range(H - 1):
                # reward = -np.linalg.norm(latent_delta[i, :])

                obs = dict(
                    latent_observation=latents[t, :],
                    latent_achieved_goal=latents[t, :],
                    latent_desired_goal=goal,
                )
                next_obs = dict(
                    latent_observation=latents[t + 1, :],
                    latent_achieved_goal=latents[t + 1, :],
                    latent_desired_goal=goal,
                )

                path['observations'].append(obs)
                path['next_observations'].append(next_obs)

            # import ipdb; ipdb.set_trace()
            self.replay_buffer.add_path(path)
예제 #3
0
 def dump_samples(self, epoch):
     self.model.eval()
     sample = ptu.randn(64, self.representation_size)
     sample = self.model.decode(sample)[0].cpu()
     save_dir = osp.join(self.log_dir, 's%d.png' % epoch)
     save_image(
         sample.data.view(64, self.input_channels, self.imsize,
                          self.imsize).transpose(2, 3), save_dir)
예제 #4
0
 def get_encoding_and_suff_stats(self, x):
     output = self(x)
     z_dim = output.shape[1] // 2
     means, log_var = (
         output[:, :z_dim], output[:, z_dim:]
     )
     stds = (0.5 * log_var).exp()
     epsilon = ptu.randn(means.shape)
     latents = epsilon * stds + means
     return latents, means, log_var, stds
예제 #5
0
    def compute_density(self, data):
        orig_data_length = len(data)
        data = np.vstack([
            data for _ in range(self.n_average)
        ])
        data = ptu.from_numpy(data)
        if self.mode == 'biased':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data)
            )
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'prior':
            latents = ptu.randn(len(data), self.z_dim)
            importance_weights = ptu.ones(data.shape[0])
        elif self.mode == 'importance_sampling':
            latents, means, log_vars, stds = (
                self.encoder.get_encoding_and_suff_stats(data)
            )
            prior = Normal(ptu.zeros(1), ptu.ones(1))
            prior_log_prob = prior.log_prob(latents).sum(dim=1)

            encoder_distrib = Normal(means, stds)
            encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1)

            importance_weights = (prior_log_prob - encoder_log_prob).exp()
        else:
            raise NotImplementedError()

        unweighted_data_log_prob = self.compute_log_prob(
            data, self.decoder, latents
        ).squeeze(1)
        unweighted_data_prob = unweighted_data_log_prob.exp()
        unnormalized_data_prob = unweighted_data_prob * importance_weights
        """
        Average over `n_average`
        """
        dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0)
        # pre_avg.shape = ORIG_LEN x N_AVERAGE
        dp_stacked = torch.stack(dp_split, dim=1)
        # final.shape = ORIG_LEN
        unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False)

        """
        Compute the importance weight denomintors.
        This requires summing across the `n_average` dimension.
        """
        iw_split = torch.split(importance_weights, orig_data_length, dim=0)
        iw_stacked = torch.stack(iw_split, dim=1)
        iw_denominators = iw_stacked.sum(dim=1, keepdim=False)

        final = unnormalized_dp / iw_denominators
        return ptu.get_numpy(final)
    def sample_prior(self, batch_size, x_0, true_prior=True):
        if x_0.shape[0] == 1:
            x_0 = x_0.repeat(batch_size, 1)

        z_sample = ptu.randn(batch_size, self.latent_sizes[0])

        if not true_prior:
            stds = np.exp(0.5 * self.prior_logvar)
            z_sample = z_sample * stds + self.prior_mu

        conditioning = self.bn_c(self.c(self.dropout(self.cond_encoder(x_0))))
        cond_sample = torch.cat([z_sample, conditioning], dim=1)
        return cond_sample
예제 #7
0
    def dump_samples(self, epoch):
        self.model.eval()
        batch, _ = self.eval_data["test/last_batch"]
        sample = ptu.randn(64, self.representation_size)
        sample = self.model.decode(sample, batch["observations"])[0].cpu()
        save_dir = osp.join(self.log_dir, 's%d.png' % epoch)
        save_image(
            sample.data.view(64, 3, self.imsize, self.imsize).transpose(2, 3),
            save_dir)

        x0 = batch["x_0"]
        x0_img = x0[:64].narrow(start=0, length=self.imlength // 2,
                                dim=1).contiguous().view(
                                    -1, 3, self.imsize,
                                    self.imsize).transpose(2, 3)
        save_dir = osp.join(self.log_dir, 'x0_%d.png' % epoch)
        save_image(x0_img.data.cpu(), save_dir)
예제 #8
0
 def rsample(self, latent_distribution_params):
     mu, logvar = latent_distribution_params
     stds = (0.5 * logvar).exp()
     epsilon = ptu.randn(*mu.size())
     latents = epsilon * stds + mu
     return latents
예제 #9
0
 def sample(self, num_samples):
     return ptu.get_numpy(self.sample_given_z(
         ptu.randn(num_samples, self.z_dim)
     ))
예제 #10
0
    def train_from_torch(self, batch):
        logger.push_tabular_prefix("train_q/")
        self.eval_statistics = dict()
        self._need_to_update_eval_statistics = True

        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)

            if self.demo_train_buffer._size >= self.bc_batch_size:
                if self.use_demo_awr:
                    train_batch = self.get_batch_from_buffer(
                        self.demo_train_buffer)
                    train_o = train_batch["observations"]
                    train_u = train_batch["actions"]
                    if self.goal_conditioned:
                        train_g = train_batch["resampled_goals"]
                        train_o = torch.cat((train_o, train_g), dim=1)
                    train_pred_u = self.policy(train_o)
                    train_error = (train_pred_u - train_u)**2
                    train_bc_loss = train_error.mean()

                    policy_q_output_demo_state = self.qf1(
                        train_o, train_pred_u)
                    demo_q_output = self.qf1(train_o, train_u)

                    advantage = demo_q_output - policy_q_output_demo_state
                    self.eval_statistics['Train BC Loss'] = np.mean(
                        ptu.get_numpy(train_bc_loss))

                    if self.awr_policy_update:
                        train_bc_loss = (train_error * torch.exp(
                            (advantage) * self.demo_beta))
                        self.eval_statistics['Advantage'] = np.mean(
                            ptu.get_numpy(advantage))

                    if self._n_train_steps_total < self.max_steps_till_train_rl:
                        rl_weight = 0
                    else:
                        rl_weight = self.rl_weight

                    policy_loss = -rl_weight * q_output.mean(
                    ) + self.bc_weight * train_bc_loss.mean()
                else:
                    train_batch = self.get_batch_from_buffer(
                        self.demo_train_buffer)

                    train_o = train_batch["observations"]
                    # train_pred_u = self.policy(train_o)
                    if self.goal_conditioned:
                        train_g = train_batch["resampled_goals"]
                        train_o = torch.cat((train_o, train_g), dim=1)
                    train_pred_u = self.policy(train_o)
                    train_u = train_batch["actions"]
                    train_error = (train_pred_u - train_u)**2
                    train_bc_loss = train_error.mean()

                    # Advantage-weighted regression
                    policy_error = (policy_actions - actions)**2
                    policy_error = policy_error.mean(dim=1)
                    advantage = q1_pred - q_output
                    weights = F.softmax((advantage / self.beta)[:, 0])

                    if self.awr_policy_update:
                        policy_loss = self.rl_weight * (
                            policy_error * weights.detach() *
                            self.bc_batch_size).mean()
                    else:
                        policy_loss = -self.rl_weight * q_output.mean(
                        ) + self.bc_weight * train_bc_loss.mean()

                    self.eval_statistics.update(
                        create_stats_ordered_dict(
                            'Advantage Weights',
                            ptu.get_numpy(weights),
                        ))

                self.eval_statistics['BC Loss'] = np.mean(
                    ptu.get_numpy(train_bc_loss))

            else:  # Normal TD3 update
                policy_loss = -self.rl_weight * q_output.mean()

            if self.update_policy and not (self.rl_weight == 0
                                           and self.bc_weight == 0):
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))

            if self.demo_test_buffer._size >= self.bc_batch_size:
                train_batch = self.get_batch_from_buffer(
                    self.demo_train_buffer)
                train_u = train_batch["actions"]
                train_o = train_batch["observations"]
                if self.goal_conditioned:
                    train_g = train_batch["resampled_goals"]
                    train_o = torch.cat((train_o, train_g), dim=1)
                train_pred_u = self.policy(train_o)
                train_error = (train_pred_u - train_u)**2
                train_bc_loss = train_error

                policy_q_output_demo_state = self.qf1(train_o, train_pred_u)
                demo_q_output = self.qf1(train_o, train_u)

                train_advantage = demo_q_output - policy_q_output_demo_state

                test_batch = self.get_batch_from_buffer(self.demo_test_buffer)
                test_o = test_batch["observations"]
                test_u = test_batch["actions"]
                if self.goal_conditioned:
                    test_g = test_batch["resampled_goals"]
                    test_o = torch.cat((test_o, test_g), dim=1)
                test_pred_u = self.policy(test_o)
                test_error = (test_pred_u - test_u)**2
                test_bc_loss = test_error

                policy_q_output_demo_state = self.qf1(test_o, test_pred_u)
                demo_q_output = self.qf1(test_o, test_u)
                test_advantage = demo_q_output - policy_q_output_demo_state

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Train BC Loss',
                        ptu.get_numpy(train_bc_loss),
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Train Demo Advantage',
                        ptu.get_numpy(train_advantage),
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Test BC Loss',
                        ptu.get_numpy(test_bc_loss),
                    ))

                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Test Demo Advantage',
                        ptu.get_numpy(test_advantage),
                    ))

        self._n_train_steps_total += 1

        logger.pop_tabular_prefix()
예제 #11
0
all_imgs = [
    x0.narrow(start=0, length=imlength,
              dim=1).contiguous().view(-1, 3, imsize, imsize).transpose(2, 3),
]
comparison = torch.cat(all_imgs)
save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%sx0.png" % prefix
save_image(comparison.data.cpu(), save_dir, nrow=8)

vae = load_local_or_remote_file(vae_path).to("cpu")
vae.eval()

model = vae
all_imgs = []
for i in range(N_ROWS):
    latent = ptu.randn(
        n,
        model.representation_size)  # model.sample_prior(self.batch_size, env)
    samples = model.decode(latent)[0]
    all_imgs.extend([samples.view(
        n,
        3,
        imsize,
        imsize,
    )[:n].transpose(2, 3)])
comparison = torch.cat(all_imgs)
save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%svae_samples.png" % prefix
save_image(comparison.data.cpu(), save_dir, nrow=8)

cvae = load_local_or_remote_file(cvae_path).to("cpu")
cvae.eval()