예제 #1
0
class Model(nn.Module):
    """
    Implementation of a seq2seq model.
    Architecture:
        Encoder/decoder
        2 LTSM layers
    """
    def __init__(self, w2i, i2w):
        """
        Args:
            args: parameters of the model
            textData: the dataset object
        """
        super(Model, self).__init__()
        print("Model creation...")

        self.word2index = w2i
        self.index2word = i2w
        self.max_length = args['maxLengthDeco']

        self.dtype = 'float32'
        self.NLLloss = torch.nn.NLLLoss(reduction='none')
        self.CEloss = torch.nn.CrossEntropyLoss(reduction='none')

        self.embedding = nn.Embedding(args['vocabularySize'],
                                      args['embeddingSize'])
        self.emo_embedding = nn.Embedding(args['emo_labelSize'],
                                          args['embeddingSize'])
        self.encoder = Encoder(w2i, i2w, self.embedding)
        self.decoder = Decoder(w2i, i2w, self.embedding)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        # self.BERTtokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
        # albert_base_configuration = AlbertConfig(
        #       hidden_size=args['ALBERT_hidden_size'],
        #       num_attention_heads=12,
        #       intermediate_size=3072,
        #   )
        # self.Albert_model = AlbertModel(albert_base_configuration)

    def buildmodel(self, x):
        '''
        :param encoderInputs: [batch, enc_len]
        :param decoderInputs: [batch, dec_len]
        :param decoderTargets: [batch, dec_len]
        :return:
        '''

        # print(x['enc_input'])
        self.encoderInputs = x['enc_input']
        self.encoder_lengths = x['enc_len']
        self.decoderInputs = x['dec_input']
        self.decoder_lengths = x['dec_len']
        self.decoderTargets = x['dec_target']
        self.emo_label = x['emo_label']
        self.batch_size = self.encoderInputs.size()[0]
        '''
        ALBERT: https://huggingface.co/transformers/model_doc/albert.html
        
        last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
        
        Sequence of hidden-states at the output of the last layer of the model.
        
        pooler_output (torch.FloatTensor: of shape (batch_size, hidden_size)):
        
        Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear 
        layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction 
        (classification) objective during pre-training.
        
        This output is usually not a good summary of the semantic content of the input, you’re often better with averaging 
        or pooling the sequence of hidden-states for the whole input sequence.
        
        hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when 
        config.output_hidden_states=True):
        
        Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape 
        (batch_size, sequence_length, hidden_size).
        
        Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        
        attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when 
        config.output_attentions=True):
        
        Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
        
        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        '''

        # ALBERT_input_sentences = x['enc_input_raw']
        # ALBERT_encoded_inputs = self.BERTtokenizer(ALBERT_input_sentences,padding=True, truncation=True,return_tensors="pt")
        # ALBERT_encoded_inputs['input_ids'] = ALBERT_encoded_inputs['input_ids'].to(args['device'])
        # ALBERT_encoded_inputs['token_type_ids'] = ALBERT_encoded_inputs['token_type_ids'].to(args['device'])
        # ALBERT_encoded_inputs['attention_mask'] = ALBERT_encoded_inputs['attention_mask'].to(args['device'])
        # last_hidden_state, pooler_output = self.Albert_model(**ALBERT_encoded_inputs)

        _, en_state = self.encoder(self.encoderInputs, self.encoder_lengths)
        emo_vector = self.embedding(self.emo_label)  # batch * hid
        # info_vector = torch.cat([emo_vector, pooler_output], dim = 1)
        info_vector = emo_vector
        de_outputs = self.decoder(en_state, info_vector, self.decoderInputs,
                                  self.decoder_lengths, self.decoderTargets)

        recon_loss = self.CEloss(torch.transpose(de_outputs, 1, 2),
                                 self.decoderTargets)
        mask = torch.sign(self.decoderTargets.float())
        recon_loss = torch.squeeze(recon_loss) * mask
        recon_loss_mean = torch.mean(recon_loss)

        return recon_loss_mean, en_state, info_vector

    def forward(self, x):
        loss, _, _ = self.buildmodel(x)
        return loss

    def predict(self, x):
        _, en_state, info = self.buildmodel(x)
        de_words = self.decoder.generate(en_state, info)
        return de_words
예제 #2
0
class LSTM_CTE_Model(nn.Module):
    """
    Implementation of a seq2seq model.
    Architecture:
        Encoder/decoder
        2 LTSM layers
    """
    def __init__(self, w2i, i2w, embs):
        """
        Args:
            args: parameters of the model
            textData: the dataset object
        """
        super(LSTM_CTE_Model, self).__init__()
        print("Model creation...")

        self.word2index = w2i
        self.index2word = i2w
        self.max_length = args['maxLengthDeco']

        self.NLLloss = torch.nn.NLLLoss(reduction='none')
        self.CEloss = torch.nn.CrossEntropyLoss(reduction='none')

        self.embedding = nn.Embedding.from_pretrained(embs)

        self.encoder = Encoder(w2i, i2w)
        self.encoder2 = Encoder(w2i, i2w, inputsize=args['hiddenSize'])

        self.decoder = Decoder(w2i, i2w, self.embedding)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

        self.att_size_r = 60
        self.grm = GaussianOrthogonalRandomMatrix()
        self.att_projection_matrix = Parameter(
            self.grm.get_2d_array(args['hiddenSize'], self.att_size_r))
        self.M = Parameter(
            torch.randn([args['hiddenSize'], args['embeddingSize']]))

        # self.z_to_fea = nn.Linear(args['hiddenSize'], args['hiddenSize']).to(args['device'])
        self.SentenceClassifier = nn.Sequential(
            nn.Linear(args['hiddenSize'], 1), nn.Sigmoid())

    def sample_z(self, mu, log_var, batch_size):
        eps = Variable(
            torch.randn(batch_size, args['style_len'] * 2 *
                        args['numLayers'])).to(args['device'])
        return mu + torch.einsum('ba,ba->ba', torch.exp(log_var / 2), eps)

    def cos(self, x1, x2):
        '''
        :param x1: batch seq emb
        :param x2:
        :return:
        '''
        xx = torch.einsum('bse,bte->bst', x1, x2)
        x1n = torch.norm(x1, dim=-1, keepdim=True)
        x2n = torch.norm(x2, dim=-1, keepdim=True)
        xd = torch.einsum('bse,bte->bst', x1n, x2n).clamp(min=0.0001)
        return xx / xd

    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape).to(args['device'])
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature):
        y = logits + self.sample_gumbel(logits.size())
        return F.softmax(y / temperature, dim=-1)

    def gumbel_softmax(self, logits, temperature=args['temperature']):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        y_hard = (y_hard - y).detach() + y
        return y_hard, y

    def build(self, x, eps=1e-6):
        '''
        :param encoderInputs: [batch, enc_len]
        :param decoderInputs: [batch, dec_len]
        :param decoderTargets: [batch, dec_len]
        :return:
        '''

        # D,Q -> s: P(s|D,Q)
        context_inputs = torch.LongTensor(x.contextSeqs).to(args['device'])
        q_inputs = torch.LongTensor(x.questionSeqs).to(args['device'])
        answer_dec = torch.LongTensor(x.decoderSeqs).to(args['device'])
        answer_tar = torch.LongTensor(x.targetSeqs).to(args['device'])
        context_mask = torch.FloatTensor(x.context_mask).to(
            args['device'])  # batch sentence
        sentence_mask = torch.FloatTensor(x.sentence_mask).to(
            args['device'])  # batch sennum contextlen

        opt_inputs = []
        for i in range(4):
            opt_inputs.append(x.optionSeqs[i])
        answerlabel = torch.LongTensor(x.label).to(args['device'])

        batch_size = context_inputs.size()[0]
        context_inputs_embs = self.embedding(context_inputs)
        q_inputs_embs = self.embedding(q_inputs)

        en_context_output, (en_context_hidden, en_context_cell) = self.encoder(
            context_inputs_embs)  # b s e
        en_q_output, (en_q_hidden, en_q_cell) = self.encoder(q_inputs_embs)

        # en_context_output_flat = en_context_output.transpose(0,1).reshape(batch_size,-1)
        # en_q_output_flat = en_q_output.transpose(0,1).reshape(batch_size,-1)

        c_q = torch.cat([en_context_output, en_q_output], dim=1)

        attentioned_context = dot_product_attention(
            query=c_q,
            key=c_q,
            value=c_q,
            projection_matrix=self.att_projection_matrix)  # b s h

        c_after_att, q_after_att = attentioned_context.split(
            [en_context_output.shape[1], en_q_output.shape[1]], dim=1)
        # print(attentioned_context)
        # exit()

        # opt_input_embed =[]
        # for i in range(4):
        #     opt_input_embed.append(self.embedding(opt_inputs[i]))
        #
        # cos_context_q = self.cos(context_inputs_embs, q_inputs_embs) # b c q
        # # cos_context_q, _ = torch.max(cos_context_q, dim = 1)
        # # M_q = torch.mean(cos_context_q, dim = 1) # b q(e)
        # att_con_q = self.softmax(cos_context_q) # b c q
        #
        # att_con = torch.einsum('bcq,bqe->bce', att_con_q, q_inputs_embs)
        # cos_context_a = []
        # M_a = []
        # for i in range(4):
        #     # print(cos_context_q.size(),opt_input_embed[i].size())
        #     cos_context_a.append(self.cos(att_con, opt_input_embed[i])) # b c a
        #     # print(cos_context_a[i].size())
        #     cos_context_a[i], _ = torch.max(cos_context_a[i], dim = 1)
        #     # print(cos_context_a[i].size())
        #     M_a.append(torch.mean(cos_context_a[i], dim = 1)) # b q(e)
        # M_as = torch.stack(M_a) # 4 b
        # # print(M_q.size(), M_as.size())
        # # scores = self.ChargeClassifier(M_q.unsqueeze(0) + M_as).transpose(0,1) # b 4
        # scores = M_as.transpose(0,1)

        # coatt = torch.einsum('bse,bte->bts', en_context_output, en_q_output)
        # coatt = self.softmax(coatt)
        # coatt2 = torch.einsum('bse,bte->bst', en_context_output, en_q_output)
        # coatt2 = self.softmax(coatt2)
        # q_info = torch.einsum('bts,bse->bte', coatt, en_context_output)
        # q_info_cat = torch.cat([q_info, q_inputs_embs], dim = 2) # b q e
        # q_info_cat_info = torch.einsum('bst,bte->bse', coatt2, q_info_cat)
        # q_info_cat_info_con = torch.cat([q_info_cat_info, context_inputs_embs], dim = 2) # b q e
        # # print(q_info_cat_info.size())
        # out_info, _ = self.encoder2(q_info_cat_info_con) # b c e
        out_info, _ = self.encoder2(c_after_att)  # b c e

        sentence_embs = torch.einsum(
            'bce,bsc->bse', out_info,
            sentence_mask) / (sentence_mask.sum(2, keepdim=True) + eps)
        # print(sentence_embs)
        raw_sentence_logits = self.SentenceClassifier(sentence_embs).squeeze()
        # print(raw_sentence_logits.size(), context_mask.size())
        sentence_logits = raw_sentence_logits * context_mask + (
            1 - context_mask) * (-1e30)  # batch sentence
        sentence_sample, _ = self.gumbel_softmax(sentence_logits)
        # print(sentence_embs.size(), sentence_sample.size())
        decoder_input = torch.einsum('bse,bs->be', sentence_embs,
                                     sentence_sample.squeeze())
        en_state = self.decoder.vector2state(decoder_input)
        # print(en_state)
        de_outputs = self.decoder(en_state, answer_dec, answer_tar)
        # print(de_outputs)
        recon_loss = self.CEloss(torch.transpose(de_outputs, 1, 2), answer_dec)
        mask = torch.sign(answer_tar.float())
        recon_loss = torch.squeeze(recon_loss) * mask

        recon_loss_mean = torch.mean(recon_loss)

        opt_vec = []
        # opt_input_embed =[]
        for i in range(4):
            # opt_input_embed.append(self.embedding(opt_inputs[i]))
            opt1 = []
            for j in range(batch_size):
                embs = self.embedding(
                    torch.LongTensor(opt_inputs[i][j]).to(args['device']))
                # print(embs.size())
                opt1.append(torch.mean(embs, dim=0))
            # print(torch.stack(opt1).size())
            # exit()
            opt_vec.append(torch.stack(opt1))  # batch dim
        #
        opt_vec_stack = torch.stack(opt_vec)  # 4 batch dim

        mul1 = torch.einsum('be,er->br', decoder_input, self.M)
        scores = torch.einsum('obe,be->bo', opt_vec_stack, mul1)
        recon_loss_c = self.CEloss(scores, answerlabel)

        loss = recon_loss_mean + recon_loss_c.mean()
        return loss, en_state, scores, sentence_logits

    def forward(self, x):
        loss, _, _, _ = self.build(x)
        return loss

    def predict(self, x):
        loss, en_state, output, sentence_probs = self.build(x)
        de_words = self.decoder.generate(en_state)
        return loss, de_words, torch.argmax(output, dim=-1), torch.argmax(
            sentence_probs, dim=-1)
예제 #3
0
class LSTM_CTE_Model(nn.Module):
    """
    Implementation of a seq2seq model.
    Architecture:
        Encoder/decoder
        2 LTSM layers
    """
    def __init__(self, w2i, i2w, embs=None, title_emb=None):
        """
        Args:
            args: parameters of the model
            textData: the dataset object
        """
        super(LSTM_CTE_Model, self).__init__()
        print("Model creation...")

        self.word2index = w2i
        self.index2word = i2w
        self.max_length = args['maxLengthDeco']

        self.NLLloss = torch.nn.NLLLoss(ignore_index=0)
        self.CEloss = torch.nn.CrossEntropyLoss(ignore_index=0)

        if embs is not None:
            self.embedding = nn.Embedding.from_pretrained(embs)
        else:
            self.embedding = nn.Embedding(args['vocabularySize'],
                                          args['embeddingSize'])

        if title_emb is not None:
            self.field_embedding = nn.Embedding.from_pretrained(title_emb)
        else:
            self.field_embedding = nn.Embedding(args['TitleNum'],
                                                args['embeddingSize'])

        self.encoder = Encoder(w2i, i2w, bidirectional=True)
        # self.encoder_answer_only = Encoder(w2i, i2w)
        self.encoder_no_answer = Encoder(w2i, i2w)
        self.encoder_pure_answer = Encoder(w2i, i2w)

        self.decoder_answer = Decoder(w2i,
                                      i2w,
                                      self.embedding,
                                      copy='pure',
                                      max_dec_len=10)
        self.decoder_no_answer = Decoder(w2i,
                                         i2w,
                                         self.embedding,
                                         input_dim=args['embeddingSize'] * 2,
                                         copy='semi')

        self.ansmax2state_h = nn.Linear(args['embeddingSize'],
                                        args['hiddenSize'] * 2,
                                        bias=False)
        self.ansmax2state_c = nn.Linear(args['embeddingSize'],
                                        args['hiddenSize'] * 2,
                                        bias=False)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

        self.att_size_r = 60
        # self.grm = GaussianOrthogonalRandomMatrix()
        # self.att_projection_matrix = Parameter(self.grm.get_2d_array(args['embeddingSize'], self.att_size_r))
        self.M = Parameter(
            torch.randn([args['embeddingSize'], args['hiddenSize'] * 2, 2]))

        self.shrink_copy_input = nn.Linear(args['hiddenSize'] * 2,
                                           args['hiddenSize'],
                                           bias=False)
        self.emb2hid = nn.Linear(args['embeddingSize'],
                                 args['hiddenSize'],
                                 bias=False)
        # self.z_logit2prob = nn.Sequential(
        #     nn.Linear(args['hiddenSize'], 2)
        # )

        # self.z_to_fea = nn.Linear(args['hiddenSize'], args['hiddenSize']).to(args['device'])
        # self.SEClassifier = nn.Sequential(
        #     nn.Linear(args['hiddenSize'], 2),
        #     nn.Sigmoid()
        # )
        #
        # self.SentenceClassifier = nn.Sequential(
        #     nn.Linear(args['hiddenSize'], 1),
        #     nn.Sigmoid()
        # )

    def sample_z(self, mu, log_var, batch_size):
        eps = Variable(
            torch.randn(batch_size, args['style_len'] * 2 *
                        args['numLayers'])).to(args['device'])
        return mu + torch.einsum('ba,ba->ba', torch.exp(log_var / 2), eps)

    def cos(self, x1, x2):
        '''
        :param x1: batch seq emb
        :param x2:
        :return:
        '''
        xx = torch.einsum('bse,bte->bst', x1, x2)
        x1n = torch.norm(x1, dim=-1, keepdim=True)
        x2n = torch.norm(x2, dim=-1, keepdim=True)
        xd = torch.einsum('bse,bte->bst', x1n, x2n).clamp(min=0.0001)
        return xx / xd

    def sample_gumbel(self, shape, eps=1e-20):
        U = torch.rand(shape).to(args['device'])
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature):
        y = logits + self.sample_gumbel(logits.size())
        return F.softmax(y / temperature, dim=-1)

    def gumbel_softmax(self, logits, temperature=args['temperature']):
        """
        ST-gumple-softmax
        input: [*, n_class]
        return: flatten --> [*, n_class] an one-hot vector
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        shape = y.size()
        _, ind = y.max(dim=-1)
        y_hard = torch.zeros_like(y).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(*shape)
        y_hard = (y_hard - y).detach() + y
        return y_hard, y

    def get_pretrain_parameters(self):
        return list(self.embedding.parameters()) + list(
            self.encoder.parameters()) + list(
                self.decoder_no_answer.parameters())

    def build(self, x, mode, eps=1e-6):
        '''
        :param encoderInputs: [batch, enc_len]
        :param decoderInputs: [batch, dec_len]
        :param decoderTargets: [batch, dec_len]
        :return:
        '''

        # D,Q -> s: P(s|D,Q)
        context_inputs = torch.LongTensor(x.contextSeqs).to(args['device'])
        field = torch.LongTensor(x.field).to(args['device'])
        answer_dec = torch.LongTensor(x.decoderSeqs).to(args['device'])
        answer_tar = torch.LongTensor(x.targetSeqs).to(args['device'])
        context_dec = torch.LongTensor(x.ContextDecoderSeqs).to(args['device'])
        context_tar = torch.LongTensor(x.ContextTargetSeqs).to(args['device'])
        pure_answer = torch.LongTensor(x.answerSeqs).to(args['device'])
        # context_mask = torch.FloatTensor(x.context_mask).to(args['device'])  # batch sentence
        # sentence_mask = torch.FloatTensor(x.sentence_mask).to(args['device'])  # batch sennum contextlen
        # start_positions = torch.FloatTensor(x.starts).to(args['device'])
        # end_positions = torch.FloatTensor(x.ends).to(args['device'])
        # ans_context_input = torch.LongTensor(x.ans_contextSeqs).to(args['device'])
        # ans_context_mask = torch.LongTensor(x.ans_con_mask).to(args['device'])

        pure_answer_embs = self.embedding(pure_answer)
        # print(' context_inputs: ', context_inputs[0])
        # print(' context_dec: ', context_dec[0])
        # print(' context_tar: ', context_tar[0])

        mask = torch.sign(context_inputs).float()
        mask_pure_answer = torch.sign(pure_answer).float()

        batch_size = context_inputs.size()[0]
        seq_len = context_inputs.size()[1]
        context_inputs_embs = self.embedding(context_inputs)
        q_inputs_embs = self.field_embedding(field)  #.unsqueeze(1) # batch emb
        #
        # attentioned_context = dot_product_attention(query=q_inputs_embs.unsqueeze(1), key=context_inputs_embs, value=context_inputs_embs,
        #                                             projection_matrix=self.att_projection_matrix)  # b s h

        en_context_output, en_context_state = self.encoder(
            context_inputs_embs)  # b s e

        # print(q_inputs_embs.size(), en_context_output.size())
        att1 = self.tanh(torch.einsum('be,ehc->bhc', q_inputs_embs, self.M))
        # print(att1.size(), en_context_output.size())
        z_logit = torch.einsum('bhc,bsh->bsc', att1, en_context_output)
        # z_embs = self.tanh(self.q_att_layer(q_inputs_embs) + self.c_att_layer(en_context_output)) # b s h
        # z_logit = self.z_logit2prob(z_embs).squeeze() # b s 2
        # z_logit = torch.cat([1-z_logit_1, z_logit_1], dim = 2)
        z_logit_fla = z_logit.reshape((batch_size * seq_len, 2))
        z_prob = self.softmax(z_logit)
        if mode == 'train':
            sampled_seq, sampled_seq_soft = self.gumbel_softmax(z_logit_fla)
            sampled_seq = sampled_seq.reshape((batch_size, seq_len, 2))
            sampled_seq_soft = sampled_seq_soft.reshape(
                (batch_size, seq_len, 2))
            sampled_seq = sampled_seq * mask.unsqueeze(2)
            sampled_seq_soft = sampled_seq_soft * mask.unsqueeze(2)
        else:
            sampled_seq = (z_prob > 0.5).float() * mask.unsqueeze(2)

        gold_ans_mask, _ = (
            context_inputs.unsqueeze(2) == pure_answer.unsqueeze(1)).max(2)

        if mode == 'train':
            ans_mask, _ = (
                context_inputs.unsqueeze(2) == pure_answer.unsqueeze(1)).max(2)
            noans_mask = 1 - ans_mask.int()
            # print(noans_mask)
        else:
            ans_mask = sampled_seq[:, :, 1].int()
            noans_mask = sampled_seq[:, :, 0].int()

        answer_only_sequence = context_inputs_embs * sampled_seq[:, :,
                                                                 1].unsqueeze(
                                                                     2)
        no_answer_sequence = context_inputs_embs * noans_mask.unsqueeze(
            2)  #.detach()
        # ANS_END = torch.LongTensor([5] * batch_size).to(args['device'])
        # ANS_END = ANS_END.unsqueeze(1)
        # ANS_END_emb = self.embedding(ANS_END)
        # no_answer_sequence = torch.cat([pure_answer_embs, ANS_END_emb, no_answer_sequence], dim = 1)

        answer_only_logp_z0 = torch.log(z_prob[:, :, 0].clamp(
            eps, 1.0))  # [B,T], log P(z = 0 | x)
        answer_only_logp_z1 = torch.log(z_prob[:, :, 1].clamp(
            eps, 1.0))  # [B,T], log P(z = 1 | x)
        # answer_only_logpz = (1-sampled_seq[:, :, 1]) * answer_only_logp_z0 + sampled_seq[:, :, 1] * answer_only_logp_z1
        answer_only_logpz = torch.where(sampled_seq[:, :, 1] == 0,
                                        answer_only_logp_z0,
                                        answer_only_logp_z1)
        # no_answer_logpz = torch.where(sampled_seq[:, :, 1] == 0,answer_only_logp_z1, answer_only_logp_z0)
        answer_only_logpz = mask * answer_only_logpz
        # no_answer_logpz = mask * no_answer_logpz

        # answer_only_output, answer_only_state = self.encoder_answer_only(answer_only_sequence)
        answer_only_info, _ = torch.max(answer_only_sequence, dim=1)
        # print(answer_only_info.size())
        answer_only_state = (self.ansmax2state_h(answer_only_info).reshape([
            batch_size, args['numLayers'], args['hiddenSize']
        ]), self.ansmax2state_c(answer_only_info).reshape(
            [batch_size, args['numLayers'], args['hiddenSize']]))
        answer_only_state = (answer_only_state[0].transpose(0, 1).contiguous(),
                             answer_only_state[1].transpose(0, 1).contiguous())

        no_answer_output, no_answer_state = self.encoder_no_answer(
            no_answer_sequence)
        # no_answer_output, no_answer_state = self.encoder_no_answer(context_inputs_embs)

        en_context_output_shrink = self.shrink_copy_input(
            en_context_output)  # bsh
        # answer_latent_emb,_ = torch.max(answer_only_output)
        enc_onehot = F.one_hot(context_inputs,
                               num_classes=args['vocabularySize'])
        answer_de_output = self.decoder_answer(
            answer_only_state,
            answer_dec,
            answer_tar,
            enc_embs=en_context_output_shrink,
            enc_mask=mask,
            enc_onehot=enc_onehot)
        answer_recon_loss = self.NLLloss(
            torch.transpose(answer_de_output, 1, 2), answer_tar)
        # answer_mask = torch.sign(answer_tar.float())
        # answer_recon_loss = torch.squeeze(answer_recon_loss) * answer_mask
        answer_recon_loss_mean = answer_recon_loss  #torch.mean(answer_recon_loss, dim = 1)
        #
        ######################## no_answer do not contain answer #####################
        pred_no_answer_seq = context_inputs_embs * sampled_seq[:, :,
                                                               0].unsqueeze(2)
        cross_len = torch.abs(
            torch.einsum('bse,bae->bsa', pred_no_answer_seq, pure_answer_embs)
            * mask.unsqueeze(2)) / (
                torch.norm(pred_no_answer_seq, dim=2).unsqueeze(2) +
                eps) / (torch.norm(pure_answer_embs, dim=2).unsqueeze(1) + eps)
        # print(cross_len)
        # print(torch.max(cross_len))
        # print(torch.mean(cross_len))
        # exit()
        cross_sim = torch.mean(cross_len)

        # ################### no-answer context  + answer info -> origin context  #############
        # pure_answer_output, pure_answer_state = self.encoder_pure_answer(pure_answer_embs)
        # pure_answer_output = torch.mean(pure_answer_embs, dim = 1, keepdim=True)
        pure_answer_mask = torch.sign(pure_answer).float()
        pure_answer_output = torch.sum(
            pure_answer_embs * pure_answer_mask.unsqueeze(2),
            dim=1,
            keepdim=True) / torch.sum(pure_answer_mask, dim=1,
                                      keepdim=True).unsqueeze(2)
        # no_ans_plus_pureans_state = (torch.cat([no_answer_state[0], pure_answer_state[0]], dim = 2),
        #                              torch.cat([no_answer_state[1], pure_answer_state[1]], dim=2))
        en_context_output_plus = torch.cat([
            en_context_output_shrink * noans_mask.unsqueeze(2),
            self.emb2hid(pure_answer_embs)
        ],
                                           dim=1)
        mask_plus = torch.cat([mask, mask_pure_answer], dim=1)
        enc_onehot_plus = F.one_hot(torch.cat(
            [context_inputs * noans_mask, pure_answer], dim=1),
                                    num_classes=args['vocabularySize'])
        pa_mask, _ = F.one_hot(pure_answer,
                               num_classes=args['vocabularySize']).max(
                                   1)  # batch voc
        pa_mask[:, 0] = 0
        pa_mask = pa_mask.detach()
        # print(pa_mask)
        context_de_output = self.decoder_no_answer(
            no_answer_state,
            context_dec,
            context_tar,
            cat=pure_answer_output,
            enc_embs=en_context_output_plus,
            enc_mask=mask_plus,
            enc_onehot=enc_onehot_plus,
            lstm_mask=pa_mask)
        # context_de_output = self.decoder_no_answer(no_answer_state, context_dec, context_tar, cat=None, enc_embs = en_context_output_plus, enc_mask=mask_plus, enc_onehot = enc_onehot_plus, lstm_mask = pa_mask)
        # context_de_output = self.decoder_no_answer(en_context_state, context_dec, context_tar)#, cat=torch.max(pure_answer_output, dim = 1, keepdim=True)[0])
        context_recon_loss = self.NLLloss(
            torch.transpose(context_de_output, 1, 2), context_tar)
        # context_mask = torch.sign(context_tar.float())
        # context_recon_loss = torch.squeeze(context_recon_loss) * context_mask
        context_recon_loss_mean = context_recon_loss  #torch.mean(context_recon_loss, dim = 1)

        I_x_z = torch.abs(
            torch.mean(-torch.log(z_prob[:, :, 0] + eps), 1) + np.log(0.5))
        # I_x_z = torch.abs(torch.mean(torch.log(z_prob[:, :, 1]+eps), 1) -np.log(0.1))

        loss = 10 * I_x_z.mean() + answer_recon_loss_mean.mean(
        ) + context_recon_loss_mean + cross_sim * 100  #+ ((answer_recon_loss_mean.detach() )* answer_only_logpz.mean(1)).mean()
        # print(loss, 100 * I_x_z.mean(), answer_recon_loss_mean.mean(), context_recon_loss_mean, cross_sim)
        #    + context_recon_loss_mean.detach() * no_answer_logpz.mean(1)).mean()
        # loss = context_recon_loss_mean.mean()
        self.tt = [
            answer_recon_loss_mean.mean(), context_recon_loss_mean,
            (sampled_seq[:, :, 1].sum(1) * 1.0 / mask.sum(1)).mean(), cross_sim
        ]
        # self.tt = [context_recon_loss_mean.mean(),]
        # self.tt = [answer_recon_loss_mean.mean() , (sampled_seq[:,:,1].sum(1)*1.0/ mask.sum(1)).mean()]
        # return loss, answer_only_state, no_answer_state, pure_answer_output, (sampled_seq[:,:,1].sum(1)*1.0/ mask.sum(1)).mean(), sampled_seq[:,:,1], \
        #        en_context_output_shrink, mask,  enc_onehot, en_context_output_plus, mask_plus, enc_onehot_plus
        return {
            'loss': loss,
            'answer_only_state': answer_only_state,
            'no_answer_state': no_answer_state,
            'pure_answer_output': pure_answer_output,
            'closs': (sampled_seq[:, :, 1].sum(1) * 1.0 / mask.sum(1)).mean(),
            'sampled_words': sampled_seq[:, :, 1],
            'en_context_output': en_context_output_shrink,
            'mask': mask,
            'enc_onehot': enc_onehot,
            'en_context_output_plus': en_context_output_plus,
            'mask_plus': mask_plus,
            'enc_onehot_plus': enc_onehot_plus,
            'context_inputs': context_inputs,
            'context_inputs_embs': context_inputs_embs,
            'ans_mask': ans_mask,
            'gold_ans_mask': gold_ans_mask
        }

    def forward(self, x):
        data = self.build(x, mode='train')
        return data['loss'], data['closs']

    def predict(self, x):
        data = self.build(x, mode='train')
        de_words_answer = []
        if data['answer_only_state'] is not None:
            de_words_answer = self.decoder_answer.generate(
                data['answer_only_state'],
                enc_embs=data['en_context_output'],
                enc_mask=data['mask'],
                enc_onehot=data['enc_onehot'])
        de_words_context = self.decoder_no_answer.generate(
            data['no_answer_state'],
            cat=data['pure_answer_output'],
            enc_embs=data['en_context_output_plus'],
            enc_mask=data['mask_plus'],
            enc_onehot=data['enc_onehot_plus'])
        # de_words_context = self.decoder_no_answer.generate(no_answer_state, cat = None, enc_embs = en_context_output_plus, enc_mask=mask_plus, enc_onehot = enc_onehot_plus)

        return data['loss'], de_words_answer, de_words_context, data[
            'sampled_words'], data['gold_ans_mask'], data['mask']

    def pre_training_forward(self, x, eps=1e-6):
        context_inputs = torch.LongTensor(x.contextSeqs).to(args['device'])
        context_dec = torch.LongTensor(x.ContextDecoderSeqs).to(args['device'])
        context_tar = torch.LongTensor(x.ContextTargetSeqs).to(args['device'])
        context_inputs_embs = self.embedding(context_inputs)
        en_context_output, en_context_state = self.encoder(
            context_inputs_embs)  # b s e

        mask = torch.sign(context_inputs).float()
        enc_onehot = F.one_hot(context_inputs,
                               num_classes=args['vocabularySize'])
        batch_size = context_inputs.size()[0]

        context_de_output = self.decoder_no_answer(
            en_context_state,
            context_dec,
            context_tar,
            cat=torch.zeros([batch_size, 1,
                             args['embeddingSize']]).to(args['device']),
            enc_embs=en_context_output,
            enc_mask=mask,
            enc_onehot=enc_onehot)
        context_recon_loss = self.NLLloss(
            torch.transpose(context_de_output, 1, 2), context_tar)
        context_mask = torch.sign(context_tar.float())
        context_recon_loss = torch.squeeze(context_recon_loss) * context_mask
        context_recon_loss_mean = torch.mean(context_recon_loss, dim=1)
        return context_recon_loss_mean