示例#1
0
 def reparam_dist(self, mu, sigma):
     if self.post_approx == 'diag':
         dist = Independent(Normal(mu, sigma), 1)
     elif self.post_approx == 'low_rank':
         if sigma.dim() == 2:
             W = sigma[...,
                       self.dim_stochastic:].view(sigma.shape[0],
                                                  self.dim_stochastic,
                                                  self.rank)
         elif sigma.dim() == 3:
             W = sigma[...,
                       self.dim_stochastic:].view(sigma.shape[0],
                                                  sigma.shape[1],
                                                  self.dim_stochastic,
                                                  self.rank)
         else:
             raise NotImplemented()
         D = sigma[..., :self.dim_stochastic]
         dist = LowRankMultivariateNormal(mu, W, D)
     else:
         raise ValueError('should not be here')
     sample = torch.squeeze(dist.rsample((1, )))
     if len(sample.shape) == 1:
         sample = sample[None, ...]
     return sample, dist
示例#2
0
文件: sdmm.py 项目: clinicalml/ief
    def forward(self, x, a, m, b):
        rnn_mask = (m[:, 1:].sum(-1) > 1) * 1.
        inp = torch.cat([x[:, 1:, :], a[:, :-1, :]], -1)
        m_t, _, lens = get_masks(m[:, 1:, :])
        pdseq = torch.nn.utils.rnn.pack_padded_sequence(inp,
                                                        lens,
                                                        batch_first=True,
                                                        enforce_sorted=False)
        out_pd, _ = self.inf_rnn(pdseq)
        out, _ = torch.nn.utils.rnn.pad_packed_sequence(out_pd,
                                                        batch_first=True)

        # Infer global latent variable
        hid_zg = torch.tanh(
            self.hid_zg(out).sum(1) / lens[..., None] +
            self.hid_zg_b(torch.cat([b, x[:, 0, :], a[:, 0, :]], -1)))
        zg_mu = self.mu_zg(hid_zg)
        zg_sigma = torch.nn.functional.softplus(self.sigma_zg(hid_zg))
        q_zg = Independent(Normal(zg_mu, zg_sigma), 1)
        Z_g = torch.squeeze(q_zg.rsample((1, )))

        # Infer per-time-step variables in the DMM
        hid_zg_zt = self.hid_zg_zt(Z_g)
        hid_rnn_zt = self.hid_rnn_zt(out)
        hid_base = self.base_h1(torch.cat([x[:, 0, :], b, a[:, 0, :]],
                                          -1))  ## test this out
        if self.combiner_type == 'standard' or self.combiner_type == 'masked':
            mu, sigma = self.combiner_fxn(
                hid_base,
                hid_rnn_zt[:, 0, :],
                rnn_mask[:, [0]],
                self.mu_z1,
                self.sigma_z1,
                global_hid=hid_zg_zt
            )  # change to self.mu_zt, self.sigma_zt if necessary
        else:
            mu, sigma = self.combiner_fxn(hid_base, hid_rnn_zt[:,0,:], rnn_mask[:,[0]], self.mu_zt, self.sigma_zt, \
                self.mu_zt2, self.sigma_zt2, self.mu_zt3, self.sigma_zt3, global_hid=hid_zg_zt)
        z, _ = self.reparam_dist(mu, sigma)

        meanlist = [mu[:, None, :]]
        sigmalist = [sigma[:, None, :]]
        zlist = [z[:, None, :]]
        for t in range(1, out.shape[1]):
            ztm1 = torch.squeeze(zlist[t - 1])
            hid_ztm1_zt = self.hid_ztm1_zt(ztm1)

            if self.combiner_type == 'standard' or self.combiner_type == 'masked':
                mu, sigma = self.combiner_fxn(hid_ztm1_zt,
                                              hid_rnn_zt[:, t, :],
                                              rnn_mask[:, [t]],
                                              self.mu_zt,
                                              self.sigma_zt,
                                              global_hid=hid_zg_zt)
            else:
                mu, sigma = self.combiner_fxn(hid_ztm1_zt, hid_rnn_zt[:,t,:], rnn_mask[:,[t]], self.mu_zt, self.sigma_zt, \
                    self.mu_zt2, self.sigma_zt2, self.mu_zt3, self.sigma_zt3, global_hid = hid_zg_zt)
            z, _ = self.reparam_dist(mu, sigma)

            meanlist += [mu[:, None, :]]
            sigmalist += [sigma[:, None, :]]
            zlist += [z[:, None, :]]
        # q_zt     = Independent(Normal(torch.cat(meanlist, 1), torch.cat(sigmalist, 1)), 1)
        _, q_zt = self.reparam_dist(torch.cat(meanlist, 1),
                                    torch.cat(sigmalist, 1))
        Z_t = torch.cat(zlist, 1)
        return Z_g, q_zg, Z_t, q_zt