Пример #1
0
    def forward_w_pert_identity(self, x1, x2, s=[]):
        """
        Run inference in the model **assuming perturbation func is identity**
        - embedding qz1
        - prediction px2 (assuming perturbation is identity)
        - embedding qz2 (from x2)
        - reconstruction px2 (from qz2)
        """
        self.eval()

        ## posterior q(z1|x1, s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        qz1 = self.encoder_z1(in_qz1)
        z1_mu = qz1[0]

        #### perturbation prediction assuming perturbation is identity => p(z2|z1) == q(z1|x1)
        ## p(x2|z2,s) where z2 ~ q(z1|x1)
        if self.use_s:
            in_px2 = [z1_mu, s1inK]
        else:
            in_px2 = [z1_mu]
        px2 = self.decoder_x(in_px2)
        x2_mu = px2[0]

        #### post-treatment x2 embedding and reconstruction by encoder_z1 and decoder_x, respectively
        ## posterior q(z2|x2,s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz2 = [x2, s1inK]
        else:
            in_qz2 = [x2]
        qz2 = self.encoder_z1(in_qz2)
        z2_mu = qz2[0]
        ## p(x2|z2,s) where z2 ~ q(z2|x2)
        if self.use_s:
            in_px2 = [z2_mu, s1inK]
        else:
            in_px2 = [z2_mu]
        px2_rec = self.decoder_x(in_px2)
        x2_rec_mu = px2_rec[0]

        return {'z1':z1_mu, 'qz1':qz1,
                'px2':px2, 'x2_pert':x2_mu,           # p(x2|z2,s) where z2 ~ q(z1|x1) (perturbation is identity)
                'z2':z2_mu, 'qz2':qz2,                # q(z2|x2,s)
                'px2_rec':px2_rec, 'x2_rec':x2_rec_mu # p(x2|z2,s) where z2 ~ q(z2|x2)
                }
Пример #2
0
    def forward(self, x1, s=[]):
        """
        Run inference in the model to:
        - predict y
        - embedding qz1
        - reconstruction px1
        """
        self.eval()

        # posterior q(z1|x1, s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        qz1 = self.encoder_z1(in_qz1)
        z1_mu = qz1[0]

        ## prediction: q(y|z1)
        qypred = self.encoder_y([z1_mu])
        if self.type_y == 'discrete':
            classifier_pred = self.encoder_y.most_probable(*qypred)
            proba = qypred[0]
        else:
            if len(qypred) > 1:
                classifier_pred, proba = qypred[0], qypred[1]
            else:  # for Bernoulli decoder
                classifier_pred, proba = qypred[0], qypred[0]

        ## p(x1|z1, s)
        if self.use_s:
            in_px1 = [z1_mu, s1inK]
        else:
            in_px1 = [z1_mu]
        px1 = self.decoder_x(in_px1)
        x1_mu = px1[0]

        return {
            'pred': classifier_pred,
            'proba': proba,
            'z1': z1_mu,
            'qz1': qz1,
            'px1': px1,
            'x1_rec': x1_mu
        }
Пример #3
0
    def forward(self, x1, s=[]):
        """
        Run inference in the model to:
        - predict y
        - embedding qz1
        - embedding qz2
        - reconstruction px1
        - prediction px2
        """
        self.eval()

        ## posterior q(z1|x1, s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        qz1 = self.encoder_z1(in_qz1)
        z1_mu = qz1[0]
        
        ## posterior p(z2|z1)
        pz2Fz1 = self.decoder_z2Fz1([z1_mu])
        z2Fz1_mu = pz2Fz1[0]

        #### reconstructions
        ## p(x1|z1,s)
        if self.use_s:
            in_px1 = [z1_mu, s1inK]
        else:
            in_px1 = [z1_mu]
        px1 = self.decoder_x(in_px1)
        x1_mu = px1[0]
        ## p(x2|z2,s)
        if self.use_s:
            in_px2 = [z2Fz1_mu, s1inK]
        else:
            in_px2 = [z2Fz1_mu]
        px2 = self.decoder_x(in_px2)
        x2_mu = px2[0]

        return {'z1':z1_mu, 'qz1':qz1, 'px1':px1, 'x1_rec':x1_mu, 'z2':z2Fz1_mu, 'pz2':pz2Fz1, 'px2':px2, 'x2_pert':x2_mu}
Пример #4
0
    def _compute_losses(self, x1, y, s, L):
        """
        Compute all losses of the model. For unlabeled data marginalize y.
        RECL - reconstruction loss E_{q(z1|x1)}[ p(x1|z1) ]
        KLD  - kl-divergences of all the other matching q and p distributions
        YL   - prediction loss on y (for labeled data)
        MMD  - maximum mean discrepancy of z1 embedding w.r.t. grouping s
        """
        RECL, KLD, YL, MMD = 0., 0., 0., 0.
        N = x1.size(0)

        if self.type_y == 'discrete':
            y1inK = blk.one_hot(y, self.dim_y)
        else:
            if len(y) > 0:
                y = y.float()
        # instantiate prior of y for discrete y
        if isinstance(self.prior_y, str) and self.prior_y == 'uniform':
            pr_y = Variable(
                torch.from_numpy(np.ones(
                    (N, self.dim_y)) / (1. * self.dim_y))).float()
        else:
            pr_y = Variable(
                torch.from_numpy(np.ones(
                    (N, self.dim_y)) * self.prior_y)).float()

        s1inK = None
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
        if self.use_s and self.use_MMD:
            sind = []  # get the indices for the nuisance variable groups
            for si in range(self.dim_s):
                sind.append(torch.eq(s, si).squeeze())

        # get q(z1|x1, s)
        if self.use_s:
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        if self.training and self.add_noise:
            eps = x1.data.new(x1.size()).normal_()
            eps = Variable(eps.mul_(self.add_noise_var))
            in_qz1[0] += eps
        qz1 = self.encoder_z1(in_qz1)

        Lf = 1. * L
        for _ in range(L):
            # sample from q(z1|x1, s)
            z1 = self.encoder_z1.sample(*qz1)
            z1_sample = z1[
                0]  # for compatibility with IAF encoders return a tuple

            ## get the reconstruction loss
            # p(x1|z1,s) where z1 ~ q(z1|x1,s)
            if self.use_s:
                in_px1 = [z1_sample, s]
            else:
                in_px1 = [z1_sample]
            px1 = self.decoder_x(in_px1)
            RECL += self.decoder_x.logp(x1, *px1) / Lf

            ## prediction: q(y|z1)
            qy = self.encoder_y([z1_sample])

            _KLD = 0.
            if len(y) > 0:
                ## if y is given then
                # (i) compute prediction loss
                if self.type_y == 'discrete':
                    YL += self.encoder_y.logp(y, *qy) / Lf
                    # ## Margin Ranking Loss
                    # MRL = 0.
                    # for i in range(1, y.size(0)-1):
                    #     a = qy[0][:-i,1]
                    #     b = qy[0][i:,1]
                    #     ya = y.view(-1)[:-i]
                    #     yb = y.view(-1)[i:]
                    #     t = ((ya > yb).float() * 2 ) - 1
                    #     idx = torch.nonzero(ya.data != yb.data).view(-1)
                    #     if len(idx) == 0: continue
                    #     # print(a[:20], b[:20], t[:20])
                    #     MRL += -F.margin_ranking_loss(a[idx], b[idx], t[idx])
                    # YL += MRL / y.size(0) / Lf / 100.
                    # # print(MRL)
                else:
                    # print(y.size(), qy[0].size())
                    # print(y.view(-1) - qy[0].view(-1))
                    ## Squared Error Loss
                    YL += -((y.view(-1) - qy[0].view(-1))**2).sum() / Lf
                    # ## Margin Ranking Loss
                    # MRL = 0.
                    # for i in range(1, y.size(0)-1):
                    #     a = qy[0].view(-1,1)[:-i]
                    #     b = qy[0].view(-1,1)[i:]
                    #     t = ((y.view(-1)[:-i] > y.view(-1)[i:]).float() * 2 ) - 1
                    #     # print(a[:20], b[:20], t[:20])
                    #     MRL += -F.margin_ranking_loss(a, b, t)
                    # YL += MRL / y.size(0) / Lf
                    # # print(MRL)
                    ## Log Likelihood
                    # YL += self.encoder_y.logp(y, *qy) / Lf
                # (ii) condition on true y
                if self.type_y == 'discrete':
                    _y = y1inK
                else:
                    _y = y
                    if _y.ndimension() == 1:
                        _y.data.unsqueeze_(1)
                # the logprior of y is ommited as it is constant wrt the optimization
                _KLD = self._fprop(z1, qz1, _y)
            else:
                ## otherwise use predicted qy to marginalize out y
                if self.type_y == 'discrete':
                    # sum out y
                    for _j in range(self.dim_y):
                        _y_j = blk.one_hot(
                            Variable(x1.data.new(N).float().fill_(_j)),
                            self.dim_y)
                        _KLD_j = self._fprop(z1, qz1, _y_j)
                        assert qy[0][:, _j].size() == _KLD_j.size()
                        _KLD += qy[0][:, _j] * _KLD_j
                    # add logprior of y
                    _KLD += self.encoder_y.kldivergence_perx(qy[0],
                                                             pr_y).sum(1)
                else:
                    # if continous then just use SGVB and sample y
                    _y = self.encoder_y.sample(*qy)
                    _KLD = self._fprop(z1, qz1, _y)
                    # add logprior of y
                    if self.prior_y != 'uniform':
                        _KLD += self.encoder_y.kldivergence_perx(
                            *(qy + self.prior_y)).sum(1)
            KLD += torch.sum(_KLD) / Lf

            # maximum mean discrepancy regularization
            if self.use_s and self.use_MMD:
                MMD += self._get_mmd_criterion(z1_sample, sind) / Lf

        # yhat = self.encoder_y.most_probable(*qy)
        # print(yhat.eq(y).data.numpy().mean())

        ## loss per batch
        return OrderedDict([('RECL', RECL), ('KLD', KLD), ('YL', YL),
                            ('MMD', MMD)])
Пример #5
0
    def _compute_losses(self, x1, x2, s, L):
        """
        Compute all losses of the model. For unlabeled data marginalize y.
        RECL - reconstruction loss E_{q(z1|x1)}[ p(x1|z1) ]
        KLD  - kl-divergences of all the other matching q and p distributions
        PERT - perturbation prediction loss E_{p(z2|z1)q(z1|x1)}[ p(x2|z2) ]
        MMD  - maximum mean discrepancy of z1 embedding w.r.t. grouping s
        """
        RECL, KLD, PERT, MMD = 0., 0., 0., 0.
        N = x1.size(0)
        isPertPair = len(x2) > 0

        s1inK = None
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
        if self.use_s and self.use_MMD:
            sind = [] # get the indices for the nuisance variable groups
            for si in range(self.dim_s):
                sind.append(torch.eq(s, si).squeeze())

        # get q(z1|x1, s) and q(z2|x2, s)        
        if self.use_s:
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        if self.training and self.add_noise:
            eps = x1.data.new(x1.size()).normal_()
            eps = Variable(eps.mul_(self.add_noise_var))
            in_qz1[0] += eps
        qz1 = self.encoder_z1(in_qz1)
        if isPertPair:
            if self.use_s:
                in_qz2 = [x2, s1inK]
            else:
                in_qz2 = [x2]
            if self.training and self.add_noise:
                eps = x2.data.new(x2.size()).normal_()
                eps = Variable(eps.mul_(self.add_noise_var))
                in_qz2[0] += eps
            qz2 = self.encoder_z1(in_qz2) # !! use the same encoder as z1 !!

        Lf = 1. * L
        for _ in range(L):
            # sample from q(z1|x1, s)
            z1 = self.encoder_z1.sample(*qz1) 
            z1_sample = z1[0] # for compatibility with IAF encoders return a tuple
            # sample from q(z2|x2, s)
            if isPertPair:
                z2 = self.encoder_z1.sample(*qz1) # !! use the same encoder as z1 !!
                z2_sample = z2[0] # for compatibility with IAF encoders return a tuple

            # encode and sample from p(z2|z1) ## TODO: extend to p(z2|z1,c,m)
            pz2Fz1 = self.decoder_z2Fz1([z1_sample])
            z2Fz1 = self.decoder_z2Fz1.sample(*pz2Fz1)
            z2Fz1_sample = z2Fz1[0]

            ## get the reconstruction loss
            # p(x1|z1,s) where z1 ~ q(z1|x1,s)
            if self.use_s:
                in_px1 = [z1_sample, s]
            else:
                in_px1 = [z1_sample]
            px1 = self.decoder_x(in_px1)
            RECL += self.decoder_x.logp(x1, *px1) / Lf

            ## KL-divergence q(z1|x1) || p(z1); p(z1)=N(0,I)
            KLD_perx = 0.
            try:
                KLD_perx += self._use_free_bits(self.encoder_z1.kldivergence_from_prior_perx(*qz1)) # add KL from prior
            except:
                # no KL-divergence, use logq(z) - logp(z) Monte Carlo estimation
                logq_perx = self.encoder_z1.logp_perx(*(z1 + qz1))
                logp_perx = self.encoder_z1.logp_prior_perx(z1_sample)
                KLD_perx += self._use_free_bits(logq_perx - logp_perx)
            KLD += torch.sum(KLD_perx) / Lf

            ## get reconstruction loss & perturbation prediction loss for x2
            if isPertPair:
                # p(x2|z2,s) where z2 ~ q(z2|x2,s)
                if self.use_s:
                    in_px2 = [z2_sample, s]
                else:
                    in_px2 = [z2_sample]
                px2 = self.decoder_x(in_px2)
                RECL += self.decoder_x.logp(x2, *px2) / Lf

                # p(x2|z2,s) where z2 ~ p(z2|z1)
                # PERT = Variable(torch.FloatTensor(1).zero_())
                try:
                    if self.use_s:
                        in_px2pert = [z2Fz1_sample, s]
                    else:
                        in_px2pert = [z2Fz1_sample]
                    px2pert = self.decoder_x(in_px2pert)
                    PERT += self.decoder_x.logp(x2, *px2pert) / Lf
                except Exception as e:
                    print(e)

                ## KL-divergence q(z2|x2) || p(z2); p(z2)=N(0,I)
                KLD_perx = 0.
                try: # add KL from prior
                    KLD_perx += self._use_free_bits(self.encoder_z1.kldivergence_from_prior_perx(*qz2)) # !! use the same encoder as z1 !!
                except:
                    # no KL-divergence, use logq(z) - logp(z) Monte Carlo estimation
                    logq_perx = self.encoder_z1.logp_perx(*(z2 + qz2))
                    logp_perx = self.encoder_z1.logp_prior_perx(z2_sample)
                    KLD_perx += self._use_free_bits(logq_perx - logp_perx)
                KLD += torch.sum(KLD_perx) / Lf

            ## match distributions over z2: KL( q(z2|x2,s) || p(z2|z1) )
            if isPertPair:
                try:
                    #### try analytic KL-divergence
                    KLz2Fz1_perx = self._use_free_bits(self.encoder_z1.kldivergence_perx(*(qz2 + pz2Fz1)) )  # !! use the same encoder as z1 !!
                    # print('{}\t{:.4f}'.format(self.finished_training_iters, KLz2Fz1_perx.sum().data.numpy()[0]))
                except:
                    #### no KL-divergence, use Monte Carlo estimation: ( logp(z2|z1) - logq(z2|x2,s) )
                    KLz2Fz1_perx = 0.
                    ## for MC use a sample from q(z2|x2,s) !!! works only if p(z2|z1) is not IAF !!!
                    try:
                        # no KL-divergence, use logq(z) - logp(z) Monte Carlo estimation
                        logq_perx = self.encoder_z1.logp_perx(*(z2 + qz2))  # !! use the same encoder as z1 !!
                        logp_perx = self.decoder_z2Fz1.logp_perx(*(z2[:1] + pz2Fz1)).clamp(min=-10e10) # log probability of a sample from q(z2|x2,s) in p(z2|z1)
                        KLz2Fz1_perx += self._use_free_bits(logq_perx - logp_perx)
                    except Exception as e:
                        print(e)
                ## apply free bits/nats
                # KLz2Fz1_perx = self._use_free_bits(KLz2Fz1_perx)
                ## sum KL
                beta_pert = 1.
                if self.anneal_perturb_rate_itermax > 0:
                    beta_pert = self._compute_anneal_coef(self.finished_training_iters,
                                                iter_max = self.anneal_perturb_rate_itermax,
                                                iter_offset = self.anneal_perturb_rate_offset)
                KLD += beta_pert * ( self.kl_qz2pz2_rate * torch.sum(KLz2Fz1_perx) / Lf )
                # KLD += torch.norm(self.decoder_z2Fz1.W_mu) / Lf

            ## maximum mean discrepancy regularization
            if self.use_s and self.use_MMD:
                MMD += self._get_mmd_criterion(z1_sample, sind) / Lf
                if isPertPair:
                    MMD += self._get_mmd_criterion(z2_sample, sind) / Lf
        
        ## loss per batch
        return OrderedDict([('RECL',RECL), ('KLD',KLD), ('PERT',PERT), ('MMD',MMD)])
Пример #6
0
    def _compute_losses(self, x1, x2, s, y, L):
        """
        Compute all losses of the model. For unlabeled data marginalize y.
        RECL - reconstruction loss E_{q(z1|x1)}[ p(x1|z1) ]
        KLD  - kl-divergences of all the other matching q and p distributions
        PERT - perturbation prediction loss E_{p(z2|z1)q(z1|x1)}[ p(x2|z2) ]
        YL   - prediction loss on y (for labeled data)
        MMD  - maximum mean discrepancy of z1 embedding w.r.t. grouping s
        """
        RECL, KLD, PERT, YL, MMD = 0., 0., 0., 0., 0.
        N = x1.size(0)
        isPertPair = len(x2) > 0

        if self.type_y == 'discrete':
            y1inK = blk.one_hot(y, self.dim_y)
        else:
            if len(y) > 0:
                y = y.float()
        # instantiate prior of y for discrete y
        if isinstance(self.prior_y, str) and self.prior_y == 'uniform':
            pr_y = Variable(
                torch.from_numpy(np.ones(
                    (N, self.dim_y)) / (1. * self.dim_y))).float()
        else:
            pr_y = Variable(
                torch.from_numpy(np.ones(
                    (N, self.dim_y)) * self.prior_y)).float()

        s1inK = None
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
        if self.use_s and self.use_MMD:
            sind = []  # get the indices for the nuisance variable groups
            for si in range(self.dim_s):
                sind.append(torch.eq(s, si).squeeze())

        # get q(z1|x1, s) and q(z2|x2, s)
        if self.use_s:
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        if self.training and self.add_noise:
            eps = x1.data.new(x1.size()).normal_()
            eps = Variable(eps.mul_(self.add_noise_var))
            in_qz1[0] += eps
        qz1 = self.encoder_z1(in_qz1)
        if isPertPair:
            if self.use_s:
                in_qz2 = [x2, s1inK]
            else:
                in_qz2 = [x2]
            if self.training and self.add_noise:
                eps = x2.data.new(x2.size()).normal_()
                eps = Variable(eps.mul_(self.add_noise_var))
                in_qz2[0] += eps
            qz2 = self.encoder_z1(in_qz2)  # !! use the same encoder as z1 !!

        Lf = 1. * L
        for _ in range(L):
            # sample from q(z1|x1, s)
            z1 = self.encoder_z1.sample(*qz1)
            z1_sample = z1[
                0]  # for compatibility with IAF encoders return a tuple
            # sample from q(z2|x2, s)
            if isPertPair:
                z2 = self.encoder_z1.sample(
                    *qz1)  # !! use the same encoder as z1 !!
                z2_sample = z2[
                    0]  # for compatibility with IAF encoders return a tuple

            # encode and sample from p(z2|z1) ## TODO: extend to p(z2|z1,c,m)
            pz2Fz1 = self.decoder_z2Fz1([z1_sample])
            z2Fz1 = self.decoder_z2Fz1.sample(*pz2Fz1)
            z2Fz1_sample = z2Fz1[0]

            ## get the reconstruction loss
            # p(x1|z1,s) where z1 ~ q(z1|x1,s)
            if self.use_s:
                in_px1 = [z1_sample, s]
            else:
                in_px1 = [z1_sample]
            px1 = self.decoder_x(in_px1)
            RECL += self.decoder_x.logp(x1, *px1) / Lf

            ## get reconstruction loss & perturbation prediction loss for x2
            if isPertPair:
                # p(x2|z2,s) where z2 ~ q(z2|x2,s)
                if self.use_s:
                    in_px2 = [z2_sample, s]
                else:
                    in_px2 = [z2_sample]
                px2 = self.decoder_x(in_px2)
                RECL += self.decoder_x.logp(x2, *px2) / Lf

                # p(x2|z2,s) where z2 ~ p(z2|z1)
                if self.use_s:
                    in_px2pert = [z2Fz1_sample, s]
                else:
                    in_px2pert = [z2Fz1_sample]
                px2pert = self.decoder_x(in_px2pert)
                PERT += self.decoder_x.logp(x2, *px2pert) / Lf

            ## match distributions over z2: KL( p(z2|z1) || q(z2|x2,s) )
            if isPertPair:
                try:
                    #### try analytic KL-divergence
                    KLz2Fz1_perx = self._use_free_bits(
                        self.encoder_z1.kldivergence_perx(
                            *(qz2 +
                              pz2Fz1)))  # !! use the same encoder as z1 !!
                    # print('{}\t{:.4f}'.format(self.finished_training_iters, KLz2Fz1_perx.sum().data.numpy()[0]))
                except:
                    #### no KL-divergence, use Monte Carlo estimation: ( logp(z2|z1) - logq(z2|x2,s) )
                    KLz2Fz1_perx = 0.
                    ## for MC use a sample from q(z2|x2,s) !!! works only if p(z2|z1) is not IAF !!!
                    try:
                        # no KL-divergence, use logq(z) - logp(z) Monte Carlo estimation
                        logq_perx = self.encoder_z1.logp_perx(
                            *(z2 + qz2))  # !! use the same encoder as z1 !!
                        logp_perx = self.decoder_z2Fz1.logp_perx(*(
                            z2[:1] + pz2Fz1
                        )).clamp(
                            min=-1e10
                        )  # log probability of a sample from q(z2|x2,s) in p(z2|z1)
                        KLz2Fz1_perx += self._use_free_bits(logq_perx -
                                                            logp_perx)
                    except Exception as e:
                        print(e)
                ## apply free bits/nats
                # KLz2Fz1_perx = self._use_free_bits(KLz2Fz1_perx)
                ## sum KL
                beta_pert = 1.
                if self.anneal_perturb_rate_itermax > 0:
                    beta_pert = self._compute_anneal_coef(
                        self.finished_training_iters,
                        iter_max=self.anneal_perturb_rate_itermax,
                        iter_offset=self.anneal_perturb_rate_offset)
                KLD += beta_pert * (self.kl_qz2pz2_rate *
                                    torch.sum(KLz2Fz1_perx) / Lf)
                ## prevent qz1 and qz2 from collapsing : -( KL( q(z1|x1,s) || q(z2|x2,s) ) + KL( q(z2|x2,s) || q(z1|x1,s) ) )/2.
                # KLqz1qz2 = -(self.encoder_z1.kldivergence_perx(*(qz1 + qz2)) + self.encoder_z1.kldivergence_perx(*(qz2 + qz1))) / 2.
                # KLD += beta_pert * ( self.kl_qz2pz2_rate * 0.05 * torch.sum(KLz2Fz1_perx.clamp(max=500)) / Lf )

            ## prediction: q(y|z1,z2) [or q(y|z2)]
            if self.clf_z1z2:
                # in_clf = [z1_sample, z2Fz1_sample]
                in_clf = [z1_sample, z2Fz1_sample - z1_sample]
            else:
                in_clf = [z2Fz1_sample]
                # in_clf = [z2Fz1_sample - z1_sample]
            qy = self.encoder_y(in_clf)

            _KLD = 0.
            if len(y) > 0:
                ## if y is given then
                # (i) compute prediction loss
                YL += self.encoder_y.logp(y, *qy) / Lf
                # (ii) condition on true y
                if self.type_y == 'discrete':
                    _y = y1inK
                else:
                    _y = y
                    if _y.ndimension() == 1:
                        _y.data.unsqueeze_(1)
                # the logprior of y is omitted as it is constant wrt the optimization
                _KLD = self._fprop(z1, qz1, _y)
            else:
                ## otherwise use predicted qy to marginalize out y
                if self.type_y == 'discrete':
                    # sum out y
                    for _j in range(self.dim_y):
                        _y_j = blk.one_hot(
                            Variable(x1.data.new(N).float().fill_(_j)),
                            self.dim_y)
                        _KLD_j = self._fprop(z1, qz1, _y_j)
                        assert qy[0][:, _j].size() == _KLD_j.size()
                        _KLD += qy[0][:, _j] * _KLD_j
                    # add logprior of y
                    _KLD += self.encoder_y.kldivergence_perx(qy[0],
                                                             pr_y).sum(1)
                else:
                    # if continous then just use SGVB and sample y
                    _y = self.encoder_y.sample(*qy)[0]
                    _KLD = self._fprop(z1, qz1, _y)
                    # add logprior of y
                    if self.prior_y != 'uniform':
                        _KLD += self.encoder_y.kldivergence_perx(
                            *(qy + self.prior_y)).sum(1)
            KLD += torch.sum(_KLD) / Lf

            ## maximum mean discrepancy regularization
            if self.use_s and self.use_MMD:
                MMD += self._get_mmd_criterion(z1_sample, sind) / Lf
                if isPertPair:
                    MMD += self._get_mmd_criterion(z2_sample, sind) / Lf

        ## loss per batch
        return OrderedDict([('RECL', RECL), ('KLD', KLD), ('PERT', PERT),
                            ('YL', YL), ('MMD', MMD)])
Пример #7
0
    def forward(self, x1, s=[]):
        """
        Run inference in the model to:
        - predict y
        - embedding qz1
        - embedding pz2 (from z1)
        - reconstruction px1
        - prediction px2
        """
        self.eval()

        ## posterior q(z1|x1, s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        qz1 = self.encoder_z1(in_qz1)
        z1_mu = qz1[0]

        ## posterior p(z2|z1)
        pz2Fz1 = self.decoder_z2Fz1([z1_mu])
        z2Fz1_mu = pz2Fz1[0]

        ## prediction: q(y|z1,z2) [or q(y|z2)]
        if self.clf_z1z2:
            # in_clf = [z1_mu, z2Fz1_mu]
            in_clf = [z1_mu, z2Fz1_mu - z1_mu]
        else:
            in_clf = [z2Fz1_mu]
            # in_clf = [z2Fz1_mu - z1_mu]
        qypred = self.encoder_y(in_clf)
        if self.type_y == 'discrete':
            classifier_pred = self.encoder_y.most_probable(*qypred)
            proba = qypred[0]
        else:
            if len(qypred) > 1:
                classifier_pred, proba = qypred[0], qypred[1]
            else:  # for Bernoulli decoder
                classifier_pred, proba = qypred[0], qypred[0]

        #### reconstructions
        ## p(x1|z1,s)
        if self.use_s:
            in_px1 = [z1_mu, s1inK]
        else:
            in_px1 = [z1_mu]
        px1 = self.decoder_x(in_px1)
        x1_mu = px1[0]
        ## p(x2|z2,s) where z2 ~ p(z2|z1)
        if self.use_s:
            in_px2 = [z2Fz1_mu, s1inK]
        else:
            in_px2 = [z2Fz1_mu]
        px2 = self.decoder_x(in_px2)
        x2_mu = px2[0]

        return {
            'pred': classifier_pred,
            'proba': proba,
            'z1': z1_mu,
            'qz1': qz1,
            'px1': px1,
            'x1_rec': x1_mu,
            'z2': z2Fz1_mu,
            'pz2': pz2Fz1,
            'px2': px2,
            'x2_pert': x2_mu
        }
Пример #8
0
    def forward_w_pert_identity(self, x1, x2, s=[]):
        """
        Run inference in the model **assuming perturbation func is identity**
        - predict y
        - embedding qz1
        - prediction px2 (assuming perturbation is identity)
        - embedding qz2 (from x2)
        - reconstruction px2 (from qz2)
        """
        self.eval()

        ## posterior q(z1|x1, s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz1 = [x1, s1inK]
        else:
            in_qz1 = [x1]
        qz1 = self.encoder_z1(in_qz1)
        z1_mu = qz1[0]

        ## prediction: q(y|z1,z2) [or q(y|z2)] assuming perturbation is identity => p(z2|z1) == q(z1|x1)
        if self.clf_z1z2:
            # in_clf = [z1_mu, z1_mu]
            in_clf = [z1_mu, z1_mu - z1_mu]
        else:
            in_clf = [z1_mu]
        qypred = self.encoder_y(in_clf)
        if self.type_y == 'discrete':
            classifier_pred = self.encoder_y.most_probable(*qypred)
            proba = qypred[0]
        else:
            if len(qypred) > 1:
                classifier_pred, proba = qypred[0], qypred[1]
            else:  # for Bernoulli decoder
                classifier_pred, proba = qypred[0], qypred[0]

        #### perturbation prediction assuming perturbation is identity => p(z2|z1) == q(z1|x1)
        ## p(x2|z2,s) where z2 ~ q(z1|x1)
        if self.use_s:
            in_px2 = [z1_mu, s1inK]
        else:
            in_px2 = [z1_mu]
        px2 = self.decoder_x(in_px2)
        x2_mu = px2[0]

        #### post-treatment x2 embedding and reconstruction by encoder_z1 and decoder_x, respectively
        ## posterior q(z2|x2,s)
        if self.use_s:
            s1inK = blk.one_hot(s, self.dim_s)
            in_qz2 = [x2, s1inK]
        else:
            in_qz2 = [x2]
        qz2 = self.encoder_z1(in_qz2)
        z2_mu = qz2[0]
        ## p(x2|z2,s) where z2 ~ q(z2|x2)
        if self.use_s:
            in_px2 = [z2_mu, s1inK]
        else:
            in_px2 = [z2_mu]
        px2_rec = self.decoder_x(in_px2)
        x2_rec_mu = px2_rec[0]

        return {
            'pred': classifier_pred,
            'proba': proba,
            'z1': z1_mu,
            'qz1': qz1,
            'px2': px2,
            'x2_pert':
            x2_mu,  # p(x2|z2,s) where z2 ~ q(z1|x1) (perturbation is identity)
            'z2': z2_mu,
            'qz2': qz2,  # q(z2|x2,s)
            'px2_rec': px2_rec,
            'x2_rec': x2_rec_mu  # p(x2|z2,s) where z2 ~ q(z2|x2)
        }