Example #1
0
 def forward_sample(self, Z_g, A, T_forward, Z_start=None):
     if Z_start is None:
         mu0 = self.pre_t_mu(Z_g)
         sig0 = torch.nn.functional.softplus(self.pre_t_sigma(Z_g))
         Z_start = torch.squeeze(
             Independent(Normal(mu0, sig0), 1).sample((1, )))
     Zlist = [Z_start]
     for t in range(1, T_forward):
         Ztm1 = Zlist[t - 1]
         inp = torch.cat([Z_g, Ztm1, A[:, t - 1, :]], -1)
         # try this below as well
         # if self.include_baseline:
         #     Aval = A[:,1:Tmax,:]
         #     Acat = torch.cat([Aval[...,[0]],B[:,None,:].repeat(1,Aval.shape[1],1), Aval[...,1:]],-1)
         #     mu2T, sig2T = self.transition_fxn(Zt[:,:-1,:], Acat)
         # else:
         #     mu2T, sig2T = self.transition_fxn(Zt[:,:-1,:], A[:,1:Tmax,:])
         mut = self.t_mu(inp)
         sigmat = torch.nn.functional.softplus(self.t_sigma(inp))
         Zlist.append(
             torch.squeeze(
                 Independent(Normal(mut, sigmat), 1).sample((1, ))))
     Z_t = torch.cat([k[:, None, :] for k in Zlist], 1)
     p_x_mu, p_x_sigma = self.p_X_Z(Z_t, Z_g)
     sample = torch.squeeze(
         Independent(Normal(p_x_mu, p_x_sigma), 1).sample((1, )))
     return sample, (Z_t, p_x_mu, p_x_sigma)
Example #2
0
 def reparam_dist(self, mu, sigma):
     if self.post_approx == 'diag':
         dist = Independent(Normal(mu, sigma), 1)
     elif self.post_approx == 'low_rank':
         if sigma.dim() == 2:
             W = sigma[...,
                       self.dim_stochastic:].view(sigma.shape[0],
                                                  self.dim_stochastic,
                                                  self.rank)
         elif sigma.dim() == 3:
             W = sigma[...,
                       self.dim_stochastic:].view(sigma.shape[0],
                                                  sigma.shape[1],
                                                  self.dim_stochastic,
                                                  self.rank)
         else:
             raise NotImplemented()
         D = sigma[..., :self.dim_stochastic]
         dist = LowRankMultivariateNormal(mu, W, D)
     else:
         raise ValueError('should not be here')
     sample = torch.squeeze(dist.rsample((1, )))
     if len(sample.shape) == 1:
         sample = sample[None, ...]
     return sample, dist
Example #3
0
 def forward_sample(self,
                    A,
                    T_forward,
                    Z_start=None,
                    B=None,
                    X0=None,
                    A0=None,
                    eps=0.):
     if Z_start is None:
         inp_cat = torch.cat([B, X0, A0], -1)
         mu1 = self.prior_W(inp_cat)
         sig1 = torch.nn.functional.softplus(self.prior_sigma(inp_cat))
         Z_start = torch.squeeze(
             Independent(Normal(mu1, sig1), 1).sample((1, )))
     Zlist = [Z_start]
     for t in range(1, T_forward):
         Ztm1 = Zlist[t - 1]
         if self.hparams.include_baseline:
             Aval = A[:, t - 1, :]
             Acat = torch.cat([Aval[..., [0]], B, Aval[..., 1:]], -1)
             mut, sigmat = self.transition_fxn(Ztm1, Acat, eps=eps)
         else:
             mut, sigmat = self.transition_fxn(Ztm1,
                                               A[:, t - 1, :],
                                               eps=eps)
         sample = torch.squeeze(
             Independent(Normal(mut, sigmat), 1).sample((1, )))
         if len(sample.shape) == 1:
             sample = sample[None, ...]
         Zlist.append(sample)
     Z_t = torch.cat([k[:, None, :] for k in Zlist], 1)
     p_x_mu, p_x_sigma = self.p_X_Z(Z_t, A[:, :Z_t.shape[1], [0]])
     sample = torch.squeeze(
         Independent(Normal(p_x_mu, p_x_sigma), 1).sample((1, )))
     return sample, (Z_t, p_x_mu, p_x_sigma)
Example #4
0
    def p_Zt_Ztm1(self, Zg, Zt_1T, A, B, Xt):
        mu0 = self.pre_t_mu(Zg)[:, None, :]
        sig0 = torch.nn.functional.softplus(self.pre_t_sigma(Zg))[:, None, :]
        Tmax = Zt_1T.shape[1]
        Z_rep = Zg[:, None, :].repeat(1, Tmax - 1, 1)
        if self.augmented:
            Zinp = torch.cat([Zt_1T, Xt], -1)
        else:
            Zinp = Zt_1T
        inp = torch.cat([Zinp[:, :-1, :], A[:, 1:Tmax, :], Z_rep], -1)

        if self.include_baseline:
            Aval = A[:, 1:Tmax, :]
            # include baseline in both control and input signals
            Acat = torch.cat([
                Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1),
                Aval[..., 1:]
            ], -1)
            inp = torch.cat([B[:, None, :].repeat(1, Aval.shape[1], 1), inp],
                            -1)
            mu1T, sig1T = self.transition_fxn(inp, Acat)
        else:
            mu1T, sig1T = self.transition_fxn(inp, A[:, 1:Tmax, :])

        mu, sig = torch.cat([mu0, mu1T], 1), torch.cat([sig0, sig1T], 1)
        return Independent(Normal(mu, sig), 1)
Example #5
0
    def p_Zt_Ztm1(self, Zt, A, B, X, A0, Am, eps=0.):
        X0 = X[:, 0, :]
        Xt = X[:, 1:, :]
        inp_cat = torch.cat([B, X0, A0], -1)
        mu1 = self.prior_W(inp_cat)[:, None, :]
        sig1 = torch.nn.functional.softplus(self.prior_sigma(inp_cat))[:,
                                                                       None, :]

        Tmax = Zt.shape[1]
        if self.hparams['augmented']:
            Zinp = torch.cat([Zt[:, :-1, :], Xt[:, :-1, :]], -1)
        else:
            Zinp = Zt[:, :-1, :]
        Aval = A[:, 1:Tmax, :]
        sub_mask = np.triu(np.ones(
            (Aval.shape[0], Aval.shape[1], Aval.shape[1])),
                           k=1).astype('uint8')
        Zm = (torch.from_numpy(sub_mask) == 0).to(Am.device)
        res = self.attn(self.attn_lin(torch.cat([Xt[:, :-1, :], Aval], -1)),
                        Zinp,
                        Zinp,
                        mask=Zm,
                        use_matmul=True)
        if self.hparams['include_baseline']:
            Acat = torch.cat([
                Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1),
                Aval[..., 1:]
            ], -1)
            mu2T, sig2T = self.transition_fxn(res, Acat, eps=eps)
        else:
            mu2T, sig2T = self.transition_fxn(res, A[:, 1:Tmax, :], eps=eps)
        mu, sig = torch.cat([mu1, mu2T], 1), torch.cat([sig1, sig2T], 1)
        return Independent(Normal(mu, sig), 1)
Example #6
0
    def p_Zt_Ztm1(self, Zt, A, B, X, A0, Am, eps=0.):
        X0 = X[:, 0, :]
        Xt = X[:, 1:, :]
        inp_cat = torch.cat([B, X0, A0], -1)
        mu1 = self.prior_W(inp_cat)[:, None, :]
        sig1 = torch.nn.functional.softplus(self.prior_sigma(inp_cat))[:,
                                                                       None, :]

        Tmax = Zt.shape[1]
        if self.hparams['augmented']:
            Zinp = torch.cat([Zt[:, :-1, :], Xt[:, :-1, :]], -1)
        else:
            Zinp = Zt[:, :-1, :]
        Aval = A[:, 1:Tmax, :]
        Am_res = Am[:, 1:Tmax, 1:Tmax]
        if self.hparams['include_baseline']:
            Acat = torch.cat([
                Aval[..., [0]], B[:, None, :].repeat(1, Aval.shape[1], 1),
                Aval[..., 1:]
            ], -1)
            res = self.attn(self.attn_lin(Zinp),
                            Acat,
                            Acat,
                            mask=Am_res,
                            use_matmul=True)
            mu2T, sig2T = self.transition_fxn(Zinp, res, eps=eps)
        else:
            res = self.attn(self.attn_lin(Zinp),
                            Aval,
                            Aval,
                            mask=Am_res,
                            use_matmul=True)  # res
            mu2T, sig2T = self.transition_fxn(Zinp, res, eps=eps)
        mu, sig = torch.cat([mu1, mu2T], 1), torch.cat([sig1, sig2T], 1)
        return Independent(Normal(mu, sig), 1)
Example #7
0
 def q_Z_XA(self, X, A, B, M):
     if self.hparams.inftype == 'rnn' or self.hparams.inftype == 'birnn':
         rnn_mask = (M[:, 1:].sum(-1) > 1) * 1.
         inp = torch.cat([
             X[:, 1:, :], rnn_mask[..., None], A[:, 1:, :],
             B[:, None, :].repeat(1, A.shape[1] - 1, 1)
         ], -1)
         m_t, _, lens = get_masks(M[:, 1:, :])
         pdseq = torch.nn.utils.rnn.pack_padded_sequence(
             inp, lens, batch_first=True, enforce_sorted=False)
         out_pd, _ = self.inf_network(pdseq)
         out, _ = torch.nn.utils.rnn.pad_packed_sequence(out_pd,
                                                         batch_first=True)
         mu = self.post_W(out[:, -1, :])
     elif self.hparams.inftype == 'ave_diff':
         sc, sh, _ = self.treatment_effects(A)
         A_inv = 1 - A  # for M_post
         A_ = A[..., 1, None]
         A_inv_ = A_inv[..., 1, None]
         pre = A_inv_ * X
         post = A_ * X
         diffs_pre = pre[:, 1:, :] - pre[:, :-1, :]
         diffs_post = post[:, 1:, :] - post[:, :-1, :]
         M_pre = diffs_pre.sum(1) / A_inv_.sum(1)
         diffs_post = ((diffs_post - sh) / sc) * A_[:, :-1, :]
         M_post = diffs_post.sum(1) / (A_.sum(1) - 1)
         M_ = torch.cat((M_pre, M_post), 1)
         mu = self.post_W(M_)
     sigma = mu * 0. + torch.nn.functional.softplus(self.post_sigma)
     q_dist = Independent(Normal(mu, sigma), 1)
     return q_dist
Example #8
0
 def p_Y_Z(self, Z, C):
     if 'ord' not in self.hparams.loss_type:
         mu = self.pred_W2(torch.sigmoid(self.pred_W1(Z)))
     else:
         mu = self.m_pred(Z)
     sigma = mu * 0. + torch.nn.functional.softplus(self.pred_sigma)
     p_y_z = Independent(Normal(mu, sigma), 1)
     return p_y_z
Example #9
0
    def p_Zt_Ztm1(self, Zt, A, B, X, A0, eps = 0.):
        X0 = X[:,0,:]; Xt = X[:,1:,:]
        inp_cat  = torch.cat([B, X0, A0], -1)
        mu1      = self.prior_W(inp_cat)[:,None,:]
        sig1     = torch.nn.functional.softplus(self.prior_sigma(inp_cat))[:,None,:]
#         mu1      = torch.zeros_like(sig1).to(sig1.device)
        
        Tmax     = Zt.shape[1]
        if self.hparams['augmented']: 
            Zinp = torch.cat([Zt[:,:-1,:], Xt[:,:-1,:]], -1)
        else: 
            Zinp = Zt[:,:-1,:]
        if self.hparams['include_baseline'] != 'none':
            Aval = A[:,1:Tmax,:]
            Acat = torch.cat([Aval[...,[0]],B[:,None,:].repeat(1,Aval.shape[1],1), Aval[...,1:]],-1)
            mu2T, sig2T = self.transition_fxn(Zinp, Acat, eps = eps)
        else:
            mu2T, sig2T = self.transition_fxn(Zinp, A[:,1:Tmax,:], eps = eps)
        mu, sig     = torch.cat([mu1,mu2T],1), torch.cat([sig1,sig2T],1)
        return Independent(Normal(mu, sig), 1)
Example #10
0
 def p_Z1(self, B, X0, A0):
     inp_cat = torch.cat([B, X0, A0], -1)
     mu      = self.prior_W(inp_cat)
     sigma   = torch.nn.functional.softplus(self.prior_sigma(inp_cat))
     p_z_bxa = Independent(Normal(mu, sigma), 1)
     return p_z_bxa
Example #11
0
    def forward(self, x, a, m, b):
        rnn_mask = (m[:, 1:].sum(-1) > 1) * 1.
        inp = torch.cat([x[:, 1:, :], a[:, :-1, :]], -1)
        m_t, _, lens = get_masks(m[:, 1:, :])
        pdseq = torch.nn.utils.rnn.pack_padded_sequence(inp,
                                                        lens,
                                                        batch_first=True,
                                                        enforce_sorted=False)
        out_pd, _ = self.inf_rnn(pdseq)
        out, _ = torch.nn.utils.rnn.pad_packed_sequence(out_pd,
                                                        batch_first=True)

        # Infer global latent variable
        hid_zg = torch.tanh(
            self.hid_zg(out).sum(1) / lens[..., None] +
            self.hid_zg_b(torch.cat([b, x[:, 0, :], a[:, 0, :]], -1)))
        zg_mu = self.mu_zg(hid_zg)
        zg_sigma = torch.nn.functional.softplus(self.sigma_zg(hid_zg))
        q_zg = Independent(Normal(zg_mu, zg_sigma), 1)
        Z_g = torch.squeeze(q_zg.rsample((1, )))

        # Infer per-time-step variables in the DMM
        hid_zg_zt = self.hid_zg_zt(Z_g)
        hid_rnn_zt = self.hid_rnn_zt(out)
        hid_base = self.base_h1(torch.cat([x[:, 0, :], b, a[:, 0, :]],
                                          -1))  ## test this out
        if self.combiner_type == 'standard' or self.combiner_type == 'masked':
            mu, sigma = self.combiner_fxn(
                hid_base,
                hid_rnn_zt[:, 0, :],
                rnn_mask[:, [0]],
                self.mu_z1,
                self.sigma_z1,
                global_hid=hid_zg_zt
            )  # change to self.mu_zt, self.sigma_zt if necessary
        else:
            mu, sigma = self.combiner_fxn(hid_base, hid_rnn_zt[:,0,:], rnn_mask[:,[0]], self.mu_zt, self.sigma_zt, \
                self.mu_zt2, self.sigma_zt2, self.mu_zt3, self.sigma_zt3, global_hid=hid_zg_zt)
        z, _ = self.reparam_dist(mu, sigma)

        meanlist = [mu[:, None, :]]
        sigmalist = [sigma[:, None, :]]
        zlist = [z[:, None, :]]
        for t in range(1, out.shape[1]):
            ztm1 = torch.squeeze(zlist[t - 1])
            hid_ztm1_zt = self.hid_ztm1_zt(ztm1)

            if self.combiner_type == 'standard' or self.combiner_type == 'masked':
                mu, sigma = self.combiner_fxn(hid_ztm1_zt,
                                              hid_rnn_zt[:, t, :],
                                              rnn_mask[:, [t]],
                                              self.mu_zt,
                                              self.sigma_zt,
                                              global_hid=hid_zg_zt)
            else:
                mu, sigma = self.combiner_fxn(hid_ztm1_zt, hid_rnn_zt[:,t,:], rnn_mask[:,[t]], self.mu_zt, self.sigma_zt, \
                    self.mu_zt2, self.sigma_zt2, self.mu_zt3, self.sigma_zt3, global_hid = hid_zg_zt)
            z, _ = self.reparam_dist(mu, sigma)

            meanlist += [mu[:, None, :]]
            sigmalist += [sigma[:, None, :]]
            zlist += [z[:, None, :]]
        # q_zt     = Independent(Normal(torch.cat(meanlist, 1), torch.cat(sigmalist, 1)), 1)
        _, q_zt = self.reparam_dist(torch.cat(meanlist, 1),
                                    torch.cat(sigmalist, 1))
        Z_t = torch.cat(zlist, 1)
        return Z_g, q_zg, Z_t, q_zt