Exemplo n.º 1
0
    def init_semicldc_cond_gmix_transadd(self, params):
        # trans one_hot_y to dense_y, then add dense_y to z1
        self.leakyrelu = nn.LeakyReLU()
        self.yohtoy_toz2 = nn.Linear(self.label_size, params.z_dim)
        self.hbn_z1y = nn.BatchNorm1d(params.z_dim)
        self.z1y_z2 = Inferer(params, in_dim=params.z_dim)

        self.y_z2 = Inferer(params, in_dim=params.cldc_label_size)
        self.z2y_z1 = Inferer(params, in_dim=params.z_dim)
Exemplo n.º 2
0
 def init_semicldc_cond_gmix(self, params):
     # concat
     # concat z1 and y
     self.z1y_z2 = Inferer(params, in_dim=params.z_dim + self.label_size)
     # p(z2 | y)
     self.y_z2 = Inferer(params, in_dim=self.label_size)
     # p(z1 | z2)
     self.z2y_z1 = Inferer(params, in_dim=params.z_dim)
     '''
    def init_semicldc_cond_gmix_transadd(self, params):
        self.init_semicldc_cond_transadd(params)

        # y -> hid for z2
        self.ytohid_z2 = nn.Linear(self.label_size, params.aux_hid_dim)
        # y -> z2
        self.hbn_y = nn.BatchNorm1d(params.aux_hid_dim)
        self.y_z2 = Inferer(params, in_dim=params.aux_hid_dim)
    def init_semicldc_cond_transadd(self, params):
        self.leakyrelu = nn.LeakyReLU()

        # x -> hid for y
        self.xtohid_y = nn.Linear(params.x_hid_dim, params.aux_hid_dim)
        # z1 -> hid for y
        self.z1tohid_y = nn.Linear(params.z_dim, params.aux_hid_dim)
        self.hbn_z1x = nn.BatchNorm1d(params.aux_hid_dim)

        # x -> hid for z2
        self.xtohid_z2 = nn.Linear(params.x_hid_dim, params.aux_hid_dim)
        # z1 -> hid for z2
        self.z1tohid_z2 = nn.Linear(params.z_dim, params.aux_hid_dim)
        # y -> hid for z2
        self.ytohid_z2 = nn.Linear(self.label_size, params.aux_hid_dim)
        # z1,y,x -> z2
        self.hbn_z1xy = nn.BatchNorm1d(params.aux_hid_dim)
        self.z1yx_z2 = Inferer(params, in_dim=params.aux_hid_dim)

        # y -> hid for z1
        self.ytohid_z1 = nn.Linear(self.label_size, params.aux_hid_dim)
        # z2 -> hid for z1
        self.z2tohid_z1 = nn.Linear(params.z_dim, params.aux_hid_dim)
        # y,z2 -> z1
        self.hbn_z2y = nn.BatchNorm1d(params.aux_hid_dim)
        self.z2y_z1 = Inferer(params, in_dim=params.aux_hid_dim)

        # y -> hid for x
        self.ytohid_x = nn.Linear(self.label_size, params.aux_hid_dim)
        # z1 -> hid for x
        self.z1tohid_x = nn.Linear(params.z_dim, params.aux_hid_dim)
        # z2 -> hid for x
        self.z2tohid_x = nn.Linear(params.z_dim, params.aux_hid_dim)
        # y,z2,z1 -> x
        self.hbn_z2z1y = nn.BatchNorm1d(params.aux_hid_dim)
        self.z2z1y_x = nn.Linear(params.aux_hid_dim, params.z_dim)
class AUXSEMICLDCModel(SEMICLDCModel):
    def __init__(self,
                 params,
                 data_list,
                 model_dict=None,
                 task_model_dict=None):
        super(SEMICLDCModel, self).__init__()
        # label size
        self.label_size = params.cldc_label_size
        # batch size for U / label size
        self.bs_u = params.semicldc_U_bs // self.label_size
        # alpha for cldc classifier
        self.semicldc_classifier_alpha = params.semicldc_classifier_alpha

        # get y prior
        self.yprior = self.get_yprior(params, data_list)
        # xling vae
        self.xlingva = XlingVA(params, data_list, model_dict=model_dict)

        self.init_semicldc_cond = getattr(
            self, 'init_semicldc_cond_{}'.format(params.semicldc_cond_type))
        # initialize the setting of how to combine x & y and z & y
        self.init_semicldc_cond(params)

        # cldc MLP
        self.cldc_classifier = CLDCClassifier(
            params, params.aux_cldc_classifier_config)

        # functions
        self.forward = getattr(self,
                               'forward_{}'.format(params.cldc_train_mode))

        self.get_z1 = getattr(self, 'get_z1_{}'.format(params.cldc_train_mode))

        self.get_z1x = getattr(self,
                               'get_z1x_{}'.format(params.semicldc_cond_type))
        self.get_z1yx = getattr(
            self, 'get_z1yx_{}'.format(params.semicldc_cond_type))
        self.get_z2y = getattr(self,
                               'get_z2y_{}'.format(params.semicldc_cond_type))
        self.get_z2z1y = getattr(
            self, 'get_z2z1y_{}'.format(params.semicldc_cond_type))
        # calculate kl of z2
        self.kl_z2 = getattr(self,
                             'kl_z2_{}'.format(params.semicldc_cond_type))

        self.step = 0
        self.anneal_warm_up = params.semicldc_anneal_warm_up
        self.cyclic_period = params.cyclic_period

        # warm up stage
        self.warm_up = False

        self.init_model(task_model_dict)

        self.use_cuda = params.cuda
        if self.use_cuda:
            self.cuda()

    def train_classifier(self,
                         lang,
                         batch_in,
                         batch_lens,
                         batch_lb,
                         training=True):
        # x -> hid_x -> mu1, logva1 -> z1
        mu1, logvar1, z1, x, loss_dis, loss_enc = self.get_z1(
            lang, batch_in, batch_lens)
        z1x_y = self.get_z1x(z1, x)

        # cldc loss
        cldc_loss, pred_p, pred = self.cldc_classifier(z1x_y,
                                                       y=batch_lb,
                                                       training=training)

        if training:
            L_dict = defaultdict(float)
            if not hasattr(self, 'adv_training') or self.adv_training is False:
                cldc_loss.mean().backward()
                L_dict['L_dis_loss'] = loss_dis
                L_dict['L_enc_loss'] = loss_enc
            elif self.adv_training is True:
                (cldc_loss.mean() + loss_dis.mean() +
                 loss_enc.mean()).backward()
                L_dict['L_dis_loss'] = loss_dis.mean().item()
                L_dict['L_enc_loss'] = loss_enc.mean().item()
            L_dict['L_cldc_loss'] = cldc_loss.mean().item()
            return L_dict, pred
        else:
            return cldc_loss, pred_p, pred

    def init_semicldc_cond_transadd(self, params):
        self.leakyrelu = nn.LeakyReLU()

        # x -> hid for y
        self.xtohid_y = nn.Linear(params.x_hid_dim, params.aux_hid_dim)
        # z1 -> hid for y
        self.z1tohid_y = nn.Linear(params.z_dim, params.aux_hid_dim)
        self.hbn_z1x = nn.BatchNorm1d(params.aux_hid_dim)

        # x -> hid for z2
        self.xtohid_z2 = nn.Linear(params.x_hid_dim, params.aux_hid_dim)
        # z1 -> hid for z2
        self.z1tohid_z2 = nn.Linear(params.z_dim, params.aux_hid_dim)
        # y -> hid for z2
        self.ytohid_z2 = nn.Linear(self.label_size, params.aux_hid_dim)
        # z1,y,x -> z2
        self.hbn_z1xy = nn.BatchNorm1d(params.aux_hid_dim)
        self.z1yx_z2 = Inferer(params, in_dim=params.aux_hid_dim)

        # y -> hid for z1
        self.ytohid_z1 = nn.Linear(self.label_size, params.aux_hid_dim)
        # z2 -> hid for z1
        self.z2tohid_z1 = nn.Linear(params.z_dim, params.aux_hid_dim)
        # y,z2 -> z1
        self.hbn_z2y = nn.BatchNorm1d(params.aux_hid_dim)
        self.z2y_z1 = Inferer(params, in_dim=params.aux_hid_dim)

        # y -> hid for x
        self.ytohid_x = nn.Linear(self.label_size, params.aux_hid_dim)
        # z1 -> hid for x
        self.z1tohid_x = nn.Linear(params.z_dim, params.aux_hid_dim)
        # z2 -> hid for x
        self.z2tohid_x = nn.Linear(params.z_dim, params.aux_hid_dim)
        # y,z2,z1 -> x
        self.hbn_z2z1y = nn.BatchNorm1d(params.aux_hid_dim)
        self.z2z1y_x = nn.Linear(params.aux_hid_dim, params.z_dim)

    def init_semicldc_cond_gmix_transadd(self, params):
        self.init_semicldc_cond_transadd(params)

        # y -> hid for z2
        self.ytohid_z2 = nn.Linear(self.label_size, params.aux_hid_dim)
        # y -> z2
        self.hbn_y = nn.BatchNorm1d(params.aux_hid_dim)
        self.y_z2 = Inferer(params, in_dim=params.aux_hid_dim)

    def forward_L_trainenc_batch(self, lang, batch_in, batch_lens, batch_lb,
                                 batch_ohlb):
        L_dict, L_pred, mu1, logvar1, z1, rec_mu1, rec_logvar1, z2 = self.forward_L_fixenc_batch(
            lang, batch_in, batch_lens, batch_lb, batch_ohlb)

        z2z1y_x = self.get_z2z1y(z2, z1, batch_ohlb)

        # nll_loss
        L_dict['L_nll'] = self.xlingva.decoder(lang,
                                               z2z1y_x,
                                               batch_in,
                                               reduction=None)
        # H(q(z1|x))
        # k/2 + k/2 log(2pi) + 1/2 log(|covariance|)
        L_dict['L_Hz1'] = mu1.shape[1] / 2.0 * (
            1 + const) + 1 / 2.0 * logvar1.sum(dim=-1)
        # regroup
        L_dict['L_z1kld'] = cal_kl_gau2(mu1, logvar1, rec_mu1, rec_logvar1)
        # fb
        L_dict['L_z1kld_fb'] = cal_kl_gau2_fb(mu1, logvar1, rec_mu1,
                                              rec_logvar1, l_z1_fb)

        return L_dict, L_pred

    def forward_U_trainenc_batch(self, lang, batch_uin, batch_ulens, batch_ulb,
                                 batch_uohlb):
        U_dict, U_pred_p, mu1, logvar1, dup_z1, dup_mu1, dup_logvar1, rec_mu1, rec_logvar1, z2 = self.forward_U_fixenc_batch(
            lang, batch_uin, batch_ulens, batch_ulb, batch_uohlb)

        z2z1y_x = self.get_z2z1y(z2, dup_z1, batch_uohlb)

        # nll_loss, expectation
        dup_batch_uin = self.enumerate_label(batch_uin, batch_uohlb)
        U_dict['U_nll'] = self.xlingva.decoder(lang,
                                               z2z1y_x,
                                               dup_batch_uin,
                                               reduction=None)
        U_dict['U_nll'] = (U_dict['U_nll'] * U_pred_p).view(
            batch_uin.shape[0], -1).sum(dim=1)
        # H(q(z1|x))
        # k/2 + k/2 log(2pi) + 1/2 log(|covariance|)
        U_dict['U_Hz1'] = mu1.shape[1] / 2.0 * (
            1 + const) + 1 / 2.0 * logvar1.sum(dim=-1)

        # regroup
        U_dict['U_z1kld'] = cal_kl_gau2(dup_mu1, dup_logvar1, rec_mu1,
                                        rec_logvar1)
        # fb
        U_dict['U_z1kld_fb'] = cal_kl_gau2_fb(dup_mu1, dup_logvar1, rec_mu1,
                                              rec_logvar1, u_z1_fb)
        return U_dict, U_pred_p

    def forward_L_fixenc_batch(self, lang, batch_in, batch_lens, batch_lb,
                               batch_ohlb):
        mu1, logvar1, z1, x, loss_dis, loss_enc = self.get_z1(
            lang, batch_in, batch_lens)

        z1x_y = self.get_z1x(z1, x)

        # cldc loss for training
        cldc_loss, _, pred = self.cldc_classifier(z1x_y,
                                                  y=batch_lb,
                                                  training=True)

        # z1yh -> z2
        mu2, logvar2, z2 = self.get_z2(z1, batch_ohlb, x)

        # au
        self.L_mu1.append(mu1)
        self.L_mu2.append(mu2)

        # z1, z2 distance
        self.L_z1z2kld += 0.5 * torch.mean(
            torch.sum(logvar1 - logvar2 - 1 +
                      ((mu2 - mu1).pow(2) + logvar2.exp()) / logvar1.exp(),
                      dim=1)).item()
        self.L_l2dist += torch.mean(
            torch.sqrt(torch.sum(((z1 - z2)**2), dim=1))).item()
        self.L_cosdist += torch.mean(F.cosine_similarity(z1, z2)).item()
        ''' 
    # gen
    self.eval()
    # fix y sample z
    # 8*4
    y_oh = torch.cat([torch.eye(4), torch.eye(4)]).cuda()
    # 2*4
    mu2 = torch.zeros(2, mu2.shape[1]).cuda()
    logvar2 = torch.zeros(2, mu2.shape[1]).cuda()
    z2 = self.z1yx_z2.reparameterize(mu2, logvar2)
    # 8*4
    z2 = z2.repeat_interleave(4, dim = 0)
    z2y_z1 = self.get_z2y(z2, y_oh)
    # z2y -> z1
    rec_mu1, rec_logvar1 = self.z2y_z1(z2y_z1)
    rec_z1 = self.z2y_z1.reparameterize(rec_mu1, rec_logvar1)
    # z2z1y - > x
    z2z1y_x = self.get_z2z1y(z2, rec_z1, y_oh)

    batch_size = z2z1y_x.shape[0]
    decoded_idxs = []
    # whether each sentence has finished generattion
    finish_idxs = torch.tensor([False] * batch_size).cuda()
    
    lang_idx = self.xlingva.decoder.lang_dict[lang]
    # init hid
    dec_hid = self.xlingva.decoder.z2hid[lang_idx](z2z1y_x).unsqueeze(0)
    # init x, pad
    dec_in = torch.zeros(dec_hid.shape[1], dtype = torch.long).unsqueeze(1).cuda()
    # max length
    for di in range(200):
      embedded = self.xlingva.decoder.embeddings.embeddings[lang_idx](dec_in)
      batch_size = embedded.shape[0]
      # concatenate with z
      embedded = torch.cat((embedded, z2z1y_x.unsqueeze(1)), dim = -1)
      # linear transformation
      embedded = self.xlingva.decoder.zx2decin[lang_idx](embedded)
      out, dec_hid = self.xlingva.decoder.rnns[lang_idx](embedded, dec_hid)
      scores = self.xlingva.decoder.hid2vocab[lang_idx](out).squeeze(1)
      scores = F.log_softmax(scores, dim = 1)
      log_prob, topi = scores.data.topk(1)
      decoded_idx = topi[:, 0].detach()
      decoded_idxs.append(decoded_idx.unsqueeze(1))
      finish_idxs = (finish_idxs | (decoded_idx == 0))

      if finish_idxs.all().item():
        break
      
      dec_in = topi[:, 0].detach().unsqueeze(1)  # detach from history as input 
    
    decoded_idxs = torch.stack(decoded_idxs, dim = -1).cpu().squeeze(1).numpy()
    sents = []
    for i in range(decoded_idxs.shape[0]):
      sent = []
      for j in range(decoded_idxs.shape[1]):
        sent.append(self.xlingva.embeddings.embeddings[lang_idx].vocab.idx2word[decoded_idxs[i][j]])
      sents.append(' '.join(sent))
    pdb.set_trace()
    # gen
    '''

        # reconstruct z1 from z2
        rec_loss, rec_mu1, rec_logvar1 = self.z2_rec_z1(z1, z2, batch_ohlb)
        # kl divergence of z2
        kld = self.kl_z2(mu2, logvar2, batch_ohlb)
        # get y prior
        yprior = batch_ohlb @ self.yprior

        # fb
        kld_fb = cal_kl_gau1_fb(mu2, logvar2, l_z2_fb)

        L_dict = {
            'L_rec': rec_loss,
            'L_kld': kld,
            'L_yprior': yprior,
            'L_cldc_loss': cldc_loss,
            'L_kld_fb': kld_fb,
            'L_dis_loss': loss_dis,
            'L_enc_loss': loss_enc,
        }
        return L_dict, pred, mu1, logvar1, z1, rec_mu1, rec_logvar1, z2

    def forward_U_fixenc_batch(self, lang, batch_uin, batch_ulens, batch_ulb,
                               batch_uohlb):
        mu1, logvar1, z1, x, loss_dis, loss_enc = self.get_z1(
            lang, batch_uin, batch_ulens)

        z1x_y = self.get_z1x(z1, x)

        # use classifier to infer
        _, pred_p, _ = self.cldc_classifier(z1x_y, y=None, training=True)

        # duplicate z1, enumerate y
        # bs * label_size, z_dim
        dup_z1 = self.enumerate_label(z1, batch_uohlb)
        dup_mu1 = self.enumerate_label(mu1, batch_uohlb)
        dup_logvar1 = self.enumerate_label(logvar1, batch_uohlb)
        dup_x = self.enumerate_label(x, batch_uohlb)

        mu2, logvar2, z2 = self.get_z2(dup_z1, batch_uohlb, dup_x)

        # au
        self.U_mu1.append(dup_mu1)
        self.U_mu2.append(mu2)

        # z1, z2 distance
        self.U_z1z2kld += 0.5 * torch.mean(
            torch.sum(dup_logvar1 - logvar2 - 1 + (
                (mu2 - dup_mu1).pow(2) + logvar2.exp()) / dup_logvar1.exp(),
                      dim=1)).item()
        self.U_l2dist += torch.mean(
            torch.sqrt(torch.sum(((dup_z1 - z2)**2), dim=1))).item()
        self.U_cosdist += torch.mean(F.cosine_similarity(dup_z1, z2)).item()

        # reconstruct z1 from z2
        rec_loss, rec_mu1, rec_logvar1 = self.z2_rec_z1(
            dup_z1, z2, batch_uohlb)
        # kl divergence of z2
        kld = self.kl_z2(mu2, logvar2, batch_uohlb)
        # get y prior
        yprior = batch_uohlb @ self.yprior

        # calculate H(q(y | x ))
        H = -torch.sum(torch.mul(pred_p, torch.log(pred_p + 1e-32)), dim=1)

        # bs * label_size
        pred_p = pred_p.view(-1)

        #fb
        kld_fb = cal_kl_gau1_fb(mu2, logvar2, u_z2_fb)

        #fb
        U_dict = {
            'U_rec': rec_loss,
            'U_kld': kld,
            'U_yprior': yprior,
            'H': H,
            'U_kld_fb': kld_fb,
            'U_dis_loss': loss_dis,
            'U_enc_loss': loss_enc,
        }
        return U_dict, pred_p, mu1, logvar1, dup_z1, dup_mu1, dup_logvar1, rec_mu1, rec_logvar1, z2

    def get_z1x_transadd(self, z1, x):
        z1_y = self.z1tohid_y(z1)
        x_y = self.xtohid_y(x)
        z1x_y = z1_y + x_y
        if z1x_y.shape[0] > 1:
            z1x_y = self.hbn_z1x(z1x_y)
        z1x_y = self.leakyrelu(z1x_y)
        return z1x_y

    def get_z1x_gmix_transadd(self, z1, x):
        return self.get_z1x_transadd(z1, x)

    def get_z2(self, z1, batch_ohlb, x):
        z1yx_z2 = self.get_z1yx(z1, batch_ohlb, x)
        # z1y -> z2
        mu2, logvar2 = self.z1yx_z2(z1yx_z2)
        z2 = self.z1yx_z2.reparameterize(mu2, logvar2)
        # mu debug
        # z2 = mu2
        # mu debug
        return mu2, logvar2, z2

    def get_z1yx_transadd(self, z1, batch_ohlb, x):
        # z1yx -> z2
        y_z2 = self.ytohid_z2(batch_ohlb)
        z1_z2 = self.z1tohid_z2(z1)
        x_z2 = self.xtohid_z2(x)
        z1yx_z2 = z1_z2 + y_z2 + x_z2
        if z1yx_z2.shape[0] > 1:
            z1yx_z2 = self.hbn_z1xy(z1yx_z2)
        z1yx_z2 = self.leakyrelu(z1yx_z2)
        return z1yx_z2

    def get_z1yx_gmix_transadd(self, z1, batch_ohlb, x):
        return self.get_z1yx_transadd(z1, batch_ohlb, x)

    def z2_rec_z1(self, z1, z2, batch_ohlb):
        z2y_z1 = self.get_z2y(z2, batch_ohlb)
        # z2y -> z1
        rec_mu1, rec_logvar1 = self.z2y_z1(z2y_z1)
        rec_z1 = self.z2y_z1.reparameterize(rec_mu1, rec_logvar1)

        logpz1 = multi_diag_normal(z1, rec_mu1, rec_logvar1)
        return -logpz1, rec_mu1, rec_logvar1

    def get_z2y_transadd(self, z2, batch_ohlb):
        y_z1 = self.ytohid_z1(batch_ohlb)
        z2_z1 = self.z2tohid_z1(z2)
        # p(z1|z2, y)
        z2y_z1 = z2_z1 + y_z1
        # p(z1|z2)
        #z2y_z1 = z2_z1
        if z2y_z1.shape[0] > 1:
            z2y_z1 = self.hbn_z2y(z2y_z1)
        z2y_z1 = self.leakyrelu(z2y_z1)
        return z2y_z1

    def get_z2y_gmix_transadd(self, z2, batch_ohlb):
        return self.get_z2y_transadd(z2, batch_ohlb)

    def get_z2z1y_transadd(self, z2, z1, batch_ohlb):
        z2_x = self.z2tohid_x(z2)
        z1_x = self.z1tohid_x(z1)
        y_x = self.ytohid_x(batch_ohlb)
        z2z1y_x = z2_x + z1_x + y_x
        if z2z1y_x.shape[0] > 1:
            z2z1y_x = self.hbn_z2z1y(z2z1y_x)
        z2z1y_x = self.leakyrelu(z2z1y_x)
        z2z1y_x = self.z2z1y_x(z2z1y_x)
        return z2z1y_x

    def get_z2z1y_gmix_transadd(self, z2, z1, batch_ohlb):
        return self.get_z2z1y_transadd(z2, z1, batch_ohlb)

    def kl_z2_gmix_transadd(self, mu_post, logvar_post, batch_ohlb):
        y_z2 = self.ytohid_z2(batch_ohlb)
        if y_z2.shape[0] > 1:
            y_z2 = self.hbn_y(y_z2)
        mu_prior, logvar_prior = self.y_z2(y_z2)
        kld = cal_kl_gau2(mu_post, logvar_post, mu_prior, logvar_prior)
        return kld
Exemplo n.º 6
0
class SEMICLDCModel(nn.Module):
    def __init__(self, params, data_list, model_dict=None):
        super(SEMICLDCModel, self).__init__()
        # label size
        self.label_size = params.cldc_label_size
        # batch size for U / label size
        self.bs_u = params.semicldc_U_bs // self.label_size
        # alpha for cldc classifier
        self.semicldc_classifier_alpha = params.semicldc_classifier_alpha

        # get y prior
        self.yprior = self.get_yprior(params, data_list)
        # xling vae
        self.xlingva = XlingVA(params, data_list, model_dict=model_dict)

        self.init_semicldc_cond = getattr(
            self, 'init_semicldc_cond_{}'.format(params.semicldc_cond_type))
        # initialize the setting of how to combine x & y and z & y
        self.init_semicldc_cond(params)

        # cldc MLP
        self.cldc_classifier = CLDCClassifier(params,
                                              params.cldc_classifier_config)

        # functions
        self.forward = getattr(self,
                               'forward_{}'.format(params.cldc_train_mode))
        self.get_z1 = getattr(self, 'get_z1_{}'.format(params.cldc_train_mode))
        # x & y
        self.get_z1y = getattr(self,
                               'get_z1y_{}'.format(params.semicldc_cond_type))
        # z & y
        self.get_z2y = getattr(self,
                               'get_z2y_{}'.format(params.semicldc_cond_type))
        # calculate kl of z2
        self.kl_z2 = getattr(self,
                             'kl_z2_{}'.format(params.semicldc_cond_type))

        self.step = 0
        self.anneal_warm_up = params.semicldc_anneal_warm_up

        # warm up stage
        self.warm_up = False

        self.use_cuda = params.cuda
        if self.use_cuda:
            self.cuda()

    def init_model(self, model_dict):
        if model_dict is None:
            return
        else:
            # 3. load the new state dict
            # parameter names need to be exactly the same
            self.load_state_dict(model_dict)

    def train_classifier(self,
                         lang,
                         batch_in,
                         batch_lens,
                         batch_lb,
                         training=True):
        # x -> hid_x -> mu1, logva1 -> z1
        mu1, logvar1, z1, x, loss_dis, loss_enc = self.get_z1(
            lang, batch_in, batch_lens)

        # cldc loss for training
        cldc_loss, pred_p, pred = self.cldc_classifier(z1,
                                                       y=batch_lb,
                                                       training=training)

        if training:
            L_dict = defaultdict(float)
            if not hasattr(self, 'adv_training') or self.adv_training is False:
                cldc_loss.mean().backward()
                L_dict['L_dis_loss'] = loss_dis
                L_dict['L_enc_loss'] = loss_enc
            elif self.adv_training is True:
                (cldc_loss.mean() + loss_dis.mean() +
                 loss_enc.mean()).backward()
                L_dict['L_dis_loss'] = loss_dis.mean().item()
                L_dict['L_enc_loss'] = loss_enc.mean().item()
            L_dict['L_cldc_loss'] = cldc_loss.mean().item()
            return L_dict, pred
        else:
            return cldc_loss, pred_p, pred

    def get_yprior(self, params, data_list):
        if params.semicldc_yprior_type == 'uniform':
            # prior scores of y
            yprior_score = torch.ones(self.label_size)
            # uniform distribution
            m = nn.LogSoftmax(dim=-1)
            yprior = m(yprior_score)
        elif params.semicldc_yprior_type == 'train_prop':
            # same distribution as lablled training data
            train_prop = data_list[params.lang_dict[
                params.cldc_langs[0]]].train_prop
            yprior = torch.log(
                torch.tensor(train_prop + 1e-32,
                             dtype=torch.float,
                             requires_grad=False))

        if params.cuda:
            yprior = yprior.cuda()

        return yprior

    def init_semicldc_cond_concat(self, params):
        # concat directly z1 and one hot y
        self.z1y_z2 = Inferer(params, in_dim=params.z_dim + self.label_size)
        self.z2y_z1 = Inferer(params, in_dim=params.z_dim + self.label_size)

    def init_semicldc_cond_transconcat(self, params):
        # trans one hot y to dense y, then concat z1 and y
        self.yohtoy = nn.Linear(self.label_size, params.z_dim)
        self.z1y_z2 = Inferer(params, in_dim=params.z_dim + params.z_dim)
        self.z2y_z1 = Inferer(params, in_dim=params.z_dim + params.z_dim)

    def init_semicldc_cond_transadd(self, params):
        # trans one_hot_y to dense_y, then add dense_y to z1
        self.leakyrelu = nn.LeakyReLU()
        self.yohtoy_toz2 = nn.Linear(self.label_size, params.z_dim)
        self.hbn_z1y = nn.BatchNorm1d(params.z_dim)
        self.yohtoy_toz1 = nn.Linear(self.label_size, params.z_dim)
        self.hbn_z2y = nn.BatchNorm1d(params.z_dim)
        self.z1y_z2 = Inferer(params, in_dim=params.z_dim)
        self.z2y_z1 = Inferer(params, in_dim=params.z_dim)

    def init_semicldc_cond_gmix(self, params):
        # concat
        # concat z1 and y
        self.z1y_z2 = Inferer(params, in_dim=params.z_dim + self.label_size)
        # p(z2 | y)
        self.y_z2 = Inferer(params, in_dim=self.label_size)
        # p(z1 | z2)
        self.z2y_z1 = Inferer(params, in_dim=params.z_dim)
        '''
    # transadd
    # trans one_hot_y to dense_y, then add dense_y to z1
    self.yohtoy = nn.Linear(self.label_size, params.z_dim)
    self.z1y_z2 = Inferer(params, in_dim = params.z_dim)
    self.y2z2 = Inferer(params, in_dim = params.z_dim)
    self.z2y_z1 = Inferer(params, in_dim = params.z_dim)
    '''

    def init_semicldc_cond_gmix_transadd(self, params):
        # trans one_hot_y to dense_y, then add dense_y to z1
        self.leakyrelu = nn.LeakyReLU()
        self.yohtoy_toz2 = nn.Linear(self.label_size, params.z_dim)
        self.hbn_z1y = nn.BatchNorm1d(params.z_dim)
        self.z1y_z2 = Inferer(params, in_dim=params.z_dim)

        self.y_z2 = Inferer(params, in_dim=params.cldc_label_size)
        self.z2y_z1 = Inferer(params, in_dim=params.z_dim)

    def forward_trainenc(self, lang, batch_in, batch_lens, batch_lb,
                         batch_ohlb, batch_uin, batch_ulens, batch_ulb,
                         batch_uohlb):
        # warm up
        if self.warm_up:
            return self.train_classifier(lang, batch_in, batch_lens, batch_lb)

        # au
        self.L_mu1, self.U_mu1 = [], []
        self.L_mu2, self.U_mu2 = [], []
        # z1, z2 distance
        self.L_l2dist, self.L_cosdist, self.L_z1z2kld = .0, .0, .0
        self.U_l2dist, self.U_cosdist, self.U_z1z2kld = .0, .0, .0

        # calculate L loss and classfication loss
        L_dict, L_pred = self.forward_L_trainenc(
            lang,
            batch_in,
            batch_lens,
            batch_lb,
            batch_ohlb,
            cls_alpha=self.semicldc_classifier_alpha)

        # calculate U loss
        U_dict = self.forward_U_trainenc(lang, batch_uin, batch_ulens,
                                         batch_ulb, batch_uohlb)

        # merge two dicts
        loss_dict = {**L_dict, **U_dict}
        self.step += 1

        # z1, z2 distance
        loss_dict['L_l2_dist'], loss_dict['L_cosdist'], loss_dict[
            'L_z1z2kld'] = self.L_l2dist, self.L_cosdist, self.L_z1z2kld
        loss_dict['U_l2_dist'], loss_dict['U_cosdist'], loss_dict[
            'U_z1z2kld'] = self.U_l2dist, self.U_cosdist, self.U_z1z2kld

        # total MEAN loss
        loss_dict['total_loss'] = loss_dict['L_loss_trainenc'] + loss_dict[
            'U_loss_trainenc'] + self.semicldc_classifier_alpha * loss_dict[
                'L_cldc_loss']

        # au
        loss_dict['L_mu1'] = calc_au(self.L_mu1)[0]
        loss_dict['L_mu2'] = calc_au(self.L_mu2)[0]
        loss_dict['U_mu1'] = calc_au(self.U_mu1)[0]
        loss_dict['U_mu2'] = calc_au(self.U_mu2)[0]
        # au
        print()
        print('L_mu1: {}'.format(loss_dict['L_mu1']))
        print('L_mu2: {}'.format(loss_dict['L_mu2']))
        print('U_mu1: {}'.format(loss_dict['U_mu1']))
        print('U_mu2: {}'.format(loss_dict['U_mu2']))

        return loss_dict, L_pred

    def forward_L_trainenc(self,
                           lang,
                           batch_in,
                           batch_lens,
                           batch_lb,
                           batch_ohlb,
                           le=1.0,
                           cls_alpha=0.1):
        L_dict, L_pred = self.forward_L_trainenc_batch(lang, batch_in,
                                                       batch_lens, batch_lb,
                                                       batch_ohlb)

        # calculate all necessary losses
        L_dict = self.cal_L_trainenc(L_dict)

        # backward
        L_dict = self.backward_L_trainenc(L_dict, le, cls_alpha)

        return L_dict, L_pred

    def forward_L_trainenc_batch(self, lang, batch_in, batch_lens, batch_lb,
                                 batch_ohlb):
        L_dict, L_pred, mu1, logvar1, z1, rec_mu1, rec_logvar1 = self.forward_L_fixenc_batch(
            lang, batch_in, batch_lens, batch_lb, batch_ohlb)

        # nll_loss
        L_dict['L_nll'] = self.xlingva.decoder(lang,
                                               z1,
                                               batch_in,
                                               reduction=None)
        # H(q(z1|x))
        # k/2 + k/2 log(2pi) + 1/2 log(|covariance|)
        # L_dict['L_Hz1'] = -multi_diag_normal(z1, mu1, logvar1)
        L_dict['L_Hz1'] = mu1.shape[1] / 2.0 * (
            1 + const) + 1 / 2.0 * logvar1.sum(dim=-1)
        # regroup
        L_dict['L_z1kld'] = cal_kl_gau2(mu1, logvar1, rec_mu1, rec_logvar1)
        # fb
        L_dict['L_z1kld_fb'] = cal_kl_gau2_fb(mu1, logvar1, rec_mu1,
                                              rec_logvar1, l_z1_fb)

        return L_dict, L_pred

    def cal_L_trainenc(self, L_dict):
        lkld_fix = 5.0
        lz1kld_fix = 5.0

        L_dict = self.cal_L_fixenc(L_dict)
        # L_dict['L_loss_trainenc'] = L_dict['L_loss'] + lnll * L_dict['L_nll'] - lz1kl * L_dict['L_Hz1']
        # regroup
        '''
    kl_weight_z1 = get_cyclic_weight(self.step, self.cyclic_period)
    print()
    print(kl_weight_z1)
    kl_weight_z2 = get_cyclic_weight(self.step, self.cyclic_period)
    '''
        '''
    kl_weight_z1 = min(1.0, self.step / self.anneal_warm_up)
    kl_weight_z2 = min(1.0, self.step / self.anneal_warm_up)
    '''
        kl_weight_z1 = 1.0
        kl_weight_z2 = 1.0
        L_dict['L_loss_trainenc'] = L_dict['L_nll'] + kl_weight_z2 * L_dict['L_kld'] + kl_weight_z1 * \
                                    L_dict['L_z1kld'] - L_dict['L_yprior']
        # L_dict['L_loss_trainenc'] = L_dict['L_nll'] + torch.abs(lkld_fix - L_dict['L_kld']) + torch.abs(lz1kld_fix - L_dict['L_z1kld']) - L_dict['L_yprior']
        L_dict[
            'L_TKL'] = L_dict['L_kld'] + L_dict['L_z1kld'] - L_dict['L_yprior']

        # fb
        '''
    kl_weight_nll = min(1.0, self.step / self.anneal_warm_up)
    kl_weight_z1 = get_cyclic_weight(self.step, self.cyclic_period)
    print()
    print(kl_weight_z1)
    kl_weight_z2 = get_cyclic_weight(self.step, self.cyclic_period)
    kl_weight_nll = 1.0
    kl_weight_z1 = 1.0
    kl_weight_z2 = 1.0
    L_dict['L_loss_trainenc'] = kl_weight_nll * L_dict['L_nll'] + kl_weight_z2 * L_dict['L_kld_fb'] + kl_weight_z1 * L_dict['L_z1kld_fb'] - L_dict['L_yprior']
    '''

        # autoencoding wo KL
        # L_dict['L_loss_trainenc'] = L_dict['L_nll']

        return L_dict

    def backward_L_trainenc(self, L_dict, e, alpha):
        '''
    # annealing
    total_step = 5000 * 2
    alpha = self.semicldc_classifier_alpha - (self.semicldc_classifier_alpha - 0.1) * (
            self.step / total_step)
    print()
    print(alpha)
    # cyclic annealing
    # number of steps for increasing
    total_step = 100 * 2
    cur_step = self.step % total_step
    alpha = self.semicldc_classifier_alpha - (self.semicldc_classifier_alpha - 0.1) * (
            cur_step / total_step)
    print()
    print(alpha)
    '''

        if not hasattr(self, 'adv_training') or self.adv_training is False:
            (e * (L_dict['L_loss_trainenc'].mean()) +
             alpha * L_dict['L_cldc_loss'].mean()).backward()
        elif self.adv_training is True:
            (e * (L_dict['L_loss_trainenc'].mean()) +
             alpha * L_dict['L_cldc_loss'].mean() +
             L_dict['L_dis_loss'].mean() +
             L_dict['L_enc_loss'].mean()).backward()

        # autoencoding wo KL
        # (L_dict['L_loss_trainenc'].mean()).backward()

        # get mean().item(), reduce memory
        L_dict = {
            k: (v.mean().item() if torch.is_tensor(v) else float(v))
            for k, v in L_dict.items()
        }

        return L_dict

    def forward_U_trainenc(self,
                           lang,
                           batch_uin,
                           batch_ulens,
                           batch_ulb,
                           batch_uohlb,
                           ue=1.0):
        U_dict = defaultdict(float)

        cur_bs = batch_uin.shape[0]
        n_bs = math.ceil(cur_bs / self.bs_u)

        for i in range(n_bs):
            U_dict_batch, U_pred_p = self.forward_U_trainenc_batch(
                lang, batch_uin[i * self.bs_u:(i + 1) * self.bs_u],
                batch_ulens[i * self.bs_u:(i + 1) * self.bs_u],
                batch_ulb[i * self.bs_u * self.label_size:(i + 1) * self.bs_u *
                          self.label_size],
                batch_uohlb[i * self.bs_u * self.label_size:(i + 1) *
                            self.bs_u * self.label_size])
            # calculate all necessary losses
            U_dict_batch = self.cal_U_trainenc(U_dict_batch, U_pred_p)
            # backward
            U_dict_batch = self.backward_U_trainenc(U_dict_batch, cur_bs, ue)
            U_dict = {k: (U_dict[k] + v) for k, v in U_dict_batch.items()}

            # z1, z2 distance
        self.U_l2dist /= n_bs
        self.U_cosdist /= n_bs
        self.U_z1z2kld /= n_bs

        U_dict = {k: v / cur_bs for k, v in U_dict.items()}

        return U_dict

    def forward_U_trainenc_batch(self, lang, batch_uin, batch_ulens, batch_ulb,
                                 batch_uohlb):
        U_dict, U_pred_p, mu1, logvar1, z1, dup_mu1, dup_logvar1, rec_mu1, rec_logvar1 = self.forward_U_fixenc_batch(
            lang, batch_uin, batch_ulens, batch_ulb, batch_uohlb)

        U_dict['U_nll'] = self.xlingva.decoder(lang,
                                               z1,
                                               batch_uin,
                                               reduction=None)
        # H(q(z1|x))
        # k/2 + k/2 log(2pi) + 1/2 log(|covariance|)
        # U_dict['U_Hz1'] = -multi_diag_normal(z1, mu1, logvar1)
        U_dict['U_Hz1'] = mu1.shape[1] / 2.0 * (
            1 + const) + 1 / 2.0 * logvar1.sum(dim=-1)
        # regroup
        U_dict['U_z1kld'] = cal_kl_gau2(dup_mu1, dup_logvar1, rec_mu1,
                                        rec_logvar1)
        # fb
        U_dict['U_z1kld_fb'] = cal_kl_gau2_fb(dup_mu1, dup_logvar1, rec_mu1,
                                              rec_logvar1, u_z1_fb)

        return U_dict, U_pred_p

    def cal_U_trainenc(self, U_dict, U_pred_p):
        ukld_fix = 5.0
        uz1kld_fix = 5.0

        U_dict = self.cal_U_fixenc(U_dict, U_pred_p)

        # U_dict['U_loss_trainenc'] = U_dict['U_loss'] + unll * U_dict['U_nll'] - uz1kl * U_dict['U_Hz1']
        # regroup
        U_dict['U_z1kld'] = torch.sum(
            (U_dict['U_z1kld'] * U_pred_p).view(-1, self.label_size), dim=1)
        '''
    kl_weight_z1 = get_cyclic_weight(self.step, self.cyclic_period)
    kl_weight_z2 = get_cyclic_weight(self.step, self.cyclic_period)
    '''
        '''
    kl_weight_z1 = min(1.0, self.step / self.anneal_warm_up)
    kl_weight_z2 = min(1.0, self.step / self.anneal_warm_up)
    '''
        kl_weight_z1 = 1.0
        kl_weight_z2 = 1.0
        U_dict['U_loss_trainenc'] = U_dict['U_nll'] + kl_weight_z2 * U_dict['U_kld'] + kl_weight_z1 * \
                                    U_dict['U_z1kld'] + U_dict['kldy']
        # U_dict['U_loss_trainenc'] = U_dict['U_nll'] + torch.abs(ukld_fix - U_dict['U_kld']) + torch.abs(uz1kld_fix - U_dict['U_z1kld']) + U_dict['kldy']
        U_dict['U_TKL'] = U_dict['U_kld'] + U_dict['U_z1kld'] + U_dict['kldy']
        '''
    # fb
    U_dict['U_z1kld_fb'] = torch.sum((U_dict['U_z1kld_fb'] * U_pred_p).view(-1, self.label_size), dim = 1)
    #kl_weight_nll = min(1.0, self.step / self.anneal_warm_up)
    #kl_weight_z1 = get_cyclic_weight(self.step, self.cyclic_period)
    #kl_weight_z2 = get_cyclic_weight(self.step, self.cyclic_period)
    kl_weight_nll = 1.0
    kl_weight_z1 = 1.0
    kl_weight_z2 = 1.0
    kl_weight_y = 1.0
    U_dict['U_loss_trainenc'] = kl_weight_nll * U_dict['U_nll'] + kl_weight_z2 * U_dict['U_kld_fb'] + kl_weight_z1 * U_dict['U_z1kld_fb'] + kl_weight_y * U_dict['kldy']
    '''

        # autoencoding wo KL
        # U_dict['U_loss_trainenc'] = U_dict['U_nll']

        return U_dict

    def backward_U_trainenc(self, U_dict, cur_bs, e):
        # backward
        if not hasattr(self, 'adv_training') or self.adv_training is False:
            (e * U_dict['U_loss_trainenc'].sum() / cur_bs).backward()
        elif self.adv_training is True:
            (e * U_dict['U_loss_trainenc'].sum() / cur_bs +
             U_dict['U_dis_loss'].sum() / cur_bs +
             U_dict['U_enc_loss'].sum() / cur_bs).backward()

        # get sum().item()
        U_dict = {
            k: (v.sum().item() if torch.is_tensor(v) else float(v))
            for k, v in U_dict.items()
        }

        return U_dict

    def forward_fixenc(self, lang, batch_in, batch_lens, batch_lb, batch_ohlb,
                       batch_uin, batch_ulens, batch_ulb, batch_uohlb):
        # au
        self.L_mu1, self.U_mu1 = [], []
        self.L_mu2, self.U_mu2 = [], []
        # z1, z2 distance
        self.L_l2dist, self.L_cosdist, self.L_z1z2kld = .0, .0, .0
        self.U_l2dist, self.U_cosdist, self.U_z1z2kld = .0, .0, .0

        # calculate L loss and classfication loss
        L_dict, L_pred = self.forward_L_fixenc(lang, batch_in, batch_lens,
                                               batch_lb, batch_ohlb)

        # calculate U loss
        U_dict = self.forward_U_fixenc(lang, batch_uin, batch_ulens, batch_ulb,
                                       batch_uohlb)

        # merge two dicts
        loss_dict = {**L_dict, **U_dict}
        self.step += 1

        # z1, z2 distance
        loss_dict['L_l2_dist'], loss_dict['L_cosdist'], loss_dict[
            'L_z1z2kld'] = self.L_l2dist, self.L_cosdist, self.L_z1z2kld
        loss_dict['U_l2_dist'], loss_dict['U_cosdist'], loss_dict[
            'U_z1z2kld'] = self.U_l2dist, self.U_cosdist, self.U_z1z2kld

        # total MEAN loss
        # total_loss = L_loss.mean()
        # total_loss = 0.1 * L_cldc_loss.mean()
        # total_loss = L_loss.mean() + 0.1 * L_cldc_loss.mean()
        # total_loss = L_loss + U_loss
        loss_dict['total_loss'] = loss_dict['L_loss'] + loss_dict[
            'U_loss'] + self.semicldc_classifier_alpha * loss_dict[
                'L_cldc_loss']

        # au
        loss_dict['L_mu1'] = calc_au(self.L_mu1)[0]
        loss_dict['L_mu2'] = calc_au(self.L_mu2)[0]
        loss_dict['U_mu1'] = calc_au(self.U_mu1)[0]
        loss_dict['U_mu2'] = calc_au(self.U_mu2)[0]

        return loss_dict, L_pred

    def forward_L_fixenc(self, lang, batch_in, batch_lens, batch_lb,
                         batch_ohlb):
        L_dict, L_pred, _, _, _, _, _ = self.forward_L_fixenc_batch(
            lang, batch_in, batch_lens, batch_lb, batch_ohlb)

        # calculate all necessary losses
        L_dict = self.cal_L_fixenc(L_dict)
        # backward
        L_dict = self.backward_L_fixenc(L_dict)

        return L_dict, L_pred

    def forward_L_fixenc_batch(self, lang, batch_in, batch_lens, batch_lb,
                               batch_ohlb):
        # x -> hid_x -> mu1, logva1 -> z1
        mu1, logvar1, z1, hid, loss_dis, loss_enc = self.get_z1(
            lang, batch_in, batch_lens)
        # cldc loss for training
        cldc_loss, _, pred = self.cldc_classifier(z1,
                                                  y=batch_lb,
                                                  training=True)
        # z1y -> z2
        mu2, logvar2, z2 = self.get_z2(z1, batch_ohlb)

        # au
        self.L_mu1.append(mu1)
        self.L_mu2.append(mu2)

        # z1, z2 distance
        self.L_z1z2kld += 0.5 * torch.mean(
            torch.sum(logvar1 - logvar2 - 1 +
                      ((mu2 - mu1).pow(2) + logvar2.exp()) / logvar1.exp(),
                      dim=1)).item()
        self.L_l2dist += torch.mean(
            torch.sqrt(torch.sum(((z1 - z2)**2), dim=1))).item()
        self.L_cosdist += torch.mean(F.cosine_similarity(z1, z2)).item()

        # reconstruct z1 from z2
        rec_loss, rec_mu1, rec_logvar1 = self.z2_rec_z1(z1, z2, batch_ohlb)
        # kl divergence of z2
        kld = self.kl_z2(mu2, logvar2, batch_ohlb)
        # get y prior
        yprior = batch_ohlb @ self.yprior

        # fb
        kld_fb = cal_kl_gau1_fb(mu2, logvar2, l_z2_fb)

        # fb
        L_dict = {
            'L_rec': rec_loss,
            'L_kld': kld,
            'L_yprior': yprior,
            'L_cldc_loss': cldc_loss,
            'L_kld_fb': kld_fb,
            'L_dis_loss': loss_dis,
            'L_enc_loss': loss_enc,
        }
        return L_dict, pred, mu1, logvar1, z1, rec_mu1, rec_logvar1

    def kl_z2_gmix(self, mu_post, logvar_post, batch_ohlb):
        # concat
        mu_prior, logvar_prior = self.y2z2(batch_ohlb)
        kld = 0.5 * torch.sum(logvar_prior - logvar_post - 1 + (
            (mu_post - mu_prior).pow(2) + logvar_post.exp()) /
                              logvar_prior.exp(),
                              dim=1)
        '''
    # transadd
    y = self.yohtoy(batch_ohlb)
    mu_prior, logvar_prior = self.y2z2(y)
    kld = 0.5 * torch.sum(logvar_prior - logvar_post - 1 + ((mu_post - mu_prior).pow(2) + logvar_post.exp()) / logvar_prior.exp(), dim = 1)
    '''
        return kld

    def kl_z2_gmix_transadd(self, mu_post, logvar_post, batch_ohlb):
        # transadd
        mu_prior, logvar_prior = self.y_z2(batch_ohlb)
        kld = 0.5 * torch.sum(logvar_prior - logvar_post - 1 + (
            (mu_post - mu_prior).pow(2) + logvar_post.exp()) /
                              logvar_prior.exp(),
                              dim=1)
        return kld

    def kl_z2_general(self, mu_post, logvar_post, batch_ohlb):
        kld = -0.5 * torch.sum(
            1 + logvar_post - mu_post.pow(2) - logvar_post.exp(), dim=1)
        return kld

    def kl_z2_transadd(self, mu_post, logvar_post, batch_ohlb):
        return self.kl_z2_general(mu_post, logvar_post, batch_ohlb)

    def cal_L_fixenc(self, L_dict):
        # L
        L_dict[
            'L_loss'] = L_dict['L_rec'] + L_dict['L_kld'] - L_dict['L_yprior']
        # L_dict['L_loss'] = lrec * L_dict['L_rec'] + lkld * torch.abs(lkld_fix - L_dict['L_kld']) - L_dict['L_yprior']
        # L_loss = L_rec + min(1.0, self.step / 1000) * L_kld - L_yprior

        return L_dict

    def backward_L_fixenc(self, L_dict):
        # backprop
        (L_dict['L_loss'].mean() + self.semicldc_classifier_alpha *
         L_dict['L_cldc_loss'].mean()).backward()

        # get mean().item(), reduce memory
        L_dict = {k: v.mean().item() for k, v in L_dict.items()}

        return L_dict

    def forward_U_fixenc(self, lang, batch_uin, batch_ulens, batch_ulb,
                         batch_uohlb):
        U_dict = defaultdict(float)

        cur_bs = batch_uin.shape[0]
        n_bs = math.ceil(cur_bs / self.bs_u)

        for i in range(n_bs):
            U_dict_batch, U_pred_p, _, _, _, _, _, _, _ = self.forward_U_fixenc_batch(
                lang, batch_uin[i * self.bs_u:(i + 1) * self.bs_u],
                batch_ulens[i * self.bs_u:(i + 1) * self.bs_u],
                batch_ulb[i * self.bs_u * self.label_size:(i + 1) * self.bs_u *
                          self.label_size],
                batch_uohlb[i * self.bs_u * self.label_size:(i + 1) *
                            self.bs_u * self.label_size])

            # calculate all necessary losses
            U_dict_batch = self.cal_U_fixenc(U_dict_batch, U_pred_p)
            # backward
            U_dict_batch = self.backward_U_fixenc(U_dict_batch, cur_bs)
            U_dict = {k: (U_dict[k] + v) for k, v in U_dict_batch.items()}

        U_dict = {k: v / cur_bs for k, v in U_dict.items()}

        return U_dict

    def forward_U_fixenc_batch(self, lang, batch_uin, batch_ulens, batch_ulb,
                               batch_uohlb):
        mu1, logvar1, z1, hid, loss_dis, loss_enc = self.get_z1(
            lang, batch_uin, batch_ulens)
        # use classifier to infer
        _, pred_p, _ = self.cldc_classifier(z1, y=None, training=True)
        # gumbel softmax
        # duplicate z1, enumerate y
        # bs * label_size, z_dim
        dup_z1 = self.enumerate_label(z1, batch_uohlb)
        dup_mu1 = self.enumerate_label(mu1, batch_uohlb)
        dup_logvar1 = self.enumerate_label(logvar1, batch_uohlb)
        # z1y -> z2
        mu2, logvar2, z2 = self.get_z2(dup_z1, batch_uohlb)

        self.U_mu1.append(dup_mu1)
        self.U_mu2.append(mu2)

        # z1, z2 distance
        self.U_z1z2kld += 0.5 * torch.mean(
            torch.sum(dup_logvar1 - logvar2 - 1 + (
                (mu2 - dup_mu1).pow(2) + logvar2.exp()) / dup_logvar1.exp(),
                      dim=1)).item()
        self.U_l2dist += torch.mean(
            torch.sqrt(torch.sum(((dup_z1 - z2)**2), dim=1))).item()
        self.U_cosdist += torch.mean(F.cosine_similarity(dup_z1, z2)).item()

        # reconstruct z1 from z2
        rec_loss, rec_mu1, rec_logvar1 = self.z2_rec_z1(
            dup_z1, z2, batch_uohlb)
        # kl divergence of z2
        kld = self.kl_z2(mu2, logvar2, batch_uohlb)
        # get y prior
        yprior = batch_uohlb @ self.yprior

        # calculate H(q(y | x ))
        H = -torch.sum(torch.mul(pred_p, torch.log(pred_p + 1e-32)), dim=1)

        # bs * label_size
        pred_p = pred_p.view(-1)

        # fb
        kld_fb = cal_kl_gau1_fb(mu2, logvar2, u_z2_fb)

        # fb
        U_dict = {
            'U_rec': rec_loss,
            'U_kld': kld,
            'U_yprior': yprior,
            'H': H,
            'U_kld_fb': kld_fb,
            'U_dis_loss': loss_dis,
            'U_enc_loss': loss_enc,
        }

        return U_dict, pred_p, mu1, logvar1, z1, dup_mu1, dup_logvar1, rec_mu1, rec_logvar1

    def cal_U_fixenc(self, U_dict, U_pred_p):
        # L for U
        UL_rec, UL_kld, UL_yprior = U_dict['U_rec'], U_dict['U_kld'], U_dict[
            'U_yprior']

        UL_loss = UL_rec + UL_kld - UL_yprior
        # UL_loss = urec * UL_rec + ukld * torch.abs(ukld_fix - UL_kld) - uyp * UL_yprior
        # U_L_loss = U_rec + min(1.0, self.step / 1000) * U_kld - U_yprior

        # expectation of UL_loss
        U_dict['UL_mean_loss'] = torch.sum(
            (U_pred_p * UL_loss).view(-1, self.label_size), dim=-1)

        # Total U
        # U_loss =  U_L_mean_loss - H
        U_dict['U_loss'] = U_dict['UL_mean_loss'] - U_dict['H']

        # calculate expectations for each term
        U_dict['U_rec'] = torch.sum(
            (U_pred_p * UL_rec).view(-1, self.label_size), dim=-1)
        U_dict['U_kld'] = torch.sum(
            (U_pred_p * UL_kld).view(-1, self.label_size), dim=-1)
        U_dict['U_yprior'] = torch.sum(
            (U_pred_p * UL_yprior).view(-1, self.label_size), dim=-1)
        # fb
        U_dict['U_kld_fb'] = torch.sum(
            (U_pred_p * U_dict['U_kld_fb']).view(-1, self.label_size), dim=-1)

        # KL(q(y|x) || p(y)) for U
        # bs, label_size
        U_pred_p_rv = U_pred_p.view(-1, self.label_size)
        U_dict['kldy'] = (U_pred_p_rv *
                          (torch.log(U_pred_p_rv + 1e-32) - self.yprior)).mean(
                              dim=1)

        # U_dict['U_loss'] += - U_dict['kldy'] + torch.abs(U_dict['kldy'] - 0.3)

        return U_dict

    def backward_U_fixenc(self, U_dict, cur_bs):
        # backward
        (U_dict['U_loss'].sum() / cur_bs).backward()

        # get sum().item()
        U_dict = {k: v.sum().item() for k, v in U_dict.items()}

        return U_dict

    def get_z1_fixenc(self, lang, batch_in, batch_lens):
        with torch.no_grad():
            # extract feature: mu1, logvar1, eval mode
            self.xlingva.eval()
            mu1, logvar1, hid, loss_dis, loss_enc = self.xlingva.get_gaus(
                lang, batch_in, batch_lens)

            # stochastic sampling, z for training, mu for eval
            if self.training:
                self.xlingva.inferer.train()
            else:
                self.xlingva.inferer.eval()
            z1 = self.xlingva.inferer.reparameterize(mu1, logvar1)
            # mu debug
            # z1 = mu1
            # mu debug
        return mu1, logvar1, z1, hid, loss_dis, loss_enc

    def get_z1_trainenc(self, lang, batch_in, batch_lens):
        mu1, logvar1, hid, loss_dis, loss_enc = self.xlingva.get_gaus(
            lang, batch_in, batch_lens)
        z1 = self.xlingva.inferer.reparameterize(mu1, logvar1)
        return mu1, logvar1, z1, hid, loss_dis, loss_enc

    def get_z2(self, z1, batch_ohlb):
        z1y = self.get_z1y(z1, batch_ohlb)
        # z1y -> z2
        mu2, logvar2 = self.z1y_z2(z1y)
        z2 = self.z1y_z2.reparameterize(mu2, logvar2)
        # mu debug
        # z2 = mu2
        # mu debug
        return mu2, logvar2, z2

    def get_z1y_concat(self, z1, batch_ohlb):
        # z1y -> z2
        z1y = torch.cat([z1, batch_ohlb], dim=-1)
        return z1y

    def get_z1y_transconcat(self, z1, batch_ohlb):
        # z1y -> z2
        batch_lb = self.yohtoy(batch_ohlb)
        z1y = torch.cat([z1, batch_lb], dim=-1)
        return z1y

    def get_z1y_transadd(self, z1, batch_ohlb):
        # z1y -> z2
        batch_lb = self.yohtoy_toz2(batch_ohlb)
        z1y = z1 + batch_lb
        if z1y.shape[-1] > 1:
            z1y = self.hbn_z1y(z1y)
        z1y = self.leakyrelu(z1y)
        return z1y

    def get_z1y_gmix(self, z1, batch_ohlb):
        # z1y -> z2
        # concat
        z1y = torch.cat([z1, batch_ohlb], dim=-1)
        '''
    # transadd
    y = self.yohtoy(batch_ohlb)
    z1y = z1 + y
    '''
        return z1y

    def get_z1y_gmix_transadd(self, z1, batch_ohlb):
        # z1y -> z2
        batch_lb = self.yohtoy_toz2(batch_ohlb)
        z1y = z1 + batch_lb
        if z1y.shape[-1] > 1:
            z1y = self.hbn_z1y(z1y)
        z1y = self.leakyrelu(z1y)
        return z1y

    def z2_rec_z1(self, z1, z2, batch_ohlb):
        z2y = self.get_z2y(z2, batch_ohlb)
        # z2y -> z1
        rec_mu1, rec_logvar1 = self.z2y_z1(z2y)
        rec_z1 = self.z2y_z1.reparameterize(rec_mu1, rec_logvar1)

        logpz1 = multi_diag_normal(z1, rec_mu1, rec_logvar1)
        return -logpz1, rec_mu1, rec_logvar1

    def get_z2y_concat(self, z2, batch_ohlb):
        # reconstruction
        z2y = torch.cat([z2, batch_ohlb], dim=-1)
        return z2y

    def get_z2y_transconcat(self, z2, batch_ohlb):
        batch_lb = self.yohtoy(batch_ohlb)
        z2y = torch.cat([z2, batch_lb], dim=-1)
        return z2y

    def get_z2y_transadd(self, z2, batch_ohlb):
        batch_lb = self.yohtoy_toz1(batch_ohlb)
        z2y = z2 + batch_lb
        if z2y.shape[-1] > 1:
            z2y = self.hbn_z2y(z2y)
        z2y = self.leakyrelu(z2y)
        return z2y

    def get_z2y_gmix(self, z2, batch_ohlb):
        return z2

    def get_z2y_gmix_transadd(self, z2, batch_ohlb):
        return z2

    def enumerate_label(self, batch_uin, batch_uohlb):
        # batch_dup_ulens = np.repeat(batch_ulens, repeats = batch_uohlb.shape[1], axis = 0)
        batch_dup_uin = batch_uin.unsqueeze(1).repeat(
            1, batch_uohlb.shape[1], 1).view(-1, batch_uin.shape[1])
        return batch_dup_uin
Exemplo n.º 7
0
 def init_semicldc_cond_transconcat(self, params):
     # trans one hot y to dense y, then concat z1 and y
     self.yohtoy = nn.Linear(self.label_size, params.z_dim)
     self.z1y_z2 = Inferer(params, in_dim=params.z_dim + params.z_dim)
     self.z2y_z1 = Inferer(params, in_dim=params.z_dim + params.z_dim)
Exemplo n.º 8
0
 def init_semicldc_cond_concat(self, params):
     # concat directly z1 and one hot y
     self.z1y_z2 = Inferer(params, in_dim=params.z_dim + self.label_size)
     self.z2y_z1 = Inferer(params, in_dim=params.z_dim + self.label_size)