Esempio n. 1
0
File: gqn.py Progetto: soudia/snp
    def forward(self, C_t, D_t):
        n_episodes = D_t.size(0)

        # init states
        state_p = (self.rnn_p if self.shared_core else
                   self.rnn_p[0]).init_state(n_episodes,
                                             [self.z_height, self.z_height])
        state_q = (self.rnn_q if self.shared_core else
                   self.rnn_q[0]).init_state(n_episodes,
                                             [self.z_height, self.z_height])
        hidden_p = state_p[0]

        z_t = []
        kl = ScaledNormalizedAdder(
            next(self.parameters()).new_zeros(1), (1.0 / n_episodes))
        log_pq_ratio = NormalizedAdder(
            next(self.parameters()).new_zeros(n_episodes))
        for i in range(self.n_draw_steps):
            # select inference and generation core
            _rnn_p = self.rnn_p if self.shared_core else self.rnn_p[i]
            _rnn_q = self.rnn_q if self.shared_core else self.rnn_q[i]

            input_q = torch.cat([hidden_p, D_t], dim=1)
            hidden_q, state_q = _rnn_q(input_q, state_q)

            mean_q, logvar_q = self.reparam_q(hidden_q)
            mean_p, logvar_p = self.reparam_p(hidden_p)

            # sample z from q
            z_t_i = self.reparam_q.sample_gaussian(mean_q, logvar_q)

            # log p/q
            log_pq_ratio.append(
                loss_recon_gaussian(mean_q.view(n_episodes, -1),
                                    logvar_q.view(n_episodes, -1),
                                    z_t_i.view(n_episodes, -1),
                                    reduction="batch_sum") -
                loss_recon_gaussian(mean_p.view(n_episodes, -1),
                                    logvar_p.view(n_episodes, -1),
                                    z_t_i.view(n_episodes, -1),
                                    reduction="batch_sum"))

            # update prior rnn
            input_p = torch.cat([z_t_i, C_t], dim=1)
            hidden_p, state_p = _rnn_p(input_p, state_p)

            # append z to latent
            z_t += [z_t_i]
            kl.append(
                loss_kld_gaussian_vs_gaussian(mean_q, logvar_q, mean_p,
                                              logvar_p))

        # concat z
        z_t = torch.cat(z_t, dim=1) if self.concat_latents else z_t_i

        return {
            "z_t": z_t,
            "kl": kl.sum,
            "log_pz/qz_batchwise": log_pq_ratio.sum,
        }
Esempio n. 2
0
    def loss(
        self,
        mu_qz,
        logvar_qz,
        mu_qz0,
        logvar_qz0,
        mu_pz0,
        logvar_pz0,
        mu_x,
        logvar_x,
        target_x,
        beta=1.0,
    ):
        # kld loss: log q(z|z0, x) - log p(z)
        kld_loss = loss_kld_gaussian(mu_qz, logvar_qz, do_sum=False)

        # aux dec loss: -log r(z0|z,x)
        aux_kld_loss = loss_kld_gaussian_vs_gaussian(
            mu_qz0,
            logvar_qz0,
            mu_pz0,
            logvar_pz0,
            do_sum=False,
        )

        # recon loss (neg likelihood): -log p(x|z)
        recon_loss = loss_recon_gaussian(mu_x,
                                         logvar_x,
                                         target_x.view(-1, 2),
                                         do_sum=False)

        # add loss
        loss = recon_loss + beta * kld_loss + beta * aux_kld_loss
        return loss.mean(), recon_loss.mean(), kld_loss.mean(
        ), aux_kld_loss.mean()
Esempio n. 3
0
 def aux_loss(self, mu_pz0, logvar_pz0, target_z0):
     # aux dec loss: -log r(z0|z,x)
     aux_recon_loss = loss_recon_gaussian(mu_pz0,
                                          logvar_pz0,
                                          target_z0.view(-1, self.z0_dim),
                                          do_sum=False)
     return aux_recon_loss
Esempio n. 4
0
    def primary_loss(self, z, mu_px, logvar_px, target_x):
        # loss from energy func
        prior_loss = self.energy_func(z.view(-1, self.z_dim))

        # recon loss (neg likelihood): -log p(x|z)
        recon_loss = loss_recon_gaussian(mu_px, logvar_px, target_x.view(-1, self.input_dim), do_sum=False)

        return recon_loss, prior_loss
Esempio n. 5
0
    def loss(self, mu_z, logvar_z, mu_x, logvar_x, target_x, beta=1.0):
        # kld loss
        kld_loss = loss_kld_gaussian(mu_z, logvar_z, do_sum=False)

        # recon loss (likelihood)
        recon_loss = loss_recon_gaussian(mu_x,
                                         logvar_x,
                                         target_x.view(-1, 2),
                                         do_sum=False)

        # add loss
        loss = recon_loss + beta * kld_loss
        return loss.mean(), recon_loss.mean(), kld_loss.mean()