Exemplo n.º 1
0
    def forward(self, x, global_step):
        """
        Background inference backward pass
        
        :param x: shape (B, C, H, W)
        :param global_step: global training step
        :return:
            bg_likelihood: (B, 3, H, W)
            bg: (B, 3, H, W)
            kl_bg: (B,)
            log: a dictionary containing things for visualization
        """
        B = x.size(0)

        # (B, D)
        x_enc = self.image_enc(x)

        # Mask and component latents over the K slots
        masks = []
        z_masks = []
        # These two are Normal instances
        z_mask_posteriors = []
        z_comp_posteriors = []

        # Initialization: encode x and dummy z_mask_0
        z_mask = self.z_mask_0.expand(B, arch.z_mask_dim)
        h = self.rnn_mask_h.expand(B, arch.rnn_mask_hidden_dim)
        c = self.rnn_mask_c.expand(B, arch.rnn_mask_hidden_dim)

        for i in range(arch.K):
            # Encode x and z_{mask, 1:k}, (b, D)
            rnn_input = torch.cat((z_mask, x_enc), dim=1)
            (h, c) = self.rnn_mask(rnn_input, (h, c))

            # Predict next mask from x and z_{mask, 1:k-1}
            z_mask_loc, z_mask_scale = self.predict_mask(h)
            z_mask_post = Normal(z_mask_loc, z_mask_scale)
            z_mask = z_mask_post.rsample()
            z_masks.append(z_mask)
            z_mask_posteriors.append(z_mask_post)
            # Decode masks
            mask = self.mask_decoder(z_mask)
            masks.append(mask)

        # (B, K, 1, H, W), in range (0, 1)
        masks = torch.stack(masks, dim=1)

        # SBP to ensure they sum to 1
        masks = self.SBP(masks)
        # An alternative is to use softmax
        # masks = F.softmax(masks, dim=1)

        B, K, _, H, W = masks.size()

        # Reshape (B, K, 1, H, W) -> (B*K, 1, H, W)
        masks = masks.view(B * K, 1, H, W)

        # Concatenate images (B*K, 4, H, W)
        comp_vae_input = torch.cat(((masks + 1e-5).log(), x[:, None].repeat(
            1, K, 1, 1, 1).view(B * K, 3, H, W)),
                                   dim=1)

        # Component latents, each (B*K, L)
        z_comp_loc, z_comp_scale = self.comp_encoder(comp_vae_input)
        z_comp_post = Normal(z_comp_loc, z_comp_scale)
        z_comp = z_comp_post.rsample()

        # Record component posteriors here. We will use this for computing KL
        z_comp_loc_reshape = z_comp_loc.view(B, K, -1)
        z_comp_scale_reshape = z_comp_scale.view(B, K, -1)
        for i in range(arch.K):
            z_comp_post_this = Normal(z_comp_loc_reshape[:, i],
                                      z_comp_scale_reshape[:, i])
            z_comp_posteriors.append(z_comp_post_this)

        # Decode into component images, (B*K, 3, H, W)
        comps = self.comp_decoder(z_comp)

        # Reshape (B*K, ...) -> (B, K, 3, H, W)
        comps = comps.view(B, K, 3, H, W)
        masks = masks.view(B, K, 1, H, W)

        # Now we are ready to compute the background likelihoods
        # (B, K, 3, H, W)
        comp_dist = Normal(comps, torch.full_like(comps, self.bg_sigma))
        log_likelihoods = comp_dist.log_prob(x[:, None].expand_as(comps))

        # (B, K, 3, H, W) -> (B, 3, H, W), mixture likelihood
        log_sum = log_likelihoods + (masks + 1e-5).log()
        bg_likelihood = torch.logsumexp(log_sum, dim=1)

        # Background reconstruction
        bg = (comps * masks).sum(dim=1)

        # Below we compute priors and kls

        # Conditional KLs
        z_mask_total_kl = 0.0
        z_comp_total_kl = 0.0

        # Initial h and c. This is h_1 and c_1 in the paper
        h = self.rnn_mask_h_prior.expand(B, arch.rnn_mask_prior_hidden_dim)
        c = self.rnn_mask_c_prior.expand(B, arch.rnn_mask_prior_hidden_dim)

        for i in range(arch.K):
            # Compute prior distribution over z_masks
            z_mask_loc_prior, z_mask_scale_prior = self.predict_mask_prior(h)
            z_mask_prior = Normal(z_mask_loc_prior, z_mask_scale_prior)
            # Compute component prior, using posterior samples
            z_comp_loc_prior, z_comp_scale_prior = self.predict_comp_prior(
                z_masks[i])
            z_comp_prior = Normal(z_comp_loc_prior, z_comp_scale_prior)
            # Compute KLs as we go.
            z_mask_kl = kl_divergence(z_mask_posteriors[i],
                                      z_mask_prior).sum(dim=1)
            z_comp_kl = kl_divergence(z_comp_posteriors[i],
                                      z_comp_prior).sum(dim=1)
            # (B,)
            z_mask_total_kl += z_mask_kl
            z_comp_total_kl += z_comp_kl

            # Compute next state. Note we condition we posterior samples.
            # Again, this is conditional prior.
            (h, c) = self.rnn_mask_prior(z_masks[i], (h, c))

        # For visualization
        kl_bg = z_mask_total_kl + z_comp_total_kl
        log = {
            # (B, K, 3, H, W)
            'comps': comps,
            # (B, 1, 3, H, W)
            'masks': masks,
            # (B, 3, H, W)
            'bg': bg,
            'kl_bg': kl_bg
        }

        return bg_likelihood, bg, kl_bg, log
Exemplo n.º 2
0
def KL(normal_1, normal_2):
    kl = kl_divergence(normal_1, normal_2)
    kl = torch.mean(kl)
    return kl
Exemplo n.º 3
0
    def forward(self, img, z_what_last_time, z_where_last_time,
                z_pres_last_time, hidden_last_time_temp):
        # img [B H W] input image
        # z_what_last_time [B numbers what_length]
        # z_where_last_time [B numbers where_length]
        # z_pres_last_time [B numbers pres_length]
        #
        # hidden_last_time_temp [B numbers hidden_size]
        n = img.size(0)
        numbers = z_what_last_time.size(1)
        #print("hidden_last_time_temp.size",hidden_last_time_temp.size())

        #initilise
        """
      h_rela = zeros(n, 1, 256)
      c_rela = zeros(n, 1, 256)
      h_temp = zeros(n, 1, 256)
      c_temp = zeros(n, 1, 256)
      z_pres = zeros(n, 1, 1)
      z_where = zeros(n, 1,3)
      z_what = zeros(n, 1,50)
      #
      """
        h_rela = Variable(torch.zeros(n, 1, 256)).to(device)
        c_rela = Variable(torch.zeros(n, 1, 256)).to(device)
        h_temp = Variable(torch.zeros(n, 1, 256)).to(device)
        c_temp = Variable(torch.zeros(n, 1, 256)).to(device)
        z_pres = Variable(torch.zeros(n, 1, 1)).to(device)
        z_where = Variable(torch.zeros(n, 1, 3)).to(device)
        z_what = Variable(torch.zeros(n, 1, 50)).to(device)

        kl_z_what = torch.zeros(n, device=device)
        kl_z_where = torch.zeros(n, device=device)
        #print("numbers",numbers)
        for i in range(numbers):
            #print("i=",i)
            z_where_bias = self.prop_loca(z_where_last_time[:, i, :],
                                          hidden_last_time_temp[:,
                                                                i, :])  # [B 3]
            x_att_bias = attentive_stn_encode(z_where_bias,
                                              img)  # Spatial trasform [B 400]
            encode_bias = self.glimpse_encoder(x_att_bias)  # [B 100]
            if (i != 0):
                h_rela_item, c_rela_item = relation_hidden_state(
                    self.relation_rnn, encode_bias, z_where_last_time[:, i, :],
                    z_where[:, i - 1, :], z_what_last_time[:, i, :],
                    z_what[:, i - 1, :], hidden_last_time_temp[:, i, :],
                    h_rela[:, i - 1, :], c_rela[:, i - 1, :])  # [B 1 256]
                h_rela = torch.cat((h_rela, h_rela_item), dim=1)
                c_rela = torch.cat((c_rela, c_rela_item), dim=1)
            elif (i == 0):
                #print("test2")
                h_rela, c_rela = relation_hidden_state(
                    self.relation_rnn, encode_bias, z_where_last_time[:, i, :],
                    z_where[:, i, :], z_what_last_time[:, i, :],
                    z_what[:, i, :], hidden_last_time_temp[:, i, :],
                    h_rela[:, i, :], c_rela[:, i, :])  # [B 1 256]
            #print("h_rela",h_rela.size())
            z_where_cal = torch.cat(
                (z_where_last_time[:, i, :], h_rela[:, i, :]), 1)  #[B 3+256]
            z_where_item, z_where_mean, z_where_std = self._reparameterized_sample_where(
                z_where_cal)  #[B 3]
            #print("z_where_item",z_where_item.size())
            x_att = attentive_stn_encode(z_where_item,
                                         img)  # Spatial trasform [B 400]
            encode = self.glimpse_encoder(x_att)  # [B 100]

            h_temp_item, c_temp_item = temp_hidden_state(
                self.tem_rnn, encode, z_where[:, i - 1, :],
                hidden_last_time_temp[:, i, :], h_rela[:, i, :],
                c_rela[:, i, :])  # [B 1 256]
            if (i != 0):
                h_temp = torch.cat((h_temp, h_temp_item), dim=1)
                c_temp = torch.cat((c_temp, c_temp_item), dim=1)
            else:
                h_temp = h_temp_item
                c_temp = c_temp_item

            z_what_cal = torch.cat((z_what_last_time[:, i, :], h_rela[:, i, :],
                                    h_temp_item.squeeze(1)),
                                   1)  #[B 50+256+256]
            z_what_item, z_what_mean, z_what_std = self._reparameterized_sample_what(
                z_what_cal)  #[B 50]
            #print("z_what_item.shape",z_what_item.size())
            z_pres_cal = torch.cat((z_what_item, z_where_item, h_rela[:, i, :],
                                    h_temp_item.squeeze(1)),
                                   1)  #[B 50+3+256+256]
            #print("z_pres_cal.shape",z_pres_cal.size())
            z_pres_item = self._reparameterized_sample_pres(
                z_pres_cal, z_pres_last_time[:, i, :])  #[B 1]
            if (i == 0):
                z_pres = z_pres_item.unsqueeze(1)
                z_what = z_what_item.unsqueeze(1)
                z_where = z_where_item.unsqueeze(1)
            else:
                z_pres = torch.cat((z_pres, z_pres_item.unsqueeze(1)), dim=1)
                z_where = torch.cat((z_where, z_where_item.unsqueeze(1)),
                                    dim=1)
                z_what = torch.cat((z_what, z_what_item.unsqueeze(1)), dim=1)
            kl_z_what += kl_divergence(
                Normal(z_what_mean, z_what_std),
                Normal(torch.zeros(50).to(device),
                       torch.ones(50).to(device))).sum(
                           1) * z_pres_item.squeeze()  # [B 1]
            kl_z_where += kl_divergence(
                Normal(z_where_mean, z_where_std),
                Normal(
                    torch.tensor([0.3, 0., 0.]).to(device),
                    torch.tensor([
                        0.1, 1., 1.
                    ]).to(device))).sum(1) * z_pres_item.squeeze()  # [B 1]

        #print("z_pres_prop_shape",z_pres.size())
        #print("h_temp")
        return z_what, z_where, z_pres, kl_z_what, kl_z_where, h_temp  #[B number __length]
Exemplo n.º 4
0
    def train(self,
              x,
              epochs,
              batch_size,
              file,
              print_freq=50,
              x_test=None,
              means=None,
              stds=None,
              lr=0.001):
        """
        :param x:
        :param epochs:
        :param batch_size:
        :param file:
        :param print_freq:
        :param x_test:
        :param means:
        :param stds:
        :param lr:
        :return:
        """
        self.means = means
        self.stds = stds
        self.dir_name = os.path.dirname(file.name)
        self.file_start = file.name[len(self.dir_name) + 1:-4]

        optimiser = optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            lr)

        self.epoch = 0
        for epoch in range(epochs):
            self.epoch = epoch
            optimiser.zero_grad()

            # Select a batch
            batch_idx = torch.randperm(x.shape[0])[:batch_size]
            x_batch = x[batch_idx, ...]  # [batch_size, x.shape[1]]
            target_batch = x_batch[:, -self.n_properties:]

            # Mask of the properties that are missing
            mask_batch = torch.isnan(x_batch[:, -self.n_properties:])

            # To form the context mask we will add properties to the missing values
            mask_context = copy.deepcopy(mask_batch)
            batch_properties = [
                torch.where(~mask_batch[i, ...])[0]
                for i in range(mask_batch.shape[0])
            ]

            for i, properties in enumerate(batch_properties):
                ps = np.random.choice(properties.numpy(),
                                      size=np.random.randint(
                                          low=0, high=properties.shape[0] + 1),
                                      replace=False)

                # add property to those being masked
                mask_context[i, ps] = True
            input_batch = copy.deepcopy(x_batch)
            input_batch[:, -self.n_properties:][mask_batch] = 0.0

            # First use this to compute the posterior distribution over z
            mu_posts, var_posts = self.encoder(input_batch)

            context_batch = copy.deepcopy(input_batch)
            # Now set the property values of the context mask to 0 too.
            context_batch[:, -self.n_properties:][mask_context] = 0.0

            mu_priors, var_priors = self.encoder(context_batch)

            # Sample from the distribution over z.
            z = mu_priors + torch.randn_like(mu_priors) * var_priors**0.5

            mus_y, vars_y = self.decoder.forward(z, mask_batch)

            likelihood_term = 0

            for p in range(self.n_properties):
                target = target_batch[:, p][~mask_batch[:, p]]
                mu_y = mus_y[p].squeeze(1)
                var_y = vars_y[p].squeeze(1)

                ll = (-0.5 * np.log(2 * np.pi) - 0.5 * torch.log(var_y) - 0.5 *
                      ((target - mu_y)**2 / var_y))

                likelihood_term += torch.sum(ll)

            likelihood_term /= torch.sum(~mask_batch)

            # Compute the KL divergence between prior and posterior over z
            z_posteriors = [
                MultivariateNormal(mu, torch.diag_embed(var))
                for mu, var in zip(mu_posts, var_posts)
            ]
            z_priors = [
                MultivariateNormal(mu, torch.diag_embed(var))
                for mu, var in zip(mu_priors, var_priors)
            ]
            kl_div = [
                kl_divergence(z_posterior, z_prior).float()
                for z_posterior, z_prior in zip(z_posteriors, z_priors)
            ]
            kl_div = torch.sum(torch.stack(kl_div))
            kl_div /= torch.sum(~mask_batch)

            loss = kl_div - likelihood_term

            if (epoch % print_freq == 0) and (epoch > 0):
                file.write(
                    '\n Epoch {} Loss: {:4.4f} LL: {:4.4f} KL: {:4.4f}'.format(
                        epoch, loss.item(), likelihood_term.item(),
                        kl_div.item()))

                r2_scores, mlls, rmses = self.metrics_calculator(x,
                                                                 n_samples=100,
                                                                 test=False)
                r2_scores = np.array(r2_scores)
                mlls = np.array(mlls)
                rmses = np.array(rmses)

                file.write('\n R^2 score (train): {:.3f}+- {:.3f}'.format(
                    np.mean(r2_scores), np.std(r2_scores)))
                file.write('\n MLL (train): {:.3f}+- {:.3f} \n'.format(
                    np.mean(mlls), np.std(mlls)))
                file.write('\n RMSE (train): {:.3f}+- {:.3f} \n'.format(
                    np.mean(rmses), np.std(rmses)))
                file.flush()

                if x_test is not None:
                    r2_scores, mlls, rmses = self.metrics_calculator(
                        x_test, n_samples=100, test=True)
                    r2_scores = np.array(r2_scores)
                    mlls = np.array(mlls)
                    rmses = np.array(rmses)
                    file.write('\n R^2 score (test): {:.3f}+- {:.3f}'.format(
                        np.mean(r2_scores), np.std(r2_scores)))
                    file.write('\n MLL (test): {:.3f}+- {:.3f} \n'.format(
                        np.mean(mlls), np.std(mlls)))
                    file.write('\n RMSE (test): {:.3f}+- {:.3f} \n'.format(
                        np.mean(rmses), np.std(rmses)))

                    file.flush()
                    if (self.epoch % 1000) == 0 and (self.epoch > 0):
                        path_to_save = self.dir_name + '/' + self.file_start + '_' + str(
                            self.epoch)
                        np.save(path_to_save + 'r2_scores.npy', r2_scores)
                        np.save(path_to_save + 'mll_scores.npy', mlls)
                        np.save(path_to_save + 'rmse_scores.npy', rmses)

            loss.backward()

            optimiser.step()
Exemplo n.º 5
0
    def kl_divergence(self, datas1, datas2):
        distribution1 = Categorical(datas1)
        distribution2 = Categorical(datas2)

        return kl_divergence(distribution1,
                             distribution2).unsqueeze(1).float().to(device)
    def kl_divergence(self, mean1, std1, mean2, std2):
        distribution1   = Normal(mean1, std1)
        distribution2   = Normal(mean2, std2)

        return kl_divergence(distribution1, distribution2).float().to(device)  
Exemplo n.º 7
0
 def gaussian_kl2(self, m1, s1, m2, s2):
     x = Normal(m1, torch.exp(0.5 * s1))
     y = Normal(m2, torch.exp(0.5 * s2))
     return - kl_divergence(x, y)
Exemplo n.º 8
0
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        # Update the advantage estimator's parameters and return advantage estimates
        adv = self._critic.update(rollouts, use_empirical_returns=False)

        with to.no_grad():
            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr

        # Attach advantages and old log probabilities to rollout
        concat_ros.add_data("adv", adv)
        concat_ros.add_data("log_probs_old", log_probs_old)

        # For logging the gradient norms
        policy_grad_norm = []

        # Iterations over the whole data set
        for e in range(self.num_epoch):

            for batch in tqdm(
                    concat_ros.split_shuffled_batches(
                        self.batch_size,
                        complete_rollouts=self._policy.is_recurrent),
                    total=num_iter_from_rollouts(None, concat_ros,
                                                 self.batch_size),
                    desc=f"Epoch {e}",
                    unit="batches",
                    file=sys.stdout,
                    leave=False,
            ):
                # Reset the gradients
                self.optim.zero_grad()

                # Compute log of the action probabilities for the mini-batch
                log_probs = compute_action_statistics(
                    batch, self._expl_strat).log_probs.to(self.policy.device)

                # Compute policy loss and backpropagate
                loss = self.loss_fcn(
                    log_probs, batch.log_probs_old.to(self.policy.device),
                    batch.adv.to(self.policy.device))
                loss.backward()

                # Clip the gradients if desired
                policy_grad_norm.append(
                    self.clip_grad(self._expl_strat.policy,
                                   self.max_grad_norm))

                # Call optimizer
                self.optim.step()

                if to.isnan(self._expl_strat.noise.std).any():
                    raise RuntimeError(
                        f"At least one exploration parameter became NaN! The exploration parameters are"
                        f"\n{self._expl_strat.std.detach().cpu().numpy()}")

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Additional logging
        if self.log_loss:
            with to.no_grad():
                act_stats = compute_action_statistics(concat_ros,
                                                      self._expl_strat)
                log_probs_new = act_stats.log_probs
                act_distr_new = act_stats.act_distr
                loss_after = self.loss_fcn(log_probs_new, log_probs_old, adv)
                kl_avg = to.mean(kl_divergence(
                    act_distr_old,
                    act_distr_new))  # mean seeking a.k.a. inclusive KL
                self.logger.add_value("loss after", loss_after, 4)
                self.logger.add_value("KL(old_new)", kl_avg, 4)

        # Logging
        self.logger.add_value("avg expl strat std",
                              to.mean(self._expl_strat.noise.std), 4)
        self.logger.add_value("expl strat entropy",
                              self._expl_strat.noise.get_entropy(), 4)
        self.logger.add_value("avg grad norm policy",
                              np.mean(policy_grad_norm), 4)
        if self._lr_scheduler is not None:
            self.logger.add_value("avg lr",
                                  np.mean(self._lr_scheduler.get_last_lr()), 6)
Exemplo n.º 9
0
    def discover(self, x, x_embed, prev_relation):
        """
        Discover step for a single object in a single time step.
        
        This is basically the same as propagate, but without a temporal state
        input. However, to share the same Predict module, we will use an empty
        temporal state instead.
        
        There are multiple code replication here. I do this because refactoring
        will not be a good abstraction.
        
        Args:
            x: original image. Size (B, C, H, W)
            x_embed: extracted image feature. Size (B, N)
            prev_relation: see RelationState
            
        Returns:
            temporal_state: TemporalState
            relation_state:  RelationState
            kl: kl divergence for all z's. (B,)
            z_pres_likelihood: q(z_pres|x). (B,)
        """
        # First, encode relation info to get current h^{R, i}_t
        # Each being (B, N)
        h_rel, c_rel = self.rnn_relation(prev_relation.object.get_encoding(),
                                         (prev_relation.h, prev_relation.c))
        # (B, N)
        # Predict where and pres, using h^R, and x
        predict_input = torch.cat((h_rel, x_embed), dim=-1)
        # (B, 4), (B, 4), (B, 1)
        z_where_loc, z_where_scale, z_pres_prob = self.discover_predict(predict_input)

        # Sample from z_pres posterior. Shape (B, 1)
        # NOTE: don't use zero probability otherwise log q(z|x) will not work
        z_pres_post = Bernoulli(z_pres_prob)
        z_pres = z_pres_post.sample()
        # Mask z_pres. You don't have do this for where and what because
        #   - their KL will be masked
        #   - they will be masked when computing the likelihood
        z_pres = z_pres * prev_relation.object.z_pres

        # Sample from z_where posterior, (B, 4)
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()
        # Mask
        z_where = z_where * z_pres

        # Extract glimpse from x, shape (B, 1, H, W)
        glimpse = self.image_to_glimpse(x, z_where, inverse=True)

        # Compute postribution over z_what and sample
        z_what_loc, z_what_scale = self.discover_encoder(glimpse)
        z_what_post = Normal(z_what_loc, z_what_scale)
        # (B, N)
        z_what = z_what_post.rsample()
        # Mask
        z_what = z_what * z_pres

        # Construct prior distributions
        z_what_prior = Normal(arch.z_what_loc_prior, arch.z_what_scale_prior)
        z_where_prior = Normal(arch.z_where_loc_prior, arch.z_where_scale_prior)
        z_pres_prior = Bernoulli(arch.z_pres_prob_prior)

        # Compute KL divergence. Each (B, N)
        kl_z_what = kl_divergence(z_what_post, z_what_prior)
        kl_z_where = kl_divergence(z_where_post, z_where_prior)
        kl_z_pres = kl_divergence(z_pres_post, z_pres_prior)

        # Mask these terms.
        # Note for kl_z_pres, we will need to use z_pres of previous time step.
        # This this because even z_pres[t] = 0, p(z_pres[t] | z_prse[t-1])
        # cannot be ignored.

        # Also Note that we do not mask where and what here. That's OK since We will
        # mask the image outside of the function later.

        kl_z_what = kl_z_what * z_pres
        kl_z_where = kl_z_where * z_pres
        kl_z_pres = kl_z_pres * prev_relation.object.z_pres
        
        vis_logger['kl_pres_list'].append(kl_z_pres.mean())
        vis_logger['kl_what_list'].append(kl_z_what.mean())
        vis_logger['kl_where_list'].append(kl_z_where.mean())
        
        # (B,) here, after reduction
        kl = kl_z_what.sum(dim=-1) + kl_z_where.sum(dim=-1) + kl_z_pres.sum(dim=-1)

        # Finally, we compute the discrete likelihoods.
        z_pres_likelihood = z_pres_post.log_prob(z_pres)

        # Note we also need to mask some of this terms since they do not depend
        # on model parameter. (B, 1) here
        z_pres_likelihood = z_pres_likelihood * prev_relation.object.z_pres
        z_pres_likelihood = z_pres_likelihood.squeeze()
        
        
        # Compute id. If z_pres = 1, highest_id += 1, and use that id. Otherwise
        # we do not change the highest id and set id to zero
        self.highest_id += z_pres
        id = self.highest_id * z_pres

        B = x.size(0)
        # Collect terms into new states.
        object_state = ObjectState(z_pres, z_where, z_what, id=id, z_pres_prob=z_pres_prob, object_enc=glimpse.view(B, -1), mask=torch.zeros_like(glimpse.view(B, -1)), proposal=torch.zeros_like(z_where))
        
        # For temporal and prior state, we will use the initial state.
        
        temporal_state = TemporalState.get_initial_state(B, object_state)
        relation_state = RelationState(object_state, h_rel, c_rel)


        return temporal_state, relation_state, kl, z_pres_likelihood
Exemplo n.º 10
0
    def train(self, x_trains, y_trains, x_tests, y_tests, x_scalers, y_scalers, batch_size,
              iterations, testing, plotting, dataname, print_freq):
        """
        :param x_trains: A np.array with dimensions [N_functions, [N_train, x_size]]
                         containing the training data (x values)
        :param y_trains: A np.array with dimensions [N_functions, [N_train, y_size]]
                         containing the training data (y values)
        :param x_tests: A tensor with dimensions [N_functions, [N_test, x_size]]
                        containing the test data (x values)
        :param y_tests: A tensor with dimensions [N_functions, [N_test, y_size]]
                        containing the test data (y values)
        :param x_scalers: The standard scaler used when testing == True to convert the
                         x values back to the correct scale.
        :param y_scalers: The standard scaler used when testing == True to convert the predicted
                         y values back to the correct scale.
        :param batch_size: An integer describing the number of times we should
                           sample the set of context points used to form the
                           aggregated embedding during training, given the number
                           of context points to be sampled N_context. When testing
                           this is set to 1
        :param iterations: An integer, describing the number of iterations. In this case it
                           also corresponds to the number of times we sample the number of
                           context points N_context
        :param testing: A Boolean object; if set to be True, then every 30 iterations the
                        R^2 score and RMSE values will be calculated and printed for
                        both the train and test data
        :param print_freq:
        :param dataname:
        :param plotting:
        :return:
        """

        n_functions = len(x_trains)

        for iteration in range(iterations):
            self.optimiser.zero_grad()

            # Sample the function from the set of functions
            idx_function = np.random.randint(n_functions)

            x_train = x_trains[idx_function]
            y_train = y_trains[idx_function]

            max_target = x_train.shape[0]

            # During training, we sample n_target points from the function, and
            # randomly select n_context points to condition on.

            num_target = torch.randint(low=5, high=int(max_target), size=(1,))
            num_context = torch.randint(low=3, high=int(num_target), size=(1,))

            idx = [np.random.permutation(x_train.shape[0])[:num_target] for i in
                   range(batch_size)]
            idx_context = [idx[i][:num_context] for i in range(batch_size)]

            x_target = [x_train[idx[i], :] for i in range(batch_size)]
            y_target = [y_train[idx[i], :] for i in range(batch_size)]
            x_context = [x_train[idx_context[i], :] for i in range(batch_size)]
            y_context = [y_train[idx_context[i], :] for i in range(batch_size)]

            x_target = torch.stack(x_target)
            y_target = torch.stack(y_target)
            x_context = torch.stack(x_context)
            y_context = torch.stack(y_context)

            # The deterministic encoder outputs the deterministic embedding r.
            r = self.det_encoder.forward(x_context, y_context, x_target)
            # [batch_size, N_target, r_size]

            # The latent encoder outputs a prior distribution over the
            # latent embedding z (conditioned only on the context points).
            z_priors, _, _ = self.lat_encoder.forward(x_context, y_context)
            z_posteriors, _, _ = self.lat_encoder.forward(x_target, y_target)

            # Sample z from the prior distribution.
            zs = [dist.rsample() for dist in z_priors]      # [batch_size, r_size]
            z = torch.cat(zs)
            z = z.view(-1, self.r_size)

            # The input to the decoder is the concatenation of the target x values,
            # the deterministic embedding r and the latent variable z
            # the output is the predicted target y for each value of x.
            dists, _, _ = self.decoder.forward(x_target.float(), r.float(), z.float())

            # Calculate the loss
            log_ps = [dist.log_prob(y_target[i, ...].float()) for i, dist in enumerate(dists)]
            log_ps = torch.cat(log_ps)

            kl_div = [kl_divergence(z_posterior, z_prior).float() for z_posterior, z_prior
                      in zip(z_posteriors, z_priors)]
            kl_div = torch.cat(kl_div)

            loss = -(torch.mean(log_ps) - torch.mean(kl_div))
            self.losslogger = loss

            # The loss should generally decrease with number of iterations, though it is not
            # guaranteed to decrease monotonically because at each iteration the set of
            # context points changes randomly.
            if iteration % print_freq == 0:
                print("Iteration " + str(iteration) + ":, Loss = {:.3f}".format(loss.item()))
                # We can set testing = True if we want to check that we are not over-fitting.
                if testing:
                    metrics_calculator(x_trains, y_trains, x_tests, y_tests, x_scalers,
                                       y_scalers, self.predict, dataname, plotting, iteration)
            loss.backward()
            self.optimiser.step()
Exemplo n.º 11
0
    def propagate(self, x, x_embed, prev_relation, prev_temporal):
        """
        Propagate step for a single object in a single time step.
        
        In this process, even empty objects are encoded in relation h's. This
        can be avoided by directly passing the previous relation state on to the
        next spatial step. May do this later.
        
        Args:
            x: original image. Size (B, C, H, W)
            x_embed: extracted image feature. Size (B, N)
            prev_relation: see RelationState
            prev_temporal: see TemporalState
            
        Returns:
            temporal_state: TemporalState
            relation_state:  RelationState
            kl: kl divergence for all z's. (B, 1)
            z_pres_likelihood: q(z_pres|x). (B, 1)
        """
        
        # First, encode relation and temporal info to get current h^{T, i}_t and
        # current h^{R, i}_t
        # Each being (B, N)
        h_rel, c_rel = self.rnn_relation(prev_relation.object.get_encoding(),
                                         (prev_relation.h, prev_relation.c))
        h_tem, c_tem = self.rnn_temporal(prev_temporal.object.get_encoding(),
                                         (prev_temporal.h, prev_temporal.c))
        
        # Compute proposal region to look at
        # (B, 4)
        proposal_region_delta = self.propagate_proposal(h_tem)
        proposal_region = prev_temporal.object.z_where + proposal_region_delta
        # self.i += 1
        # if self.i % 1000 == 0:
        #     print(proposal_region[0])
        proposal = self.image_to_glimpse(x, proposal_region, inverse=True)
        proposal_embed = self.proposal_embedding(proposal)
        
        # (B, N)
        # Predict where and pres, using h^T, h^T and x
        predict_input = torch.cat((h_rel, h_tem, proposal_embed), dim=-1)
        # (B, 4), (B, 4), (B, 1)
        # Note we only predict delta here.
        z_where_delta_loc, z_where_delta_scale, z_pres_prob = self.propagate_predict(predict_input)
        
        # Sample from z_pres posterior. Shape (B, 1)
        # NOTE: don't use zero probability otherwise log q(z|x) will not work
        z_pres_post = Bernoulli(z_pres_prob)
        z_pres = z_pres_post.sample()
        # Mask z_pres. You don't have do this for where and what because
        #   - their KL will be masked
        #   - they will be masked when computing the likelihood
        z_pres = z_pres * prev_temporal.object.z_pres
        
        # Sample from z_where posterior, (B, 4)
        z_where_delta_post = Normal(z_where_delta_loc, z_where_delta_scale)
        z_where_delta = z_where_delta_post.rsample()
        z_where = prev_temporal.z_where + z_where_delta
        # Mask
        z_where = z_where * z_pres
        
        # Extract glimpse from x, shape (B, 1, H, W)
        glimpse = self.image_to_glimpse(x, z_where, inverse=True)
        # This is important for handling overlap
        glimpse_mask = self.glimpse_mask(h_tem)
        glimpse = glimpse * glimpse_mask
        
        # Compute postribution over z_what and sample
        z_what_delta_loc, z_what_delta_scale = self.propagate_encoder(glimpse, h_tem, h_rel)
        z_what_delta_post = Normal(z_what_delta_loc, z_what_delta_scale)
        # (B, N)
        z_what_delta = z_what_delta_post.rsample()
        z_what = prev_temporal.z_what + z_what_delta
        # Mask
        z_what = z_what * z_pres
        
        # Now we compute KL divergence and discrete likelihood. Before that, we
        # will need to compute the recursive prior. This is parametrized by
        # previous object state (z) and hidden states from LSTM.
        
        
        # Compute prior for current step
        (z_what_delta_loc_prior, z_what_delta_scale_prior, z_where_delta_loc_prior,
            z_where_delta_scale_prior, z_pres_prob_prior) = (
            self.propagate_prior(h_tem))
        
        # TODO: demand that scale to be small to guarantee consistency
        if DEBUG:
            z_what_delta_loc_prior = arch.z_what_delta_loc_prior.expand_as(z_what_delta_loc_prior)
            z_what_delta_scale_prior = arch.z_what_delta_scale_prior.expand_as(z_what_delta_scale_prior)
            z_where_delta_scale_prior = arch.z_where_delta_scale_prior.expand_as(z_where_delta_scale_prior)


        # Construct prior distributions
        z_what_delta_prior = Normal(z_what_delta_loc_prior, z_what_delta_scale_prior)
        z_where_delta_prior = Normal(z_where_delta_loc_prior, z_where_delta_scale_prior)
        z_pres_prior = Bernoulli(z_pres_prob_prior)
        
        # Compute KL divergence. Each (B, N)
        kl_z_what = kl_divergence(z_what_delta_post, z_what_delta_prior)
        kl_z_where = kl_divergence(z_where_delta_post, z_where_delta_prior)
        kl_z_pres = kl_divergence(z_pres_post, z_pres_prior)
        
        # Mask these terms.
        # Note for kl_z_pres, we will need to use z_pres of previous time step.
        # This this because even z_pres[t] = 0, p(z_pres[t] | z_prse[t-1])
        # cannot be ignored.
        
        # Also Note that we do not mask where and what here. That's OK since We will
        # mask the image outside of the function later.
        
        kl_z_what = kl_z_what * z_pres
        kl_z_where = kl_z_where * z_pres
        kl_z_pres = kl_z_pres * prev_temporal.object.z_pres
        
        vis_logger['kl_pres_list'].append(kl_z_pres.mean())
        vis_logger['kl_what_list'].append(kl_z_what.mean())
        vis_logger['kl_where_list'].append(kl_z_where.mean())

        # (B,) here, after reduction
        kl = kl_z_what.sum(dim=-1) + kl_z_where.sum(dim=-1) + kl_z_pres.sum(dim=-1)
        
        # Finally, we compute the discrete likelihoods.
        z_pres_likelihood = z_pres_post.log_prob(z_pres)
        
        # Note we also need to mask some of this terms since they do not depend
        # on model parameter. (B, 1) here
        z_pres_likelihood = z_pres_likelihood * prev_temporal.object.z_pres
        z_pres_likelihood = z_pres_likelihood.squeeze()
        
        # Compute id. If z_pres is 1, then inherit that id. Otherwise set it to
        # zero
        id = prev_temporal.object.id * z_pres
        
        B = x.size(0)
        # Collect terms into new states.
        object_state = ObjectState(z_pres, z_where, z_what, id, z_pres_prob=z_pres_prob, object_enc=glimpse.view(B, -1), mask=glimpse_mask.view(B, -1), proposal=proposal_region)
        temporal_state = TemporalState(object_state, h_tem, c_tem)
        relation_state = RelationState(object_state, h_rel, c_rel)
        
        return temporal_state, relation_state, kl, z_pres_likelihood
Exemplo n.º 12
0
def eq_dist(x, y):
    if type(x) != type(y):
        return False
    if x.batch_shape != y.batch_shape:
        return False
    return kl_divergence(x, y) == 0
Exemplo n.º 13
0
def train(current_epoch,
          loader,
          aux_loader,
          epochs_normal=1,
          epochs_aux=6,
          beta_clone=1.,
          aux_every=12):

    old_policy.load_state_dict(policy.state_dict())
    for epoch in range(epochs_normal):
        for counter, data in enumerate(loader):
            s, a, r, ns = data

            state_estimates = value(s)
            value_loss = F.mse_loss(state_estimates, r)

            next_state_estimates = value(ns)
            advantage = r + 0.99 * next_state_estimates - state_estimates

            adam_v.zero_grad()
            value_loss.backward()
            adam_v.step()

            mean, std = policy(s)
            dist = Normal(mean, std)
            log_probs = dist.log_prob(a)

            mean_o, std_o = old_policy(s)
            old_dist = Normal(mean_o, std_o)
            old_log_probs = old_dist.log_prob(a)

            policies_ratio = (log_probs - old_log_probs).exp()
            policy_loss = -torch.min(
                policies_ratio * advantage.detach(),
                policies_ratio.clamp(1. - epsilon_clip, 1. + epsilon_clip) *
                advantage.detach()).mean()

            # policy_loss = -(log_probs * advantage.detach()).mean()
            adam_p.zero_grad()
            policy_loss.backward()
            adam_p.step()

    if current_epoch % aux_every == 0:
        aux_policy.load_state_dict(policy.state_dict())
        for epoch in range(epochs_aux):
            batch_counter = 0
            for data in aux_loader:
                s, r = data
                with torch.no_grad():
                    aux_old_mean, aux_old_std = aux_policy(s)
                    aux_dist = Normal(aux_old_mean, aux_old_std)
                current_mean, current_std = policy(s)
                current_dist = Normal(current_mean, current_std)
                kl_loss = kl.kl_divergence(aux_dist, current_dist).mean()
                aux_value_loss = F.mse_loss(policy.get_values(s), r)
                aux_loss = aux_value_loss + beta_clone * kl_loss

                adam_p.zero_grad()
                aux_loss.backward()
                adam_p.step()

                state_estimates = value(s)
                value_loss = F.mse_loss(state_estimates, r)
                adam_v.zero_grad()
                value_loss.backward()
                adam_v.step()

                batch_counter += 1
                if batch_counter == 16:
                    break

    return state_estimates.mean().item(), value_loss.item(), advantage.mean(
    ).item()
Exemplo n.º 14
0
    def train(self,
              x,
              y,
              x_test=None,
              y_test=None,
              x_scaler=None,
              y_scaler=None,
              nz_samples=1,
              ny_samples=1,
              batch_size=10,
              lr=0.001,
              epochs=3000,
              print_freq=100,
              VERBOSE=False,
              dataname=None):
        """

        :param x: [n_functions, [n_train, x_dim]]
        :param y: [n_functions, [n_train, y_dim]]
        :param lr:
        :param iterations:
        :return:
        """
        self.optimiser = optim.Adam(list(self.det_encoder.parameters()) +
                                    list(self.prob_encoder.parameters()) +
                                    list(self.decoder.parameters()),
                                    lr=lr)

        for epoch in range(epochs):
            self.optimiser.zero_grad()

            # Sample the function from the set of functions
            x_context, y_context, x_target, y_target = batch_sampler(
                x, y, batch_size)

            # Make a forward pass through the ANP to obtain a distribution over the target set.
            mu_y, var_y, mus_z, vars_z, mus_z_posterior, vars_z_posterior = self.forward(
                x_context, y_context, x_target, y_target, nz_samples,
                ny_samples, batch_size)  #[batch_size*n_target, y_dim] x2

            # Measure the log probability of observing y_target given mu_y, var_y.
            log_ps = MultivariateNormal(
                mu_y, torch.diag_embed(var_y)).log_prob(y_target.float())
            log_ps = log_ps.reshape(batch_size, -1).sum(dim=-1)
            log_ps = torch.mean(log_ps)

            # Compute the KL divergence between prior and posterior over z
            z_posteriors = [
                MultivariateNormal(mu, torch.diag_embed(var))
                for mu, var in zip(mus_z_posterior, vars_z_posterior)
            ]
            z_priors = [
                MultivariateNormal(mu, torch.diag_embed(var))
                for mu, var in zip(mus_z, vars_z)
            ]

            kl_div = [
                kl_divergence(z_posterior, z_prior).float()
                for z_posterior, z_prior in zip(z_posteriors, z_priors)
            ]
            kl_div = torch.mean(torch.stack(kl_div))

            # Calculate the loss function from this.
            loss = -(log_ps - kl_div)
            self.losslogger = loss

            if epoch % print_freq == 0:
                print(
                    'Epoch {:.0f}: Loss = {:.5f} \t LL = {:.5f} \t KL = {:.5f}'
                    .format(epoch, loss, log_ps, kl_div))
                if epoch % int(10 * print_freq) == 0:
                    if VERBOSE:
                        metrics_calculator(self, 'anp', x, y, x_test, y_test,
                                           dataname, epoch, x_scaler, y_scaler)

            loss.backward()
            self.optimiser.step()
Exemplo n.º 15
0
    def train(self, x_train, y_train, x_test, y_test, x_scaler, y_scaler,
              batch_size, lr, iterations, testing, plotting):
        """
        :param x_train: A tensor with dimensions [N_train, x_size] containing the training
                        data (x values)
        :param y_train: A tensor with dimensions [N_train, y_size] containing the training
                        data (y values)
        :param x_test: A tensor with dimensions [N_test, x_size] containing the test data
                       (x values)
        :param y_test: A tensor with dimensions [N_test, y_size] containing the test data
                       (y values)
        :param x_scaler: The standard scaler used when testing == True to convert the
                         x values back to the correct scale.
        :param y_scaler: The standard scaler used when testing == True to convert the predicted
                         y values back to the correct scale.
        :param batch_size: An integer describing the number of times we should
                                    sample the set of context points used to form the
                                    aggregated embedding during training, given the number
                                    of context points to be sampled N_context. When testing
                                    this is set to 1
        :param lr: A float number, describing the optimiser's learning rate
        :param iterations: An integer, describing the number of iterations. In this case it
                           also corresponds to the number of times we sample the number of
                           context points N_context
        :param testing: A Boolean object; if set to be True, then every 30 iterations the
                        R^2 score and RMSE values will be calculated and printed for
                        both the train and test data
        :return:
        """
        self.gp_sampler = GPSampler(data=(x_train, y_train))
        self.batch_size = batch_size
        self._max_num_context = x_train.shape[0]
        self.iterations = iterations

        #Convert the data for use in PyTorch.
        x_train = torch.from_numpy(x_train).float()
        y_train = torch.from_numpy(y_train).float()
        x_test = torch.from_numpy(x_test).float()
        y_test = torch.from_numpy(y_test).float()

        # At prediction time the context points comprise the entire training set.
        x_tot_context = torch.unsqueeze(x_train, dim=0)
        y_tot_context = torch.unsqueeze(y_train, dim=0)

        for iteration in range(iterations):
            self.optimiser.zero_grad()

            # Randomly select the number of context points N_context (uniformly from 3 to
            # N_train)
            num_context = np.random.randint(low=1, high=self._max_num_context)

            # Randomly select N_context context points from the training data, a total of
            # batch_size times.
            x_context, y_context, x_target, y_target = self.gp_sampler.sample(
                batch_size=self.batch_size,
                train_size=50,
                num_context=num_context,
                x_min=-4,
                x_max=4)

            x_context = torch.from_numpy(x_context).float()
            y_context = torch.from_numpy(y_context).float()
            x_target = torch.from_numpy(x_target).float()
            y_target = torch.from_numpy(y_target).float()

            # The input to both the deterministic and latent encoder is (x, y)_i for all data points in the set of context
            # points.
            input_context = torch.cat((x_context, y_context), dim=2)
            input_target = torch.cat((x_target, y_target), dim=2)

            #The deterministic encoder outputs the deterministic embedding r.
            r = self.det_encoder.forward(
                x_context, y_context,
                x_target)  #[batch_size, N_target, r_size]

            # The latent encoder outputs a prior distribution over the latent embedding z (conditioned only on the context points).
            z_priors, mu_prior, sigma_prior = self.lat_encoder.forward(
                x_context, y_context)

            if y_target is not None:
                z_posteriors, mu_posterior, sigma_posterior = self.lat_encoder.forward(
                    x_target, y_target)
                zs = [dist.sample()
                      for dist in z_posteriors]  #[batch_size, r_size]

            else:
                zs = [dist.sample()
                      for dist in z_priors]  #[batch_size, r_size]

            z = torch.cat(zs)
            z = z.view(-1, self.r_size)

            # The input to the decoder is the concatenation of the target x values, the deterministic embedding r and the latent variable z
            # the output is the predicted target y for each value of x.
            dists_y, _, _ = self.decoder.forward(x_target.float(), r.float(),
                                                 z.float())

            # Calculate the loss
            log_ps = [
                dist.log_prob(y_target[i, ...].float())
                for i, dist in enumerate(dists_y)
            ]
            log_ps = torch.cat(log_ps)

            kl_div = [
                kl_divergence(z_posterior, z_prior).float()
                for z_posterior, z_prior in zip(z_posteriors, z_priors)
            ]
            kl_div = torch.tensor(kl_div)

            loss = -(torch.mean(log_ps) - torch.mean(kl_div))
            self.losslogger = loss

            # The loss should generally decrease with number of iterations, though it is not
            # guaranteed to decrease monotonically because at each iteration the set of
            # context points changes randomly.
            if iteration % 200 == 0:
                print("Iteration " + str(iteration) +
                      ":, Loss = {:.3f}".format(loss.item()))
                # We can set testing = True if we want to check that we are not overfitting.
                if testing:

                    r2_train_list = []
                    rmse_train_list = []
                    nlpd_train_list = []
                    r2_test_list = []
                    rmse_test_list = []
                    nlpd_test_list = []

                    #Useful for determining uncertainty due to sampling z.
                    for j in range(10):
                        _, predict_train_mean, predict_train_var = self.predict(
                            x_tot_context, y_tot_context, x_tot_context)
                        predict_train_mean = np.squeeze(
                            predict_train_mean.data.numpy(), axis=0)
                        predict_train_var = np.squeeze(
                            predict_train_var.data.numpy(), axis=0)

                        # We transform the standardised predicted and actual y values back to the original data
                        # space
                        y_train_mean_pred = y_scaler.inverse_transform(
                            predict_train_mean)
                        y_train_var_pred = y_scaler.var_ * predict_train_var
                        y_train_untransformed = y_scaler.inverse_transform(
                            y_train)

                        r2_train = r2_score(y_train_untransformed,
                                            y_train_mean_pred)
                        nlpd_train = nlpd(y_train_mean_pred, y_train_var_pred,
                                          y_train_untransformed)
                        rmse_train = np.sqrt(
                            mean_squared_error(y_train_untransformed,
                                               y_train_mean_pred))
                        r2_train_list.append(r2_train)
                        rmse_train_list.append(rmse_train)
                        nlpd_train_list.append(nlpd_train)

                        x_test = torch.unsqueeze(x_test, dim=0)
                        _, predict_test_mean, predict_test_var = self.predict(
                            x_tot_context, y_tot_context, x_test)
                        x_test = torch.squeeze(x_test, dim=0)
                        predict_test_mean = np.squeeze(
                            predict_test_mean.data.numpy(), axis=0)
                        predict_test_var = np.squeeze(
                            predict_test_var.data.numpy(), axis=0)

                        # We transform the standardised predicted and actual y values back to the original data
                        # space
                        y_test_mean_pred = y_scaler.inverse_transform(
                            predict_test_mean)
                        y_test_var_pred = y_scaler.var_ * predict_test_var
                        y_test_untransformed = y_scaler.inverse_transform(
                            y_test)

                        indices = np.random.permutation(
                            y_test_untransformed.shape[0])[0:20]
                        r2_test = r2_score(y_test_untransformed[indices, 0],
                                           y_test_mean_pred[indices, 0])
                        rmse_test = np.sqrt(
                            mean_squared_error(
                                y_test_untransformed[indices, 0],
                                y_test_mean_pred[indices, 0]))
                        nlpd_test = nlpd(y_test_mean_pred[indices, 0],
                                         y_test_var_pred[indices, 0],
                                         y_test_untransformed[indices, 0])

                        r2_test_list.append(r2_test)
                        rmse_test_list.append(rmse_test)
                        nlpd_test_list.append(nlpd_test)

                    r2_train_list = np.array(r2_train_list)
                    rmse_train_list = np.array(rmse_train_list)
                    nlpd_train_list = np.array(nlpd_train_list)
                    r2_test_list = np.array(r2_test_list)
                    rmse_test_list = np.array(rmse_test_list)
                    nlpd_test_list = np.array(nlpd_test_list)

                    print("\nR^2 score (train): {:.3f} +- {:.3f}".format(
                        np.mean(r2_train_list),
                        np.std(r2_train_list) / np.sqrt(len(r2_train_list))))
                    #print("RMSE (train): {:.3f} +- {:.3f}".format(np.mean(rmse_train_list) / np.sqrt(
                    #len(rmse_train_list))))
                    print("NLPD (train): {:.3f} +- {:.3f}".format(
                        np.mean(nlpd_train_list),
                        np.std(nlpd_train_list) /
                        np.sqrt(len(nlpd_train_list))))
                    print("R^2 score (test): {:.3f} +- {:.3f}".format(
                        np.mean(r2_test_list),
                        np.std(r2_test_list) / np.sqrt(len(r2_test_list))))
                    #print("RMSE (test): {:.3f} +- {:.3f}".format(np.mean(rmse_test_list),
                    #np.std(rmse_test_list) / np.sqrt(len(rmse_test_list))))
                    print("NLPD (test): {:.3f} +- {:.3f}\n".format(
                        np.mean(nlpd_test_list),
                        np.std(nlpd_test_list) / np.sqrt(len(nlpd_test_list))))

                    if iteration % 1000 == 0:
                        if plotting:
                            x_c = x_scaler.inverse_transform(np.array(x_train))
                            y_c = y_train_untransformed
                            x_t = x_scaler.inverse_transform(np.array(x_test))
                            y_t = x_t**3

                            plt.figure(figsize=(7, 7))
                            plt.scatter(x_c,
                                        y_c,
                                        color='red',
                                        s=15,
                                        marker='o',
                                        label="Context points")
                            plt.plot(x_t,
                                     y_t,
                                     linewidth=1,
                                     color='red',
                                     label="Ground truth")
                            plt.plot(x_t,
                                     y_test_mean_pred,
                                     color='darkcyan',
                                     linewidth=1,
                                     label='Mean prediction')
                            plt.plot(x_t[:, 0],
                                     y_test_mean_pred[:, 0] -
                                     1.96 * np.sqrt(y_test_var_pred[:, 0]),
                                     linestyle='-.',
                                     marker=None,
                                     color='darkcyan',
                                     linewidth=0.5)
                            plt.plot(x_t[:, 0],
                                     y_test_mean_pred[:, 0] +
                                     1.96 * np.sqrt(y_test_var_pred[:, 0]),
                                     linestyle='-.',
                                     marker=None,
                                     color='darkcyan',
                                     linewidth=0.5,
                                     label='Two standard deviations')
                            plt.fill_between(
                                x_t[:, 0],
                                y_test_mean_pred[:, 0] -
                                1.96 * np.sqrt(y_test_var_pred[:, 0]),
                                y_test_mean_pred[:, 0] +
                                1.96 * np.sqrt(y_test_var_pred[:, 0]),
                                color='cyan',
                                alpha=0.2)
                            plt.title('Predictive distribution')
                            plt.ylabel('f(x)')
                            plt.yticks([-80, -60, -40, -20, 0, 20, 40, 60, 80])
                            plt.ylim(-80, 80)
                            plt.xlim(-4, 4)
                            plt.xlabel('x')
                            plt.xticks([-4, -2, 0, 2, 4])
                            plt.legend()
                            plt.savefig('results/anp_1dreg_crossatt_2selfatt' +
                                        str(iteration) + '.png')

            loss.backward()
            self.optimiser.step()
Exemplo n.º 16
0
 beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model(
     init_state, actions[:-1], init_belief,
     bottle(encoder, (observations[1:], )), nonterminals[:-1])
 # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
 observation_loss = F.mse_loss(
     bottle(observation_model, (beliefs, posterior_states)),
     observations[1:],
     reduction='none').sum(
         dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
 reward_loss = F.mse_loss(bottle(reward_model,
                                 (beliefs, posterior_states)),
                          rewards[:-1],
                          reduction='none').mean(dim=(0, 1))
 kl_loss = torch.max(
     kl_divergence(Normal(posterior_means, posterior_std_devs),
                   Normal(prior_means,
                          prior_std_devs)).sum(dim=2), free_nats
 ).mean(
     dim=(0, 1)
 )  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
 if args.global_kl_beta != 0:
     kl_loss += args.global_kl_beta * kl_divergence(
         Normal(posterior_means, posterior_std_devs),
         global_prior).sum(dim=2).mean(dim=(0, 1))
 # Calculate latent overshooting objective for t > 0
 if args.overshooting_kl_beta != 0:
     overshooting_vars = [
     ]  # Collect variables for overshooting to process in batch
     for t in range(1, args.chunk_size - 1):
         d = min(t + args.overshooting_distance,
                 args.chunk_size - 1)  # Overshooting distance
Exemplo n.º 17
0
def train(netG, netD, opt_G, opt_D, opt_E):
	D_real = D_fake = D_z_kl = G_real = Z_recon = R_kl = 0
	fixed_z = torch.randn(64, Z_DIM).to(device)

	saved_image_folder, saved_model_folder = make_folders(SAVE_FOLDER, TRIAL_NAME)

	for n_iter in tqdm.tqdm(range(0, MAX_ITERATION+1)):

		if n_iter % SAVE_IMAGE_INTERVAL == 0:
			save_image_from_z(netG, fixed_z, pjoin(saved_image_folder, "z_%d.jpg"%n_iter))
			save_image_from_r(netG, R_DIM, pjoin(saved_image_folder, "r_%d.jpg"%n_iter))
		if n_iter % SAVE_MODEL_INTERVAL == 0:
			save_model(netG, netD, pjoin(saved_model_folder, "%d.pth"%n_iter))	
		
		### 0. prepare data
		real_image = next(dataloader)[0].to(device)

		z = torch.randn(BATCH_SIZE, Z_DIM).to(device)
		# e(r|z) as the likelihood of r given z
		r_sampler = netG.r_sampler(z)
		g_image = netG.generate(r_sampler.sample())

		### 1. Train Discriminator on real and generated data
		netD.zero_grad()
		pred_f = netD.discriminate(g_image.detach())
		pred_r, rec_z = netD(real_image)
		d_loss = loss_bce(torch.sigmoid(pred_r), torch.ones(pred_r.size()).to(device)) \
			+ loss_bce(torch.sigmoid(pred_f), torch.zeros(pred_f.size()).to(device))
		q_loss = KL_Loss(rec_z)
		#d_loss.backward()
		total_loss = d_loss + q_loss
		total_loss.backward()
		opt_D.step()

		# record the loss values
		D_real += torch.sigmoid(pred_r).mean().item()
		D_fake += torch.sigmoid(pred_f).mean().item()
		D_z_kl += q_loss.item()

		### 2. Train Generator
		netD.zero_grad()
		netG.zero_grad()
		# q(z|x) as the posterior of z given x
		pred_g, z_posterior = netD(g_image)
		# GAN loss for generator
		g_loss = LAMBDA_G * loss_bce(torch.sigmoid(pred_g), torch.ones(pred_g.size()).to(device))
		# reconstruction loss of z
		## TODO
		## question here: as stated in the paper-algorithm-1: this part should be a - log(q(z|x)) instead of mse
		recon_loss = loss_mse(z_posterior, z)
		# kl loss between e(r|z) || m(r) as a variational inference
		#kl_loss = BETA_KL * torch.distributions.kl.kl_divergence(r_likelihood, M_r).mean()
		kl_loss = BETA_KL * kl_divergence(r_sampler, M_r).mean()
		total_loss = g_loss + recon_loss + kl_loss
		total_loss.backward()
		opt_E.step()
		opt_G.step()

		# record the loss values
		G_real += torch.sigmoid(pred_g).mean().item()
		Z_recon += recon_loss.item()
		R_kl += kl_loss.item()

		if n_iter % LOG_INTERVAL == 0 and n_iter > 0:
			print("D(x): %.5f    D(G(z)): %.5f    D_kl: %.5f    G(z): %.5f    Z_rec: %.5f    R_kl: %.5f"%\
				(D_real/LOG_INTERVAL, D_fake/LOG_INTERVAL, D_z_kl/LOG_INTERVAL, G_real/LOG_INTERVAL, Z_recon/LOG_INTERVAL, R_kl/LOG_INTERVAL))
			D_real = D_fake = D_z_kl = G_real = Z_recon = R_kl = 0
Exemplo n.º 18
0
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        with to.no_grad():
            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr

            # Compute value predictions using the old old (before update) value function
            v_pred_old = self._critic.values(concat_ros)

        # Attach advantages and old log probabilities to rollout
        concat_ros.add_data("log_probs_old", log_probs_old)
        concat_ros.add_data("v_pred_old", v_pred_old)

        # For logging the gradient norms
        policy_grad_norm = []
        vfcn_grad_norm = []

        # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters
        adv = self._critic.gae(concat_ros)  # done with to.no_grad()
        v_targ = discounted_values(rollouts, self._critic.gamma).view(
            -1, 1)  # empirical discounted returns
        concat_ros.add_data("adv", adv)
        concat_ros.add_data("v_targ", v_targ)

        # Iterations over the whole data set
        for e in range(self.num_epoch):

            for batch in tqdm(
                    concat_ros.split_shuffled_batches(
                        self.batch_size,
                        complete_rollouts=self._policy.is_recurrent
                        or isinstance(self._critic.vfcn, RecurrentPolicy),
                    ),
                    total=num_iter_from_rollouts(None, concat_ros,
                                                 self.batch_size),
                    desc=f"Epoch {e}",
                    unit="batches",
                    file=sys.stdout,
                    leave=False,
            ):
                # Reset the gradients
                self.optim.zero_grad()

                # Compute log of the action probabilities for the mini-batch
                log_probs = compute_action_statistics(
                    batch, self._expl_strat).log_probs.to(self.policy.device)

                # Compute value predictions for the mini-batch
                v_pred = self._critic.values(batch)

                # Compute combined loss and backpropagate
                loss = self.loss_fcn(log_probs, batch.log_probs_old, batch.adv,
                                     v_pred, batch.v_pred_old, batch.v_targ)
                loss.backward()

                # Clip the gradients if desired
                policy_grad_norm.append(
                    self.clip_grad(self._expl_strat.policy,
                                   self.max_grad_norm))
                vfcn_grad_norm.append(
                    self.clip_grad(self._critic.vfcn, self.max_grad_norm))

                # Call optimizer
                self.optim.step()

                if to.isnan(self._expl_strat.noise.std).any():
                    raise RuntimeError(
                        f"At least one exploration parameter became NaN! The exploration parameters are"
                        f"\n{self._expl_strat.std.detach().cpu().numpy()}")

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Additional logging
        if self.log_loss:
            with to.no_grad():
                # Compute value predictions using the new (after the updates) value function approximator
                v_pred = self._critic.values(concat_ros).to(self.policy.device)
                v_loss_old = self._critic.loss_fcn(
                    v_pred_old.to(self.policy.device),
                    v_targ.to(self.policy.device)).to(self.policy.device)
                v_loss_new = self._critic.loss_fcn(v_pred, v_targ).to(
                    self.policy.device)
                vfcn_loss_impr = v_loss_old - v_loss_new  # positive values are desired

                # Compute the action probabilities using the new (after the updates) policy
                act_stats = compute_action_statistics(concat_ros,
                                                      self._expl_strat)
                log_probs_new = act_stats.log_probs
                act_distr_new = act_stats.act_distr
                loss_after = self.loss_fcn(log_probs_new, log_probs_old, adv,
                                           v_pred, v_pred_old, v_targ)
                kl_avg = to.mean(kl_divergence(
                    act_distr_old,
                    act_distr_new))  # mean seeking a.k.a. inclusive KL

                # Compute explained variance (after the updates)
                self.logger.add_value("explained var",
                                      explained_var(v_pred, v_targ), 4)
                self.logger.add_value("loss improvement V-fcnovement",
                                      vfcn_loss_impr, 4)
                self.logger.add_value("loss after", loss_after, 4)
                self.logger.add_value("KL(old_new)", kl_avg, 4)

        # Logging
        self.logger.add_value("avg expl strat std",
                              to.mean(self._expl_strat.noise.std), 4)
        self.logger.add_value("expl strat entropy",
                              self._expl_strat.noise.get_entropy(), 4)
        self.logger.add_value("avg grad norm policy",
                              np.mean(policy_grad_norm), 4)
        self.logger.add_value("avg grad norm V-fcn", np.mean(vfcn_grad_norm),
                              4)
        if self._lr_scheduler is not None:
            self.logger.add_value("avg lr",
                                  np.mean(self._lr_scheduler.get_last_lr()), 6)
Exemplo n.º 19
0
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters
        adv = self._critic.gae(concat_ros)  # done with to.no_grad()
        v_targ = discounted_values(rollouts, self._critic.gamma).view(-1, 1).to(self.policy.device)  # empirical discounted returns

        with to.no_grad():
            # Compute value predictions and the GAE using the old (before the updates) value function approximator
            v_pred = self._critic.values(concat_ros)

            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr
            loss_before = self.loss_fcn(log_probs_old, adv, v_pred, v_targ)
            self.logger.add_value('loss before', loss_before, 4)

        concat_ros.add_data('adv', adv)
        concat_ros.add_data('v_targ', v_targ)

        # For logging the gradients' norms
        policy_grad_norm = []

        for batch in tqdm(concat_ros.split_shuffled_batches(
            self.batch_size,
            complete_rollouts=self._policy.is_recurrent or isinstance(self._critic.vfcn, RecurrentPolicy)),
            total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
            desc='Updating', unit='batches', file=sys.stdout, leave=False):
            # Reset the gradients
            self.optim.zero_grad()

            # Compute log of the action probabilities for the mini-batch
            log_probs = compute_action_statistics(batch, self._expl_strat).log_probs

            # Compute value predictions for the mini-batch
            v_pred = self._critic.values(batch)

            # Compute combined loss and backpropagate
            loss = self.loss_fcn(log_probs, batch.adv, v_pred, batch.v_targ)
            loss.backward()

            # Clip the gradients if desired
            policy_grad_norm.append(self.clip_grad(self.expl_strat.policy, self.max_grad_norm))

            # Call optimizer
            self.optim.step()

        # Update the learning rate if a scheduler has been specified
        if self._lr_scheduler is not None:
            self._lr_scheduler.step()

        if to.isnan(self.expl_strat.noise.std).any():
            raise RuntimeError(f'At least one exploration parameter became NaN! The exploration parameters are'
                               f'\n{self.expl_strat.std.item()}')

        # Logging
        with to.no_grad():
            # Compute value predictions and the GAE using the new (after the updates) value function approximator
            v_pred = self._critic.values(concat_ros).to(self.policy.device)
            adv = self._critic.gae(concat_ros)  # done with to.no_grad()

            # Compute the action probabilities using the new (after the updates) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_new = act_stats.log_probs
            act_distr_new = act_stats.act_distr
            loss_after = self.loss_fcn(log_probs_new, adv, v_pred, v_targ)
            kl_avg = to.mean(
                kl_divergence(act_distr_old, act_distr_new))  # mean seeking a.k.a. inclusive KL
            explvar = explained_var(v_pred, v_targ)  # values close to 1 are desired
            self.logger.add_value('loss after', loss_after, 4)
            self.logger.add_value('KL(old_new)', kl_avg, 4)
            self.logger.add_value('explained var', explvar, 4)

        ent = self.expl_strat.noise.get_entropy()
        self.logger.add_value('avg expl strat std', to.mean(self.expl_strat.noise.std), 4)
        self.logger.add_value('expl strat entropy', to.mean(ent), 4)
        self.logger.add_value('avg grad norm policy', np.mean(policy_grad_norm), 4)
        if self._lr_scheduler is not None:
            self.logger.add_value('avg lr', np.mean(self._lr_scheduler.get_last_lr()), 6)
Exemplo n.º 20
0
        torch_w = torch_dist.log_prob(torch.FloatTensor(x).unsqueeze(-1))

        if (rel_error(torch_w, scipy_w) > 1e-6).any():
            raise ValueError(
                "Log pdf of torch and scipy versions doesn't match")

    print("Testing wishart KL divergence...")
    df1, scale1 = torch.randn(32).exp() + 2, torch.randn(32).exp() + 1e-5
    df2, scale2 = torch.randn(32).exp() + 2, torch.randn(32).exp() + 1e-5
    init_df1, init_scale1 = df1[0].clone(), scale1[0].clone()
    dist2 = DiagonalWishart(scale2.unsqueeze(-1), df2)
    df1.requires_grad, scale1.requires_grad = True, True
    gamma = 0.1
    for k in range(10000):
        dist1 = DiagonalWishart(scale1.unsqueeze(-1), df1)
        loss = kl_divergence(dist1, dist2).mean()
        if k % 1000 == 0:
            print(k, loss.item())
        loss.backward()
        with torch.no_grad():
            scale1 = scale1 - gamma * scale1.grad
            df1 = df1 - gamma * df1.grad
        scale1.requires_grad, df1.requires_grad = True, True
    print('df1 init', init_df1, init_scale1)
    print('df1 final', df1[0], scale1[0])
    print('df2', df2[0], scale2[0])

    print("All tests passed.")

    print("Testing normal wishart...")
    y = np.linspace(5, 20, 100)
def _compute_kl(policy_net, states):
    pi = policy_net(states)
    pi_detach = detach_distribution(pi)
    kl = torch.mean(kl_divergence(pi_detach, pi))
    return kl
Exemplo n.º 22
0
    def forward(self, img, z_propagte_what, z_propagte_where, z_propagte_pres,
                propgate_GlimpseEncoder):
        """
        :param img: [B H W]
        :param z_propagte_where:[B number where_length]
        :param z_propagte_what:[B number what_length]
        :param z_propagte_pres:[B number pres_length]
        :return:
        """
        #print("z_propagte_pres",z_propagte_pres.size(),"z_propagte_what.shape",z_propagte_what.size())
        n = img.size(0)
        loss = 0
        # initial

        h_dis = Variable(torch.zeros(n, 1, 256)).to(device)
        c_dis = Variable(torch.zeros(n, 1, 256)).to(device)
        z_pres = Variable(torch.ones(n, 1, 1)).to(device)
        z_where = Variable(torch.zeros(n, 1, 3)).to(device)
        z_what = Variable(torch.zeros(n, 1, 50)).to(device)
        h_dis_item = Variable(torch.zeros(n, 1, 256)).to(device)

        kl_z_what = torch.zeros(n, device=device)
        kl_z_where = torch.zeros(n, device=device)
        obj_probs = torch.ones(n, self.max_step, device=device)

        h_baseline = torch.zeros(n, 256, device=device)
        c_baseline = torch.zeros_like(h_baseline)
        baseline = torch.zeros(n, 1).to(device)
        """
        h_dis = zeros(n, 1, 256)
        c_dis = zeros(n, 1, 256)
        z_pres = zeros(n, 1, 1)
        z_where = zeros(n, 1, 3)
        z_what = zeros(n, 1, 50)
        h_dis_item= zeros(n, 1, 256)
        """
        if (use_cuda):
            e_t = self.encoder_img(img.view(img.size(0),
                                            50 * 50).to(device))  #[B 100]
        else:
            e_t = self.encoder_img(img.view(img.size(0), 50 * 50))  # [B 100]
        for i in range(self.max_step):
            if (i == 0):
                z_where_item = z_propagte_where[:, -1, :]
                z_what_item = z_propagte_what[:, -1, :]
            # z_pres_item=z_propagte_pres[:,-1,:]
            else:
                z_where_item = z_where[:, i - 1, :]
                z_what_item = z_what[:, i - 1, :]
                #z_pres_item = z_pres[:, i - 1, :]
            h_dis_item, c_dis = dis_hidden_state(self.dis_rnncell, e_t,
                                                 z_where_item, z_what_item,
                                                 h_dis_item,
                                                 c_dis)  #[B 1 hidden_size]
            z_pres_proba, z_where_mu, z_where_sd = self.latent_predic_where_and_pres(
                h_dis_item.squeeze(1))
            #print("z_pres_proba_discovery", z_pres_proba.size())
            loss += self.latent_loss(z_where_mu, z_where_sd)
            #print("z_pres_item_discovery",z_pres_item.size())
            z_where_item, z_pres_item = self._reparameterized_sample_where_and_pres(
                z_where_mu, z_where_sd, z_pres_proba)
            x_att = attentive_stn_encode(z_where_item,
                                         img)  # Spatial trasform [B 400]
            encode = propgate_GlimpseEncoder(x_att)  # [B 100]
            z_what_mean, z_what_std = self.latent_predic_what(encode)  #[B 50]
            loss += self.latent_loss(z_what_mean, z_what_std)  #[1]
            z_what_item = self._reparameterized_sample_what(
                z_what_mean, z_what_std)
            if (i == 0):
                z_what = z_what_item.unsqueeze(1)
                z_where = z_where_item.unsqueeze(1)
                z_pres = z_pres_item.unsqueeze(1)
                h_dis = h_dis_item
            else:
                z_what = torch.cat((z_what, z_what_item.unsqueeze(1)), dim=1)
                z_where = torch.cat((z_where, z_where_item.unsqueeze(1)),
                                    dim=1)
                z_pres = torch.cat((z_pres, z_pres_item.unsqueeze(1)), dim=1)
                h_dis = torch.cat((h_dis, h_dis_item), dim=1)

            baseline_input = torch.cat([
                e_t.view(n, -1).detach(),
                z_pres_item.detach(),
                z_what_item.detach(),
                z_where_item.detach()
            ],
                                       dim=-1)  # [B,1+3+50+100]
            h_baseline, c_baseline = self.baseline_rnn(
                baseline_input, (h_baseline, c_baseline))  #[B 256]
            #print("test self.baseline_linear(h_baseline).squeeze()=",self.baseline_linear(h_baseline).size())
            baseline += self.baseline_linear(
                h_baseline
            )  # note: masking by z_pres give poorer results [B 1]

            kl_z_what += kl_divergence(
                Normal(z_what_mean, z_what_std),
                Normal(torch.zeros(50).to(device),
                       torch.ones(50).to(device))).sum(
                           1) * z_pres_item.squeeze()  #[B 1]
            kl_z_where += kl_divergence(
                Normal(z_where_mu, z_where_sd),
                Normal(
                    torch.tensor([0.3, 0., 0.]).to(device),
                    torch.tensor([
                        0.1, 1., 1.
                    ]).to(device))).sum(1) * z_pres_item.squeeze()  #[B 1]

            #pred_counts[:, i] = z_pres_item.flatten()  # [b MAX_STEP] binary
            obj_probs[:, i] = z_pres_proba[:,
                                           0]  # [b MAX_STEP] z_pres_proba[b 1]

        q_z_pres = self.compute_geometric_from_bernoulli(obj_probs).to(device)
        #print("torch.arange(n)", torch.arange(n).type())
        score_fn = q_z_pres[torch.arange(n).long(),
                            z_pres.long().squeeze(2).sum(1)].log(
                            )  # log prob of num objects under the geometric
        #print("z_pres.long()",z_pres.long().type())
        kl_z_pres = self.compute_z_pres_kl(
            q_z_pres,
            Geometric(torch.tensor([1 - self.z_pres_prob
                                    ]).to(device))).sum(1)  # [B 1]

        z_what = torch.cat((z_propagte_what, z_what), dim=1)
        z_where = torch.cat((z_propagte_where, z_where), dim=1)
        z_pres = torch.cat((z_propagte_pres, z_pres), dim=1)
        return z_what, z_where, z_pres, kl_z_pres, kl_z_where, kl_z_what, baseline, score_fn, h_dis
Exemplo n.º 23
0
    def forward(
        self,
        input,
        seq_lens,
        span_token_ids,
        target,
        target_lens,
        definition=None,
        definition_lens=None,
        classification_labels=None,
        sentence_mask=None,
    ):
        batch_size, tgt_len = target.shape

        # (batch_size,seq_len,hidden_size), (batch_size,hidden_size), (num_layers,batch_size,seq_len,hidden_size)
        _, last_hidden_layer = self.context_encoder(input, seq_lens, initial_state=None)
        definition_representation, _ = self.definition_encoder(
            definition, definition_lens, initial_state=None
        )

        span_ids = self._id_extractor(tokens=span_token_ids, batch=input, lens=seq_lens)
        span_representation, hidden_states = self._span_aggregator(
            all_hidden_layers if self.scalar_mix is not None else last_hidden_layer,
            sequence_mask(seq_lens),
            span_ids,
        )
        span_representation = self.context_feed_forward(span_representation)

        definition_representation = self.definition_feed_forward(
            definition_representation
        )

        post_project = self.w_z_post(
            torch.cat([span_representation, definition_representation], -1)
        )
        prior_project = self.w_z_prior(span_representation)

        mu = self.mean_layer(post_project)
        logvar = self.logvar_layer(post_project)

        mu_prime = self.mean_prime_layer(prior_project)
        logvar_prime = self.logvar_prime_layer(prior_project)

        z = mu + torch.exp(logvar * 0.5) * torch.randn_like(logvar)
        span_representation = self.z_project(z)
        KLD = kl_divergence(
            Normal(mu, torch.exp(logvar * 0.5)),
            Normal(mu_prime, torch.exp(logvar_prime * 0.5)),
        )
        kl_mask = (KLD > (self.target_kl / self.latent_size)).float()
        fake_loss_kl = (kl_mask * KLD).sum(dim=1)

        predictions, logits = self.decoder(
            target, target_lens, span_representation, hidden_states, seq_lens,
        )

        loss = (
            F.cross_entropy(
                logits.view(batch_size * (tgt_len - 1), -1),
                target[:, 1:].contiguous().view(-1),
                ignore_index=self.embeddings.tgt.padding_idx,
                reduction="none",
            )
            .view(batch_size, tgt_len - 1)
            .sum(1)
        )

        perplexity = F.cross_entropy(
            logits.view(batch_size * (tgt_len - 1), -1),
            target[:, 1:].contiguous().view(-1),
            ignore_index=self.embeddings.tgt.padding_idx,
            reduction="mean",
        ).exp()
        return DotMap(
            {
                "predictions": predictions,
                "logits": logits,
                "loss": loss,
                "perplexity": perplexity,
                "fake_kl": fake_loss_kl,
                "kl": KLD,
                "cosine_loss": cosine_loss,
            }
        )
Exemplo n.º 24
0
def main(
    env_name='AntDirection-v1',
    adapt_lr=0.1,
    meta_lr=3e-4,
    adapt_steps=3,
    num_iterations=1000,
    meta_bsz=40,
    adapt_bsz=20,
    ppo_clip=0.3,
    ppo_steps=5,
    tau=1.00,
    gamma=0.99,
    eta=0.0005,
    adaptive_penalty=False,
    kl_target=0.01,
    num_workers=4,
    seed=421,
):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)

    def make_env():
        env = gym.make(env_name)
        env = ch.envs.ActionSpaceScaler(env)
        return env

    env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
    env.seed(seed)
    env = ch.envs.ActionSpaceScaler(env)
    env = ch.envs.Torch(env)
    policy = DiagNormalPolicy(input_size=env.state_size,
                              output_size=env.action_size,
                              hiddens=[64, 64],
                              activation='tanh')
    meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr)
    baseline = LinearValue(env.state_size, env.action_size)
    opt = optim.Adam(meta_learner.parameters(), lr=meta_lr)

    for iteration in range(num_iterations):
        iteration_reward = 0.0
        iteration_replays = []
        iteration_policies = []

        # Sample Trajectories
        for task_config in tqdm(env.sample_tasks(meta_bsz),
                                leave=False,
                                desc='Data'):
            clone = deepcopy(meta_learner)
            env.set_task(task_config)
            env.reset()
            task = ch.envs.Runner(env)
            task_replay = []
            task_policies = []

            # Fast Adapt
            for step in range(adapt_steps):
                for p in clone.parameters():
                    p.detach_().requires_grad_()
                task_policies.append(deepcopy(clone))
                train_episodes = task.run(clone, episodes=adapt_bsz)
                clone = fast_adapt_a2c(clone,
                                       train_episodes,
                                       adapt_lr,
                                       baseline,
                                       gamma,
                                       tau,
                                       first_order=True)
                task_replay.append(train_episodes)

            # Compute Validation Loss
            for p in clone.parameters():
                p.detach_().requires_grad_()
            task_policies.append(deepcopy(clone))
            valid_episodes = task.run(clone, episodes=adapt_bsz)
            task_replay.append(valid_episodes)
            iteration_reward += valid_episodes.reward().sum().item(
            ) / adapt_bsz
            iteration_replays.append(task_replay)
            iteration_policies.append(task_policies)

        # Print statistics
        print('\nIteration', iteration)
        adaptation_reward = iteration_reward / meta_bsz
        print('adaptation_reward', adaptation_reward)

        # ProMP meta-optimization
        for ppo_step in tqdm(range(ppo_steps), leave=False, desc='Optim'):
            promp_loss = 0.0
            kl_total = 0.0
            for task_replays, old_policies in zip(iteration_replays,
                                                  iteration_policies):
                new_policy = meta_learner.clone()
                states = task_replays[0].state()
                actions = task_replays[0].action()
                rewards = task_replays[0].reward()
                dones = task_replays[0].done()
                next_states = task_replays[0].next_state()
                old_policy = old_policies[0]
                (old_density, new_density, old_log_probs,
                 new_log_probs) = precompute_quantities(
                     states, actions, old_policy, new_policy)
                advantages = compute_advantages(baseline, tau, gamma, rewards,
                                                dones, states, next_states)
                advantages = ch.normalize(advantages).detach()
                for step in range(adapt_steps):
                    # Compute KL penalty
                    kl_pen = kl_divergence(old_density, new_density).mean()
                    kl_total += kl_pen.item()

                    # Update the clone
                    surr_loss = trpo.policy_loss(new_log_probs, old_log_probs,
                                                 advantages)
                    new_policy.adapt(surr_loss)

                    # Move to next adaptation step
                    states = task_replays[step + 1].state()
                    actions = task_replays[step + 1].action()
                    rewards = task_replays[step + 1].reward()
                    dones = task_replays[step + 1].done()
                    next_states = task_replays[step + 1].next_state()
                    old_policy = old_policies[step + 1]
                    (old_density, new_density, old_log_probs,
                     new_log_probs) = precompute_quantities(
                         states, actions, old_policy, new_policy)

                    # Compute clip loss
                    advantages = compute_advantages(baseline, tau, gamma,
                                                    rewards, dones, states,
                                                    next_states)
                    advantages = ch.normalize(advantages).detach()
                    clip_loss = ppo.policy_loss(new_log_probs,
                                                old_log_probs,
                                                advantages,
                                                clip=ppo_clip)

                    # Combine into ProMP loss
                    promp_loss += clip_loss + eta * kl_pen

            kl_total /= meta_bsz * adapt_steps
            promp_loss /= meta_bsz * adapt_steps
            opt.zero_grad()
            promp_loss.backward(retain_graph=True)
            opt.step()

            # Adapt KL penalty based on desired target
            if adaptive_penalty:
                if kl_total < kl_target / 1.5:
                    eta /= 2.0
                elif kl_total > kl_target * 1.5:
                    eta *= 2.0
Exemplo n.º 25
0
    def _validate(
        self,
        input,
        seq_lens,
        span_token_ids,
        target,
        target_lens,
        decode_strategy,
        definition=None,
        definition_lens=None,
        sentence_mask=None,
    ):
        batch_size, tgt_len = target.shape

        # (batch_size,seq_len,hidden_size), (batch_size,hidden_size), (num_layers,batch_size,seq_len,hidden_size)
        _, last_hidden_layer = self.context_encoder(input, seq_lens, initial_state=None)
        definition_representation, _ = self.definition_encoder(
            definition, definition_lens, initial_state=None
        )

        span_ids = self._id_extractor(tokens=span_token_ids, batch=input, lens=seq_lens)
        span_representation, hidden_states = self._span_aggregator(
            all_hidden_layers if self.scalar_mix is not None else last_hidden_layer,
            sequence_mask(seq_lens),
            span_ids,
        )
        span_representation = self.context_feed_forward(span_representation)

        definition_representation = self.definition_feed_forward(
            definition_representation
        )

        post_project = self.w_z_post(
            torch.cat([span_representation, definition_representation], -1)
        )
        prior_project = self.w_z_prior(span_representation)

        mu = self.mean_layer(post_project)
        logvar = self.logvar_layer(post_project)

        mu_prime = self.mean_prime_layer(prior_project)
        logvar_prime = self.logvar_prime_layer(prior_project)

        z = mu + torch.exp(logvar * 0.5) * torch.randn_like(logvar)
        span_representation = self.z_project(z)
        KLD = (
            kl_divergence(
                Normal(mu, torch.exp(logvar * 0.5)),
                Normal(mu_prime, torch.exp(logvar_prime * 0.5)),
            )
            .sum(1)
            .mean()
        )

        memory_bank = hidden_states if self.decoder.attention else None
        _, logits = self.decoder(
            target, target_lens, span_representation, memory_bank, seq_lens,
        )

        loss = F.cross_entropy(
            logits.view(batch_size * (tgt_len - 1), -1),
            target[:, 1:].contiguous().view(-1),
            ignore_index=self.embeddings.tgt.padding_idx,
        )

        ppl = loss.exp()
        beam_results = self._strategic_decode(
            target,
            target_lens,
            decode_strategy,
            memory_bank,
            seq_lens,
            span_representation,
        )
        return DotMap(
            {
                "predictions": beam_results["predictions"],
                "logits": logits.view(batch_size * (tgt_len - 1), -1),
                "loss": loss,
                "perplexity": ppl,
                "kl": KLD,
            }
        )
Exemplo n.º 26
0
    def compute_losses(self,
                       pd_dict: dict,
                       data_batch: dict,
                       recon_weight=1.0,
                       kl_weight=1.0,
                       **kwargs):
        loss_dict = dict()
        image = data_batch['image']
        b, c0, h0, w0 = image.size()

        # ---------------------------------------------------------------------------- #
        # Reconstruction
        # ---------------------------------------------------------------------------- #
        fg_recon = pd_dict['fg_recon']  # (b, c0, h0, w0)
        fg_mask = pd_dict['fg_mask']  # (b, 1, h0, w0)
        bg_recon = pd_dict['bg_recon']  # (b, c0, h0, w0)

        fg_recon_dist = Normal(fg_recon, scale=data_batch['fg_recon_scale_prior'])
        fg_recon_log_prob = fg_recon_dist.log_prob(image) + torch.log(fg_mask.clamp(min=self._eps))
        bg_recon_dist = Normal(bg_recon, scale=data_batch['bg_recon_scale_prior'])
        bg_recon_log_prob = bg_recon_dist.log_prob(image) + torch.log((1.0 - fg_mask).clamp(min=self._eps))
        # conditional probability p(x|z) = p(x|fg, z) * p(fg|z) + p(x|bg, z) * p(bg|z)
        image_recon_log_prob = torch.stack([fg_recon_log_prob, bg_recon_log_prob], dim=1)
        # log likelihood, (b, c0, h0, w0)
        image_recon_log_prob = torch.logsumexp(image_recon_log_prob, dim=1)

        observation_nll = - torch.sum(image_recon_log_prob, dim=[1, 2, 3])
        loss_dict['recon_loss'] = observation_nll.mean() * recon_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_where)
        # ---------------------------------------------------------------------------- #
        if 'z_where_loc_prior' in data_batch and 'z_where_scale_prior' in data_batch:
            z_where_post = pd_dict['z_where_post']  # (b, A * h1 * w1, 4)
            z_where_prior = Normal(loc=data_batch['z_where_loc_prior'],
                                   scale=data_batch['z_where_scale_prior'],
                                   )
            # (b, A * h1 * w1, 4)
            kl_where = kl_divergence(z_where_post, z_where_prior.expand(z_where_post.batch_shape))
            kl_where = kl_where.reshape(b, -1).sum(1)
            loss_dict['kl_where_loss'] = kl_where.mean() * kl_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_what)
        # ---------------------------------------------------------------------------- #
        if 'z_what_loc_prior' in data_batch and 'z_what_scale_prior' in data_batch:
            z_what_post = pd_dict['z_what_post']  # (b * A * h1 * w1, z_what_size)
            z_what_prior = Normal(loc=data_batch['z_what_loc_prior'],
                                  scale=data_batch['z_what_scale_prior'],
                                  )
            # (b * A * h1 * w1, z_what_size)
            kl_what = kl_divergence(z_what_post, z_what_prior.expand(z_what_post.batch_shape))
            kl_what = kl_what.reshape(b, -1).sum(1)
            loss_dict['kl_what_loss'] = kl_what.mean() * kl_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_pres)
        # ---------------------------------------------------------------------------- #
        if 'z_pres_p_prior' in data_batch:
            z_pres_p = pd_dict['z_pres_p']  # (b, A * h1 * w1)
            z_pres_post = Bernoulli(probs=z_pres_p)
            z_pres_prior = Bernoulli(probs=data_batch['z_pres_p_prior'])
            kl_pres = kl_divergence(z_pres_post, z_pres_prior.expand(z_pres_post.batch_shape))
            kl_pres = kl_pres.reshape(b, -1).sum(1)
            loss_dict['kl_pres_loss'] = kl_pres.mean() * kl_weight

        # ---------------------------------------------------------------------------- #
        # KL divergence (z_depth)
        # ---------------------------------------------------------------------------- #
        if 'z_depth_loc_prior' in data_batch and 'z_depth_scale_prior' in data_batch:
            z_depth_post = pd_dict['z_depth_post']  # (b, A * h1 * w1)
            z_depth_prior = Normal(loc=data_batch['z_depth_loc_prior'],
                                   scale=data_batch['z_depth_scale_prior'],
                                   )
            # (b, A * h1 * w1)
            kl_depth = kl_divergence(z_depth_post, z_depth_prior.expand(z_depth_post.batch_shape))
            kl_depth = kl_depth.reshape(b, -1).sum(1)
            loss_dict['kl_depth_loss'] = kl_depth.mean() * kl_weight

        return loss_dict
Exemplo n.º 27
0
    def train_env_model(self, beliefs, prior_states, prior_means,
                        prior_std_devs, posterior_states, posterior_means,
                        posterior_std_devs, observations, actions, rewards,
                        nonterminals):
        # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
        if args.worldmodel_LogProbLoss:
            observation_dist = Normal(
                bottle(self.observation_model, (beliefs, posterior_states)), 1)
            observation_loss = -observation_dist.log_prob(
                observations[1:]).sum(
                    dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
        else:
            observation_loss = F.mse_loss(
                bottle(self.observation_model, (beliefs, posterior_states)),
                observations[1:],
                reduction='none').sum(
                    dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
        if args.worldmodel_LogProbLoss:
            reward_dist = Normal(
                bottle(self.reward_model, (beliefs, posterior_states)), 1)
            reward_loss = -reward_dist.log_prob(rewards[:-1]).mean(dim=(0, 1))
        else:
            reward_loss = F.mse_loss(bottle(self.reward_model,
                                            (beliefs, posterior_states)),
                                     rewards[:-1],
                                     reduction='none').mean(dim=(0, 1))

        # transition loss
        div = kl_divergence(Normal(posterior_means, posterior_std_devs),
                            Normal(prior_means, prior_std_devs)).sum(dim=2)
        kl_loss = torch.max(div, self.free_nats).mean(
            dim=(0, 1)
        )  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
        if args.global_kl_beta != 0:
            kl_loss += args.global_kl_beta * kl_divergence(
                Normal(posterior_means, posterior_std_devs),
                self.global_prior).sum(dim=2).mean(dim=(0, 1))
        # Calculate latent overshooting objective for t > 0
        if args.overshooting_kl_beta != 0:
            overshooting_vars = [
            ]  # Collect variables for overshooting to process in batch
            for t in range(1, args.chunk_size - 1):
                d = min(t + args.overshooting_distance,
                        args.chunk_size - 1)  # Overshooting distance
                t_, d_ = t - 1, d - 1  # Use t_ and d_ to deal with different time indexing for latent states
                seq_pad = (
                    0, 0, 0, 0, 0, t - d + args.overshooting_distance
                )  # Calculate sequence padding so overshooting terms can be calculated in one batch
                # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks
                overshooting_vars.append(
                    (F.pad(actions[t:d],
                           seq_pad), F.pad(nonterminals[t:d], seq_pad),
                     F.pad(rewards[t:d],
                           seq_pad[2:]), beliefs[t_], prior_states[t_],
                     F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad),
                     F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(),
                           seq_pad,
                           value=1),
                     F.pad(
                         torch.ones(d - t,
                                    args.batch_size,
                                    args.state_size,
                                    device=args.device), seq_pad))
                )  # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences
            overshooting_vars = tuple(zip(*overshooting_vars))
            # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
            beliefs, prior_states, prior_means, prior_std_devs = self.upper_transition_model(
                torch.cat(overshooting_vars[4], dim=0),
                torch.cat(overshooting_vars[0], dim=1),
                torch.cat(overshooting_vars[3], dim=0), None,
                torch.cat(overshooting_vars[1], dim=1))
            seq_mask = torch.cat(overshooting_vars[7], dim=1)
            # Calculate overshooting KL loss with sequence mask
            kl_loss += (
                1 / args.overshooting_distance
            ) * args.overshooting_kl_beta * torch.max((kl_divergence(
                Normal(torch.cat(overshooting_vars[5], dim=1),
                       torch.cat(overshooting_vars[6], dim=1)),
                Normal(prior_means, prior_std_devs)
            ) * seq_mask).sum(dim=2), self.free_nats).mean(dim=(0, 1)) * (
                args.chunk_size
                - 1
            )  # Update KL loss (compensating for extra average over each overshooting/open loop sequence)
            # Calculate overshooting reward prediction loss with sequence mask
            if args.overshooting_reward_scale != 0:
                reward_loss += (
                    1 / args.overshooting_distance
                ) * args.overshooting_reward_scale * F.mse_loss(
                    bottle(self.reward_model,
                           (beliefs, prior_states)) * seq_mask[:, :, 0],
                    torch.cat(overshooting_vars[2], dim=1),
                    reduction='none'
                ).mean(dim=(0, 1)) * (
                    args.chunk_size - 1
                )  # Update reward loss (compensating for extra average over each overshooting/open loop sequence)
        # Apply linearly ramping learning rate schedule
        if args.learning_rate_schedule != 0:
            for group in self.model_optimizer.param_groups:
                group['lr'] = min(
                    group['lr'] + args.model_learning_rate /
                    args.model_learning_rate_schedule,
                    args.model_learning_rate)
        model_loss = observation_loss + reward_loss + kl_loss
        # Update model parameters
        self.model_optimizer.zero_grad()
        model_loss.backward()
        nn.utils.clip_grad_norm_(self.param_list,
                                 args.grad_clip_norm,
                                 norm_type=2)
        self.model_optimizer.step()
        return observation_loss, reward_loss, kl_loss
Exemplo n.º 28
0

for episode in tqdm(range(metrics['episodes'][-1] + 1, args.episodes + 1), total=args.episodes, initial=metrics['episodes'][-1] + 1):
  # Model fitting
  losses = []
  for s in tqdm(range(args.collect_interval)):
    # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
    observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0
    # Create initial belief and state for time t = 0
    init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)
    # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])
    # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
    observation_loss = F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
    reward_loss = F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))
    kl_loss = torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
    if args.global_kl_beta != 0:
      kl_loss += args.global_kl_beta * kl_divergence(Normal(posterior_means, posterior_std_devs), global_prior).sum(dim=2).mean(dim=(0, 1))
    # Calculate latent overshooting objective for t > 0
    if args.overshooting_kl_beta != 0:
      overshooting_vars = []  # Collect variables for overshooting to process in batch
      for t in range(1, args.chunk_size - 1):
        d = min(t + args.overshooting_distance, args.chunk_size - 1)  # Overshooting distance
        t_, d_ = t - 1, d - 1  # Use t_ and d_ to deal with different time indexing for latent states
        seq_pad = (0, 0, 0, 0, 0, t - d + args.overshooting_distance)  # Calculate sequence padding so overshooting terms can be calculated in one batch
        # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks
        overshooting_vars.append((F.pad(actions[t:d], seq_pad), F.pad(nonterminals[t:d], seq_pad), F.pad(rewards[t:d], seq_pad[2:]), beliefs[t_], prior_states[t_], F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad), F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1), F.pad(torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad)))  # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences
      overshooting_vars = tuple(zip(*overshooting_vars))
      # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
      beliefs, prior_states, prior_means, prior_std_devs = transition_model(torch.cat(overshooting_vars[4], dim=0), torch.cat(overshooting_vars[0], dim=1), torch.cat(overshooting_vars[3], dim=0), None, torch.cat(overshooting_vars[1], dim=1))
      seq_mask = torch.cat(overshooting_vars[7], dim=1)
Exemplo n.º 29
0
# Loss and optimizer
# nn.CrossEntropyLoss() computes softmax internally
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Reshape images to (batch_size, input_size)
        images = images.reshape(-1, 28 * 28)

        # Forward pass
        z, dist, mu = encoder(images)
        prior = Normal(torch.zeros_like(z), torch.ones_like(z))
        kl = 0.001 * kl_divergence(dist, prior).sum(1).mean()
        loss = criterion(z, labels) + kl

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch + 1, num_epochs, i + 1, total_step, loss.item()))

# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
    correct = 0.
def compute_loss(observations, actions, ssm, observation_model):
    # Input is B x T x ...., need T x B x ....
    observations = torch.transpose(observations, 0, 1).contiguous()
    actions = torch.transpose(actions, 0, 1).contiguous()

    # Create initial belief and state for time t = 0
    init_state = torch.zeros(args.batch_size,
                             args.state_size,
                             device=args.device)
    # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
    prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = ssm(
        init_state, actions[:-1], observations[1:])
    # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
    observation_loss = F.mse_loss(bottle(observation_model,
                                         (posterior_states, )),
                                  observations[1:],
                                  reduction='none').sum(dim=(2, 3,
                                                             4)).mean(dim=(0,
                                                                           1))
    kl_loss = torch.max(
        kl_divergence(Normal(posterior_means, posterior_std_devs),
                      Normal(prior_means,
                             prior_std_devs)).sum(dim=2), free_nats
    ).mean(
        dim=(0, 1)
    )  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
    if args.global_kl_beta != 0:
        kl_loss += args.global_kl_beta * kl_divergence(
            Normal(posterior_means, posterior_std_devs),
            global_prior).sum(dim=2).mean(dim=(0, 1))
    # Calculate latent overshooting objective for t > 0
    if args.overshooting_kl_beta != 0:
        overshooting_vars = [
        ]  # Collect variables for overshooting to process in batch
        for t in range(1, args.chunk_size - 1):
            d = min(t + args.overshooting_distance,
                    args.chunk_size - 1)  # Overshooting distance
            t_, d_ = t - 1, d - 1  # Use t_ and d_ to deal with different time indexing for latent states
            seq_pad = (
                0, 0, 0, 0, 0, t - d + args.overshooting_distance
            )  # Calculate sequence padding so overshooting terms can be calculated in one batch
            # Store (0) actions, (1) prior states, (2) posterior means, (3) posterior standard deviations and (4) sequence masks
            overshooting_vars.append(
                (F.pad(actions[t:d], seq_pad), prior_states[t_],
                 F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad),
                 F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(),
                       seq_pad,
                       value=1),
                 F.pad(
                     torch.ones(d - t,
                                args.batch_size,
                                args.state_size,
                                device=args.device), seq_pad))
            )  # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences
        overshooting_vars = tuple(zip(*overshooting_vars))
        # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
        prior_states, prior_means, prior_std_devs = ssm(
            torch.cat(overshooting_vars[1], dim=0),
            torch.cat(overshooting_vars[0], dim=1), None)
        seq_mask = torch.cat(overshooting_vars[4], dim=1)
        # Calculate overshooting KL loss with sequence mask
        kl_loss += (
            1 / args.overshooting_distance
        ) * args.overshooting_kl_beta * torch.max(
            (kl_divergence(
                Normal(torch.cat(overshooting_vars[2], dim=1),
                       torch.cat(overshooting_vars[3], dim=1)),
                Normal(prior_means, prior_std_devs)) * seq_mask).sum(dim=2),
            free_nats).mean(dim=(0, 1)) * (
                args.chunk_size - 1
            )  # Update KL loss (compensating for extra average over each overshooting/open loop sequence)
    return observation_loss + kl_loss, observation_loss, kl_loss
Exemplo n.º 31
0
# TODO: to make this stochastic, shuffle and make smaller batches.
start = time.time()
theta.train()
for epoch in range(args.num_epochs*2):
    # Keep track of reconstruction loss and total kl
    total_recon_loss = 0
    total_kl = 0
    total = 0
    for img, _ in loader:
        # no need to Variable(img).cuda()
        optim1.zero_grad()
        optim2.zero_grad()
        q = Normal(loc=mu, scale=logvar.mul(0.5).exp())
        # Reparameterized sample.
        qsamp = q.rsample()
        kl = kl_divergence(q, p).sum() # KL term
        out = theta(qsamp)
        recon_loss = criterion(out, img) # reconstruction term
        loss = (recon_loss + args.alpha * kl) / args.batch_size
        total_recon_loss += recon_loss.item() / args.batch_size
        total_kl += kl.item() / args.batch_size
        total += 1
        loss.backward()
        if args.clip:
            torch.nn.utils.clip_grad_norm(theta.parameters(), args.clip)
            torch.nn.utils.clip_grad_norm(mu, args.clip)
            torch.nn.utils.clip_grad_norm(theta.parameters(), args.clip)
        if epoch % 2:
            optim1.step()
            wv = 'Theta'
            # print(theta.linear1.weight[:56:4])