Esempio n. 1
0
    def forward(self, batch_input, lang):
        word_batch = get_torch_variable_from_np(batch_input['word'])
        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        else:
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        if lang == "En":
            word_emb = self.word_embedding(word_batch)
        else:
            word_emb = self.fr_word_embedding(word_batch)

        input_emb = torch.cat((pretrain_emb, word_emb), 2)
        input_emb = self.word_dropout(input_emb)
        seq_len = input_emb.shape[1]
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        hidden_input = self.out_dropout(hidden_input)
        output = self.scorer(hidden_input)
        output = output.view(self.batch_size * seq_len, -1)

        return output
Esempio n. 2
0
    def __init__(self, model_params):
        super(SR_Matcher, self).__init__()
        self.dropout_word = nn.Dropout(p=0.3)
        self.mlp_size = 300
        self.dropout_mlp = model_params['dropout_mlp']
        self.batch_size = model_params['batch_size']

        self.target_vocab_size = model_params['target_vocab_size']
        self.use_flag_embedding = model_params['use_flag_embedding']
        self.flag_emb_size = model_params['flag_embedding_size']
        self.pretrain_emb_size = 768  #model_params['pretrain_emb_size']

        self.bilstm_num_layers = model_params['bilstm_num_layers']
        self.bilstm_hidden_size = model_params['bilstm_hidden_size']
        self.emb2vector = nn.Sequential(
            nn.Linear(self.pretrain_emb_size + 0 * self.flag_emb_size, 200),
            nn.Tanh())
        self.matrix = nn.Parameter(
            get_torch_variable_from_np(np.zeros((200, 200)).astype("float32")))

        self.probs2vector = nn.Sequential(
            nn.Linear(self.target_vocab_size, self.target_vocab_size),
            nn.Tanh())

        self.vector2probs = nn.Sequential(
            nn.Linear(self.target_vocab_size, self.target_vocab_size),
            nn.Tanh(),
            nn.Linear(self.target_vocab_size, self.target_vocab_size - 1))
Esempio n. 3
0
    def learn_loss(self, output_SRL, pretrain_emb, flag_emb, seq_len, mask_copy, mask_unk):
        SRL_input = output_SRL.view(self.batch_size, seq_len, -1)
        output = SRL_input.view(self.batch_size*seq_len, -1)
        SRL_input = F.softmax(SRL_input, 2).detach()

        pred = torch.max(SRL_input, dim=2)[1]
        for i in range(self.batch_size):
            for j in range(seq_len):
                if pred[i][j] > 1:
                    mask_copy[i][j] = 1

        mask_copy = get_torch_variable_from_np(mask_copy)
        mask_final = mask_copy.view(self.batch_size * seq_len) * mask_unk.view(self.batch_size * seq_len)
        pred_recur = self.SR_Compressor(SRL_input, pretrain_emb,
                                        flag_emb.detach(), None, None, seq_len, para=False, use_bert=True)

        output_word = self.SR_Matcher(pred_recur, pretrain_emb, flag_emb.detach(), None, seq_len, copy = True,
                                      para=False, use_bert=True)

        score4Null = torch.zeros_like(output_word[:, 1:2])
        output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1)

        criterion = nn.CrossEntropyLoss(reduction='none')
        _, prediction_batch_variable = torch.max(output, 1)

        loss_word = criterion(output_word, prediction_batch_variable)*mask_final

        loss_word = loss_word.sum()/(self.batch_size*seq_len)


        return loss_word
Esempio n. 4
0
    def forward(self,
                role_vectors,
                pretrained_emb,
                word_id_emb,
                seq_len,
                para=False):
        query_vector = torch.cat((pretrained_emb, word_id_emb), 2)
        role_vectors = role_vectors.view(self.batch_size,
                                         self.target_vocab_size, 200)
        # B T R V
        role_vectors = role_vectors.unsqueeze(1).expand(
            self.batch_size, seq_len, self.target_vocab_size, 200)
        # B T R W
        query_vector = query_vector.unsqueeze(2).expand(
            self.batch_size, seq_len, self.target_vocab_size,
            self.pretrain_emb_size + self.flag_emb_size)
        # B T R V
        y = torch.mm(
            query_vector.contiguous().view(
                self.batch_size * seq_len * self.target_vocab_size, -1),
            self.matrix)
        # B T R
        y = y.contiguous().view(self.batch_size, seq_len,
                                self.target_vocab_size, 200)
        roles_scores = torch.sum(role_vectors * y, dim=3)
        zerosNull = get_torch_variable_from_np(
            np.zeros((self.batch_size, seq_len, 1), dtype='float32'))
        roles_scores = roles_scores.view(self.batch_size, seq_len, -1)

        output_word = torch.cat(
            (roles_scores[:, :, 0:1], zerosNull, roles_scores[:, :, 2:]), 2)
        output_word = output_word.view(self.batch_size * seq_len, -1)

        return output_word
Esempio n. 5
0
    def __init__(self, model_params):
        super(SR_Matcher, self).__init__()
        #self.dropout_word_1 = nn.Dropout(p=0.0)
        #self.dropout_word_2 = nn.Dropout(p=0.0)
        self.mlp_size = 300
        #self.dropout_mlp = model_params['dropout_mlp']
        self.batch_size = model_params['batch_size']

        self.target_vocab_size = model_params['target_vocab_size']
        self.use_flag_embedding = model_params['use_flag_embedding']
        self.flag_emb_size = model_params['flag_embedding_size']
        self.pretrain_emb_size = model_params['pretrain_emb_size']

        self.bilstm_num_layers = model_params['bilstm_num_layers']
        self.bilstm_hidden_size = model_params['bilstm_hidden_size']
        self.bert_size = 768
        self.base_emb2vector = nn.Sequential(nn.Linear(self.bert_size, 300),
                                             nn.Tanh())

        self.query_emb2vector = nn.Sequential(nn.Linear(self.bert_size, 300),
                                              nn.Tanh())

        self.probs2vector = nn.Sequential(
            nn.Linear(self.target_vocab_size, 100), nn.Tanh(),
            nn.Linear(100, 50), nn.Tanh())

        self.vector2scores = nn.Sequential(
            nn.Linear(50, 30), nn.Tanh(), nn.Linear(30,
                                                    self.target_vocab_size))

        self.matrix = nn.Parameter(
            get_torch_variable_from_np(np.zeros((300, 300)).astype("float32")))
Esempio n. 6
0
def Input4Gan_1(output, predicates_1D):
    output = F.softmax(output, dim=2)
    seq_len = output.shape[1]
    shuffled_timestep = np.arange(0, seq_len)
    np.random.shuffle(shuffled_timestep)
    #Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
    shuffled_output = output.index_select(
        dim=1, index=get_torch_variable_from_np(shuffled_timestep))
    return shuffled_output
Esempio n. 7
0
    def forward(self, batch_input, lang='En', unlabeled=False):
        if unlabeled:

            loss = self.parallel_train(batch_input)

            loss_word = 0

            return loss, loss_word

        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_id = get_torch_variable_from_np(batch_input['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        else:
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        seq_len = flag_emb.shape[1]
        SRL_output = self.SR_Labeler(pretrain_emb,
                                     flag_emb,
                                     predicates_1D,
                                     seq_len,
                                     para=False)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        SRL_input = SRL_input.detach()
        pred_recur = self.SR_Compressor(SRL_input,
                                        pretrain_emb,
                                        word_id_emb,
                                        seq_len,
                                        para=False)

        output_word = self.SR_Matcher(pred_recur,
                                      pretrain_emb,
                                      word_id_emb.detach(),
                                      seq_len,
                                      para=False)
        return SRL_output, output_word
Esempio n. 8
0
    def __init__(self, model_params):
        super(SR_Labeler, self).__init__()
        self.dropout_word = nn.Dropout(p=0.5)
        self.dropout_hidden = nn.Dropout(p=0.3)
        self.dropout_mlp = model_params['dropout_mlp']
        self.batch_size = model_params['batch_size']

        self.target_vocab_size = model_params['target_vocab_size']
        self.use_flag_embedding = model_params['use_flag_embedding']
        self.flag_emb_size = model_params['flag_embedding_size']
        self.pretrain_emb_size = model_params['pretrain_emb_size']

        self.bilstm_num_layers = model_params['bilstm_num_layers']
        self.bilstm_hidden_size = model_params['bilstm_hidden_size']

        if USE_CUDA:
            self.bilstm_hidden_state = (Variable(torch.randn(
                2 * self.bilstm_num_layers, self.batch_size,
                self.bilstm_hidden_size),
                                                 requires_grad=True).cuda(),
                                        Variable(torch.randn(
                                            2 * self.bilstm_num_layers,
                                            self.batch_size,
                                            self.bilstm_hidden_size),
                                                 requires_grad=True).cuda())
        else:
            self.bilstm_hidden_state = (Variable(torch.randn(
                2 * self.bilstm_num_layers, self.batch_size,
                self.bilstm_hidden_size),
                                                 requires_grad=True),
                                        Variable(torch.randn(
                                            2 * self.bilstm_num_layers,
                                            self.batch_size,
                                            self.bilstm_hidden_size),
                                                 requires_grad=True))
        self.bilstm_layer = nn.LSTM(input_size=768 + 1 * self.flag_emb_size,
                                    hidden_size=self.bilstm_hidden_size,
                                    num_layers=self.bilstm_num_layers,
                                    bidirectional=True,
                                    bias=True,
                                    batch_first=True)

        self.mlp_size = 300
        self.rel_W = nn.Parameter(
            get_torch_variable_from_np(
                np.zeros((self.mlp_size + 1, self.target_vocab_size *
                          (self.mlp_size + 1))).astype("float32")))
        self.mlp_arg = nn.Sequential(
            nn.Linear(2 * self.bilstm_hidden_size, self.mlp_size), nn.ReLU())
        self.mlp_pred = nn.Sequential(
            nn.Linear(2 * self.bilstm_hidden_size, self.mlp_size), nn.ReLU())
Esempio n. 9
0
    def forward(self, tree, inputs): #, tree_hidden
        for idx in range(tree.num_children):
            self.forward(tree.children[idx], inputs) #, tree_hidden

        if tree.num_children == 0:
            child_h = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.))
            child_deprels = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.))
        else:
            child_h = [x.state for x in tree.children]
            child_h = torch.cat(child_h, dim=0)
            child_deprels = self.deprel_emb[get_torch_variable_from_np(np.array([c.deprel for c in tree.children]))]

        tree.state = self.node_forward(inputs[tree.idx], child_h, child_deprels)

        return tree.state
Esempio n. 10
0
    def __init__(self, model_params):
        super(SR_Matcher, self).__init__()
        self.dropout_word = nn.Dropout(p=0.0)
        self.mlp_size = 300
        self.dropout_mlp = model_params['dropout_mlp']
        self.batch_size = model_params['batch_size']

        self.target_vocab_size = model_params['target_vocab_size']
        self.use_flag_embedding = model_params['use_flag_embedding']
        self.flag_emb_size = model_params['flag_embedding_size']
        self.pretrain_emb_size = model_params['pretrain_emb_size']

        self.bilstm_num_layers = model_params['bilstm_num_layers']
        self.bilstm_hidden_size = model_params['bilstm_hidden_size']
        self.matrix = nn.Parameter(
            get_torch_variable_from_np(
                np.zeros((self.pretrain_emb_size + self.flag_emb_size,
                          200)).astype("float32")))
Esempio n. 11
0
    def forward(self, tree, inputs):  #, tree_hidden
        for idx in range(tree.num_children):
            self.forward(tree.children[idx], inputs)  #, tree_hidden

        if tree.num_children == 0:
            child_c = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.))
            child_h = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.))
            deprels = None
        else:
            child_c, child_h = zip(*map(lambda x: x.state, tree.children))
            child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h,
                                                                    dim=0)
            deprels = get_torch_variable_from_np(
                np.array([c.deprel for c in tree.children]))

        tree.state = self.node_forward(inputs[tree.idx], child_c, child_h,
                                       deprels)

        # tree_hidden[tree.idx] = tree.state

        return tree.state
Esempio n. 12
0
    def __init__(self, model_params):
        super(SR_Matcher, self).__init__()
        #self.dropout_word_1 = nn.Dropout(p=0.0)
        #self.dropout_word_2 = nn.Dropout(p=0.0)
        self.mlp_size = 300
        #self.dropout_mlp = model_params['dropout_mlp']
        self.batch_size = model_params['batch_size']

        self.target_vocab_size = model_params['target_vocab_size']
        self.use_flag_embedding = model_params['use_flag_embedding']
        self.flag_emb_size = model_params['flag_embedding_size']
        self.pretrain_emb_size = model_params['pretrain_emb_size']

        self.bilstm_num_layers = model_params['bilstm_num_layers']
        self.bilstm_hidden_size = model_params['bilstm_hidden_size']
        self.compress_emb = nn.Sequential(nn.Linear(768, 128), nn.Tanh())
        self.matrix = nn.Parameter(
            get_torch_variable_from_np(np.zeros((256, 256)).astype("float32")))

        self.specific_NULL_emb = nn.Sequential(nn.Linear(768, 256), nn.Tanh(),
                                               nn.Linear(256, 128), nn.Tanh())

        self.scorer = nn.Sequential(nn.Linear(256, 64), nn.Tanh(),
                                    nn.Linear(64, 1))
Esempio n. 13
0
    def forward(self,
                batch_input,
                lang='En',
                unlabeled=False,
                use_bert=False,
                isTrain=True):
        if unlabeled:

            loss = self.parallel_train(batch_input)

            loss_word = 0

            return loss, loss_word

        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_id = get_torch_variable_from_np(batch_input['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        actual_lens = batch_input['seq_len']

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        else:
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        bert_input_ids = get_torch_variable_from_np(
            batch_input['bert_input_ids'])
        bert_input_mask = get_torch_variable_from_np(
            batch_input['bert_input_mask'])
        bert_out_positions = get_torch_variable_from_np(
            batch_input['bert_out_positions'])

        bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask)
        bert_emb = bert_emb[0]
        bert_emb = bert_emb[:, 1:-1, :].contiguous().detach()

        bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1),
                            bert_out_positions].detach()

        for i in range(len(bert_emb)):
            if i >= len(actual_lens):
                break

            for j in range(len(bert_emb[i])):
                if j >= actual_lens[i]:
                    bert_emb[i][j] = get_torch_variable_from_np(
                        np.zeros(768, dtype="float32"))
        bert_emb = bert_emb.detach()

        pretrain_emb = bert_emb
        seq_len = flag_emb.shape[1]
        SRL_output = self.SR_Labeler(pretrain_emb,
                                     flag_emb,
                                     predicates_1D,
                                     seq_len,
                                     para=False)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        SRL_input = F.softmax(SRL_input, 2).detach()
        pred_recur = self.SR_Compressor(pretrain_emb,
                                        word_id_emb,
                                        seq_len,
                                        para=False)

        output_word = self.SR_Matcher(pred_recur,
                                      SRL_input,
                                      pretrain_emb,
                                      word_id_emb.detach(),
                                      seq_len,
                                      para=False)

        score4Null = torch.zeros_like(output_word[:, 1:2])
        output_word = torch.cat(
            (output_word[:, 0:1], score4Null, output_word[:, 1:]), 1)

        teacher = SRL_input.view(self.batch_size * seq_len, -1).detach()
        eps = 1e-7
        student = torch.log_softmax(output_word, 1)

        unlabeled_loss_function = nn.KLDivLoss(reduction='none')
        loss_copy = unlabeled_loss_function(student, teacher)
        loss_copy = loss_copy.sum() / (self.batch_size * seq_len)

        return SRL_output, output_word, loss_copy
Esempio n. 14
0
    def parallel_train(self, batch_input):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        pretrain_batch_fr = get_torch_variable_from_np(
            unlabeled_data_fr['pretrain'])
        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])
        word_id_fr = get_torch_variable_from_np(
            unlabeled_data_fr['word_times'])
        word_id_emb_fr = self.id_embedding(word_id_fr).detach()
        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()
        pretrain_emb_fr = self.fr_pretrained_embedding(
            pretrain_batch_fr).detach()

        pretrain_batch = get_torch_variable_from_np(
            unlabeled_data_en['pretrain'])
        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        word_id = get_torch_variable_from_np(unlabeled_data_en['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        seq_len = flag_emb.shape[1]
        seq_len_en = seq_len
        pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()

        seq_len = flag_emb.shape[1]
        SRL_output = self.SR_Labeler(pretrain_emb,
                                     flag_emb.detach(),
                                     predicates_1D,
                                     seq_len,
                                     para=True)

        pred_recur = self.SR_Compressor(pretrain_emb,
                                        word_id_emb.detach(),
                                        seq_len,
                                        para=True)

        seq_len_fr = flag_emb_fr.shape[1]
        SRL_output_fr = self.SR_Labeler(pretrain_emb_fr,
                                        flag_emb_fr.detach(),
                                        predicates_1D_fr,
                                        seq_len_fr,
                                        para=True)

        pred_recur_fr = self.SR_Compressor(pretrain_emb_fr,
                                           word_id_emb_fr.detach(),
                                           seq_len_fr,
                                           para=True)
        """
        En event vector, En word
        """
        output_word_en = self.SR_Matcher(pred_recur.detach(),
                                         SRL_output.detach(),
                                         pretrain_emb,
                                         word_id_emb.detach(),
                                         seq_len,
                                         para=True)

        #############################################
        """
        Fr event vector, En word
        """
        output_word_fr = self.SR_Matcher(pred_recur_fr.detach(),
                                         SRL_output_fr,
                                         pretrain_emb,
                                         word_id_emb.detach(),
                                         seq_len,
                                         para=True)
        unlabeled_loss_function = nn.KLDivLoss(size_average=False)
        output_word_en = F.softmax(output_word_en, dim=1).detach()
        output_word_fr = F.log_softmax(output_word_fr, dim=1)
        loss = unlabeled_loss_function(
            output_word_fr, output_word_en) / (seq_len_en * self.batch_size)
        #############################################3
        """
        En event vector, Fr word
        """
        output_word_en = self.SR_Matcher(pred_recur.detach(),
                                         SRL_output.detach(),
                                         pretrain_emb_fr,
                                         word_id_emb_fr.detach(),
                                         seq_len_fr,
                                         para=True)
        """
        Fr event vector, Fr word
        """
        output_word_fr = self.SR_Matcher(pred_recur_fr.detach(),
                                         SRL_output_fr,
                                         pretrain_emb_fr,
                                         word_id_emb_fr.detach(),
                                         seq_len_fr,
                                         para=True)
        unlabeled_loss_function = nn.KLDivLoss(size_average=False)
        output_word_en = F.softmax(output_word_en, dim=1).detach()
        output_word_fr = F.log_softmax(output_word_fr, dim=1)
        loss_2 = unlabeled_loss_function(
            output_word_fr, output_word_en) / (seq_len_fr * self.batch_size)
        return loss, loss_2
Esempio n. 15
0
    def forward(self, batch_input, elmo, withParallel=True, lang='En'):

        if lang == 'En':
            word_batch = get_torch_variable_from_np(batch_input['word'])
            pretrain_batch = get_torch_variable_from_np(
                batch_input['pretrain'])
        else:
            word_batch = get_torch_variable_from_np(batch_input['word'])
            pretrain_batch = get_torch_variable_from_np(
                batch_input['pretrain'])

        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        pos_batch = get_torch_variable_from_np(batch_input['pos'])
        pos_emb = self.pos_embedding(pos_batch)
        #log(pretrain_batch)
        #log(flag_batch)

        role_index = get_torch_variable_from_np(batch_input['role_index'])
        role_mask = get_torch_variable_from_np(batch_input['role_mask'])
        role_mask_expand = role_mask.unsqueeze(2).expand(
            self.batch_size, self.target_vocab_size, self.word_emb_size)

        role2word_batch = pretrain_batch.gather(dim=1, index=role_index)
        #log(role2word_batch)
        role2word_emb = self.pretrained_embedding(role2word_batch).detach()
        role2word_emb = role2word_emb * role_mask_expand.float()
        #log(role2word_emb[0][1])
        #log(role2word_emb[0][2])
        #log(role_mask)
        #log(role_mask_expand)
        #log("#################")
        #log(role_index)
        #log(role_mask)
        #log(role2word_batch)

        if withParallel:
            fr_word_batch = get_torch_variable_from_np(batch_input['fr_word'])
            fr_pretrain_batch = get_torch_variable_from_np(
                batch_input['fr_pretrain'])
            fr_flag_batch = get_torch_variable_from_np(batch_input['fr_flag'])

        #log(fr_pretrain_batch[0])
        #log(fr_flag_batch)
        if self.use_flag_embedding:
            flag_emb = self.flag_embedding(flag_batch)
        else:
            flag_emb = flag_batch.view(flag_batch.shape[0],
                                       flag_batch.shape[1], 1).float()

        seq_len = flag_batch.shape[1]
        if lang == "En":
            word_emb = self.word_embedding(word_batch)
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()

        else:
            word_emb = self.fr_word_embedding(word_batch)
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        if withParallel:
            #fr_word_emb = self.fr_word_embedding(fr_word_batch)
            fr_pretrain_emb = self.fr_pretrained_embedding(
                fr_pretrain_batch).detach()
            fr_flag_emb = self.flag_embedding(fr_flag_batch)
            fr_seq_len = fr_flag_batch.shape[1]
            role_mask_expand_timestep = role_mask.unsqueeze(2).expand(
                self.batch_size, self.target_vocab_size, fr_seq_len)
            role_mask_expand_forFR = \
                role_mask_expand.unsqueeze(1).expand(self.batch_size, fr_seq_len, self.target_vocab_size,
                                                     self.pretrain_emb_size)

        # predicate_emb = self.word_embedding(predicate_batch)
        # predicate_pretrain_emb = self.pretrained_embedding(predicate_pretrain_batch)

        #######semantic role labelerxxxxxxxxxx

        if self.use_deprel:
            input_emb = torch.cat([flag_emb, pretrain_emb], 2)  #
        else:
            input_emb = torch.cat([flag_emb, pretrain_emb], 2)  #

        if withParallel:
            fr_input_emb = torch.cat([fr_flag_emb, fr_pretrain_emb], 2)

        input_emb = self.word_dropout(input_emb)
        #input_emb = torch.cat((input_emb, get_torch_variable_from_np(np.zeros((self.batch_size, seq_len, self.target_vocab_size))).float()), 2)
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        #output = self.output_layer(hidden_input)

        arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input))
        predicates_1D = batch_input['predicates_idx']
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur))
        output = bilinear(arg_hidden,
                          self.rel_W,
                          pred_hidden,
                          self.mlp_size,
                          seq_len,
                          1,
                          self.batch_size,
                          num_outputs=self.target_vocab_size,
                          bias_x=True,
                          bias_y=True)
        en_output = output.view(self.batch_size * seq_len, -1)

        if withParallel:
            role2word_emb_1 = role2word_emb.view(self.batch_size,
                                                 self.target_vocab_size, -1)
            role2word_emb_2 = role2word_emb_1.unsqueeze(dim=1)
            role2word_emb_expand = role2word_emb_2.expand(
                self.batch_size, fr_seq_len, self.target_vocab_size,
                self.pretrain_emb_size)

            #log(role2word_emb_expand[0, 1, 4])
            #log(role2word_emb_expand[0, 2, 4])

            fr_pretrain_emb = fr_pretrain_emb.view(self.batch_size, fr_seq_len,
                                                   -1)
            fr_pretrain_emb = fr_pretrain_emb.unsqueeze(dim=2)
            fr_pretrain_emb_expand = fr_pretrain_emb.expand(
                self.batch_size, fr_seq_len, self.target_vocab_size,
                self.pretrain_emb_size)
            fr_pretrain_emb_expand = fr_pretrain_emb_expand  # * role_mask_expand_forFR.float()
            # B T R W
            emb_distance = fr_pretrain_emb_expand - role2word_emb_expand
            emb_distance = emb_distance * emb_distance
            # B T R
            emb_distance = emb_distance.sum(dim=3)
            emb_distance = torch.sqrt(emb_distance)
            emb_distance = emb_distance.transpose(1, 2)
            # emb_distance_min = emb_distance_min.expand(self.batch_size, fr_seq_len, self.target_vocab_size)
            # emb_distance = F.softmax(emb_distance, dim=1)*role_mask_expand_timestep.float()
            # B R
            emb_distance_min, emb_distance_argmin = torch.min(emb_distance,
                                                              dim=2,
                                                              keepdim=True)
            emb_distance_max, emb_distance_argmax = torch.max(emb_distance,
                                                              dim=2,
                                                              keepdim=True)
            emb_distance_max_expand = emb_distance_max.expand(
                self.batch_size, self.target_vocab_size, fr_seq_len)
            emb_distance_min_expand = emb_distance_min.expand(
                self.batch_size, self.target_vocab_size, fr_seq_len)
            emb_distance_nomalized = (emb_distance - emb_distance_min_expand
                                      ) / (emb_distance_max_expand -
                                           emb_distance_min_expand)
            #* role_mask_expand_timestep.float()
            # emb_distance_argmin_expand = emb_distance_argmin.expand(self.batch_size, fr_seq_len, self.target_vocab_size)
            # log("######################")

            emb_distance_nomalized = emb_distance_nomalized.detach()

            top2 = emb_distance_nomalized.topk(2,
                                               dim=2,
                                               largest=False,
                                               sorted=True)[0]

            top2_gap = top2[:, :, 1] - top2[:, :, 0]
            #log(top2_gap)
            #log(role_mask)
            #log(emb_distance_nomalized[0,:,2])
            #log(emb_distance_argmax)

            #emb_distance_argmin = emb_distance_argmin*role_mask
            #log(emb_distance_argmin)
            #log(emb_distance_nomalized[0][2])
            """
            weight_4_loss = emb_distance_nomalized
            #log(emb_distance_argmin[0,2])
            for i in range(self.batch_size):
                for j in range(self.target_vocab_size):
                    weight_4_loss[i,j].index_fill_(0, emb_distance_argmin[i][j], 1)
            """
            #log(weight_4_loss[0][2])

            fr_input_emb = self.word_dropout(fr_input_emb).detach()
            #fr_input_emb = torch.cat((fr_input_emb, emb_distance_nomalized), 2)
            fr_bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
                fr_input_emb, self.fr_bilstm_hidden_state)
            fr_bilstm_output = fr_bilstm_output.contiguous()
            hidden_input = fr_bilstm_output.view(
                fr_bilstm_output.shape[0] * fr_bilstm_output.shape[1], -1)
            hidden_input = hidden_input.view(self.batch_size, fr_seq_len, -1)
            arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input))
            predicates_1D = batch_input['fr_predicates_idx']
            #log(predicates_1D)
            #log(predicates_1D)
            pred_recur = hidden_input[np.arange(0, self.batch_size),
                                      predicates_1D]
            pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur))
            output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              fr_seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)
            output = output.view(self.batch_size, fr_seq_len, -1)
            output = output.transpose(1, 2)
            #B R T
            #output_p = F.softmax(output, dim=2)
            #output = output * role_mask_expand_timestep.float()
            #log(role_mask_expand_timestep)
            """
            output_exp = torch.exp(output)
            output_exp_weighted = output_exp #* weight_4_loss
            output_expsum = output_exp_weighted.sum(dim=2, keepdim=True).expand(self.batch_size, self.target_vocab_size,
                                                                                fr_seq_len)
            output = output_exp/output_expsum
            """
            #log(output_expsum)
            #log(output_p[0, 2])
            # B R 1
            #output_pmax, output_pmax_arg = torch.max(output_p, dim=2, keepdim=True)

            #log(output[0])
            #log(emb_distance[0,:, 2])
            #log(emb_distance_nomalized[0,:, 2])
            #log(role_mask[0])
            #log(role_mask[0])
            #log(output[0][0])
            #log("++++++++++++++++++++")
            #log(emb_distance)
            #log(emb_distance.gather(1, emb_distance_argmin))
            #output_pargminD = output_p.gather(2, emb_distance_argmin)
            #log(output_argminD)
            #weighted_distance = (output/output_argminD) * emb_distance_nomalized
            #bias = torch.FloatTensor(output_pmax.size()).fill_(1).cuda()
            #rank_loss = (output_pmax/output_pargminD)# * emb_distance_nomalized.gather(2, output_pmax_arg)
            #rank_loss = rank_loss.view(self.batch_size, self.target_vocab_size)
            #rank_loss = rank_loss * role_mask.float()
            #log("++++++++++++++++++++++")
            #log(output_max-output_argminD)
            #log(emb_distance_nomalized.gather(1, output_max_arg))
            # B R
            #weighted_distance = weighted_distance.squeeze() * role_mask.float()
            """
            output = F.softmax(output, dim=1)
            output = output.transpose(1, 2)
            fr_role2word_emb = torch.bmm(output, fr_pretrain_emb)
            criterion = nn.MSELoss(reduce=False)
            l2_loss = criterion(fr_role2word_emb.view(self.batch_size*self.target_vocab_size, -1),
                                  role2word_emb.view(self.batch_size*self.target_vocab_size, -1))
            """
            criterion = nn.CrossEntropyLoss(ignore_index=0, reduce=False)

            output = output.view(self.batch_size * self.target_vocab_size, -1)
            emb_distance_argmin = emb_distance_argmin * role_mask.unsqueeze(2)
            emb_distance_argmin = emb_distance_argmin.view(-1)
            #log(emb_distance_argmin[0][2])
            l2_loss = criterion(output, emb_distance_argmin)
            l2_loss = l2_loss.view(self.batch_size, self.target_vocab_size)
            top2_gap = top2_gap.view(self.batch_size, self.target_vocab_size)
            l2_loss = l2_loss * top2_gap
            l2_loss = l2_loss.sum()
            #l2_loss = F.nll_loss(torch.log(output), emb_distance_argmin, ignore_index=0)

            #log(emb_distance_argmin)
            #log(l2_loss)
            #log("+++++++++++++++++++++")
            #log(l2_loss)
            #log(batch_input['fr_loss_mask'])
            l2_loss = l2_loss * get_torch_variable_from_np(
                batch_input['fr_loss_mask']).float()
            l2_loss = l2_loss.sum()  #/float_role_mask.sum()
            #log("+")
            #log(l2_loss)
            return en_output, l2_loss
        return en_output
Esempio n. 16
0
    def forward(self,
                batch_input,
                lang='En',
                unlabeled=False,
                self_constrain=False,
                use_bert=False):
        if unlabeled:

            loss = self.parallel_train(batch_input, use_bert)

            loss_word = 0

            return loss, loss_word
        if self_constrain:

            loss = self.self_train(batch_input)
            return loss

        #pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_id = get_torch_variable_from_np(batch_input['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        actual_lens = batch_input['seq_len']
        #print(actual_lens)
        if use_bert:
            bert_input_ids = get_torch_variable_from_np(
                batch_input['bert_input_ids'])
            bert_input_mask = get_torch_variable_from_np(
                batch_input['bert_input_mask'])
            bert_out_positions = get_torch_variable_from_np(
                batch_input['bert_out_positions'])

            bert_emb = self.model(bert_input_ids,
                                  attention_mask=bert_input_mask)
            bert_emb = bert_emb[0]
            bert_emb = bert_emb[:, 1:-1, :].contiguous().detach()

            bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1),
                                bert_out_positions].detach()

            for i in range(len(bert_emb)):
                if i >= len(actual_lens):
                    break
                for j in range(len(bert_emb[i])):
                    if j >= actual_lens[i]:
                        bert_emb[i][j] = get_torch_variable_from_np(
                            np.zeros(768, dtype="float32"))

            bert_emb = bert_emb.detach()

        #if lang == "En":
        #    pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        #else:
        #    pretrain_emb = self.fr_pretrained_embedding(pretrain_batch).detach()

        seq_len = flag_emb.shape[1]
        if not use_bert:
            SRL_output = self.SR_Labeler(pretrain_emb,
                                         flag_emb,
                                         predicates_1D,
                                         seq_len,
                                         para=False)

            SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
            SRL_input = SRL_input
            pred_recur = self.SR_Compressor(SRL_input,
                                            pretrain_emb,
                                            flag_emb.detach(),
                                            word_id_emb,
                                            predicates_1D,
                                            seq_len,
                                            para=False)

            output_word = self.SR_Matcher(pred_recur,
                                          pretrain_emb,
                                          flag_emb.detach(),
                                          word_id_emb.detach(),
                                          seq_len,
                                          para=False)

        else:
            SRL_output = self.SR_Labeler(bert_emb,
                                         flag_emb,
                                         predicates_1D,
                                         seq_len,
                                         para=False,
                                         use_bert=True)

            SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
            SRL_input = SRL_input
            pred_recur = self.SR_Compressor(SRL_input,
                                            bert_emb,
                                            flag_emb.detach(),
                                            word_id_emb,
                                            predicates_1D,
                                            seq_len,
                                            para=False,
                                            use_bert=True)

            output_word = self.SR_Matcher(pred_recur,
                                          bert_emb,
                                          flag_emb.detach(),
                                          word_id_emb.detach(),
                                          seq_len,
                                          para=False,
                                          use_bert=True)
            copy_loss = 0  #self.copy_loss(SRL_input, flag_emb, bert_emb, seq_len)

        return SRL_output, output_word, copy_loss
Esempio n. 17
0
                                          batch_size,
                                          word2idx,
                                          fr_word2idx,
                                          lemma2idx,
                                          pos2idx,
                                          pretrain2idx,
                                          fr_pretrain2idx,
                                          deprel2idx,
                                          argument2idx,
                                          idx2word,
                                          shuffle=True,
                                          withParrallel=True)):
                srl_model.train()
                target_argument = train_input_data['argument']
                flat_argument = train_input_data['flat_argument']
                target_batch_variable = get_torch_variable_from_np(
                    flat_argument)

                out = srl_model(train_input_data,
                                elmo,
                                withParallel=False,
                                lang='En',
                                isPretrain=True)
                loss = criterion(out, target_batch_variable)

                if batch_i % 50 == 0:
                    log(batch_i, loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
Esempio n. 18
0
def eval_data(model, elmo, dataset, batch_size ,word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2argument, idx2word, unify_pred = False, predicate_correct=0, predicate_sum=0, isPretrain=False):

    model.eval()
    golden = []
    predict = []

    output_data = []
    cur_sentence = None
    cur_sentence_data = None

    for batch_i, input_data in enumerate(inter_utils.get_batch(dataset, batch_size, word2idx, fr_word2idx,
                                                             lemma2idx, pos2idx, pretrain2idx,
                                                             fr_pretrain2idx, deprel2idx, argument2idx, idx2word, lang="En")):
        target_flags = input_data['sen_flags']
        flat_argument = input_data['flat_argument']
        target_batch_variable = get_torch_variable_from_np(flat_argument)

        sentence_id = input_data['sentence_id']
        predicate_id = input_data['predicate_id']
        word_id = input_data['word_id']
        sentence_len =  input_data['sentence_len']
        seq_len = input_data['seq_len']
        bs = input_data['batch_size']
        psl = input_data['pad_seq_len']
        out = model(input_data, lang='En')

        _, pred = torch.max(out, 1)
        pred = get_data(pred)
        pred = np.reshape(pred, target_flags.shape)

        for idx in range(pred.shape[0]):
            predict.append(list(pred[idx]))
            golden.append(list(target_flags[idx]))

        pre_data = []
        for b in range(len(seq_len)):
            line_data = ['_' for _ in range(sentence_len[b])]
            for s in range(seq_len[b]):
                wid = word_id[b][s]
                #line_data[wid-1] = idx2argument[pred[b][s]]
                line_data[wid - 1] = str(pred[b][s])
            pre_data.append(line_data)

        for b in range(len(sentence_id)):
            if cur_sentence != sentence_id[b]:
                if cur_sentence_data is not None:
                    output_data.append(cur_sentence_data)
                cur_sentence_data = [[sentence_id[b]]*len(pre_data[b]),pre_data[b]]
                cur_sentence = sentence_id[b]
            else:
                assert cur_sentence_data is not None
                cur_sentence_data.append(pre_data[b])

    if cur_sentence_data is not None and len(cur_sentence_data)>0:
        output_data.append(cur_sentence_data)
    
    score = sem_f1_score(golden, predict, argument2idx, unify_pred, predicate_correct, predicate_sum)
    """
    P = correct_pos / NonullPredict_pos
    R = correct_pos / NonullTruth_pos
    F = 2 * P * R / (P + R)
    log("POS: ", P, R, F)

    P = correct_PI / NonullPredict_PI
    R = correct_PI / NonullTruth_PI
    F = 2 * P * R / (P + R)
    log("PI: ", P, R, F)

    P = correct_deprel / NonullPredict_deprel
    R = correct_deprel / NonullTruth_deprel
    F = 2 * P * R / (P + R)
    log("deprel: ", P, R, F)

    P = correct_link / NonullPredict_link
    R = correct_link / NonullTruth_link
    F = 2 * P * R / (P + R)
    log(correct_link, NonullPredict_link, NonullTruth_link)
    log("link: ", P, R, F)
    """

    model.train()

    return score, output_data
Esempio n. 19
0
    def forward(self,
                batch_input,
                lang='En',
                unlabeled=False,
                self_constrain=False,
                use_bert=False,
                isTrain=False):
        if unlabeled:
            #l2loss = self.word_trans(batch_input, use_bert)
            consistent_loss = self.parallel_train(batch_input, use_bert)
            return consistent_loss
        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        flag_emb = self.flag_embedding(flag_batch)
        actual_lens = batch_input['seq_len']
        # print(actual_lens)
        if use_bert:
            bert_input_ids = get_torch_variable_from_np(
                batch_input['bert_input_ids'])
            bert_input_mask = get_torch_variable_from_np(
                batch_input['bert_input_mask'])
            bert_out_positions = get_torch_variable_from_np(
                batch_input['bert_out_positions'])

            bert_emb = self.model(bert_input_ids,
                                  attention_mask=bert_input_mask)
            bert_emb = bert_emb[0]
            bert_emb = bert_emb[:, 1:-1, :].contiguous().detach()

            bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1),
                                bert_out_positions].detach()
            """
            for i in range(len(bert_emb)):
                if i >= len(actual_lens):
                    break
                for j in range(len(bert_emb[i])):
                    if j >= actual_lens[i]:
                        bert_emb[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb = bert_emb.detach()
            """

        seq_len = flag_emb.shape[1]

        SRL_output = self.SR_Labeler(bert_emb,
                                     flag_emb,
                                     predicates_1D,
                                     seq_len,
                                     para=False,
                                     use_bert=True)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        SRL_input_probs = F.softmax(SRL_input, 2).detach()

        output_word = self.SR_Matcher(bert_emb.detach(),
                                      bert_emb.detach(),
                                      SRL_input_probs,
                                      seq_len,
                                      seq_len,
                                      isTrain=isTrain,
                                      para=False)

        #score4Null = torch.zeros_like(output_word[:, 1:2])
        #output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1)

        teacher = SRL_input_probs.view(self.batch_size * seq_len, -1).detach()
        student = torch.log_softmax(output_word, dim=1)
        unlabeled_loss_function = nn.KLDivLoss(reduction='none')
        loss_copy = unlabeled_loss_function(student, teacher)
        loss_copy = loss_copy.sum() / (self.batch_size * seq_len)

        return SRL_output, output_word, loss_copy
Esempio n. 20
0
    def parallel_train_(self, batch_input, use_bert, isTrain=True):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])

        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()
        actual_lens_fr = unlabeled_data_fr['seq_len']

        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        actual_lens_en = unlabeled_data_en['seq_len']
        flag_emb = self.flag_embedding(flag_batch).detach()
        seq_len = flag_emb.shape[1]
        seq_len_en = seq_len

        if use_bert:
            bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids'])
            bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask'])
            bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions'])

            bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr)
            bert_emb_fr = bert_emb_fr[0]
            bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach()
            bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach()

            for i in range(len(bert_emb_fr)):
                if i >= len(actual_lens_fr):
                    print("error")
                    break
                for j in range(len(bert_emb_fr[i])):
                    if j >= actual_lens_fr[i]:
                        bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb_fr = gaussian(bert_emb_fr, isTrain, 0, 0.1)
            bert_emb_fr = bert_emb_fr.detach()

            bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids'])
            bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask'])
            bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions'])

            bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en)
            bert_emb_en = bert_emb_en[0]
            bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach()
            bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach()

            for i in range(len(bert_emb_en)):
                if i >= len(actual_lens_en):
                    print("error")
                    break
                for j in range(len(bert_emb_en[i])):
                    if j >= actual_lens_en[i]:
                        bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb_en = gaussian(bert_emb_en, isTrain, 0, 0.1)
            bert_emb_en = bert_emb_en.detach()


        seq_len = flag_emb.shape[1]
        SRL_output = self.SR_Labeler(bert_emb_en, flag_emb.detach(), predicates_1D, seq_len, para=True, use_bert=True)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        pred_recur = self.SR_Compressor(SRL_input.detach(), bert_emb_en,
                                        flag_emb.detach(), None, predicates_1D, seq_len, para=True, use_bert=True)



        seq_len_fr = flag_emb_fr.shape[1]
        SRL_output_fr = self.SR_Labeler(bert_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True, use_bert=True)

        SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1)
        pred_recur_fr = self.SR_Compressor(SRL_input_fr, bert_emb_fr,
                                        flag_emb_fr.detach(), None, predicates_1D_fr, seq_len_fr, para=True, use_bert=True)

        """
        En event vector, En word
        """
        output_word_en_en = self.SR_Matcher(pred_recur.detach(), bert_emb_en, flag_emb.detach(), None, seq_len,
                                         para=True, use_bert=True).detach()

        #############################################
        """
        Fr event vector, En word
        """
        output_word_fr_en = self.SR_Matcher(pred_recur_fr, bert_emb_en, flag_emb.detach(), None, seq_len,
                                      para=True, use_bert=True)

        ## B*T R 2
        Union_enfr_en = torch.cat((output_word_en_en.view(-1, self.target_vocab_size, 1),
                                   output_word_fr_en.view(-1, self.target_vocab_size, 1)), 2)
        ## B*T R
        max_enfr_en = torch.max(Union_enfr_en, 2)[0]

        #############################################3
        """
        En event vector, Fr word
        """
        output_word_en_fr = self.SR_Matcher(pred_recur.detach(), bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr,
                                         para=True, use_bert=True).detach()
        """
        Fr event vector, Fr word
        """
        output_word_fr_fr = self.SR_Matcher(pred_recur_fr, bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr,
                                         para=True, use_bert=True)

        ## B*T R 2
        Union_enfr_fr = torch.cat((output_word_en_fr.view(-1, self.target_vocab_size, 1),
                                   output_word_fr_fr.view(-1, self.target_vocab_size, 1)), 2)
        ## B*T R
        max_enfr_fr = torch.max(Union_enfr_fr, 2)[0]

        unlabeled_loss_function = nn.KLDivLoss(reduction='none')
        """
        word_mask_4en = self.P_word_mask(output_word_fr_en.view(self.batch_size, seq_len, -1),
                                         output_word_fr_fr.view(self.batch_size, seq_len_fr, -1), seq_len_en)
        word_mask_4en_tensor = get_torch_variable_from_np(word_mask_4en).view(self.batch_size*seq_len_en, -1)

        word_mask_4fr = self.R_word_mask(output_word_en_en.view(self.batch_size, seq_len, -1),
                                         output_word_en_fr.view(self.batch_size, seq_len_fr, -1), seq_len_fr)
        word_mask_4fr_tensor = get_torch_variable_from_np(word_mask_4fr).view(self.batch_size*seq_len_fr, -1)
        
        word_mask_en, word_mask_fr = self.word_mask_soft(output_word_en_en.view(self.batch_size, seq_len_en, -1),
                                                    output_word_en_fr.view(self.batch_size, seq_len_fr, -1),
                                                    seq_len_en, seq_len_fr)
        word_mask_en_tensor = get_torch_variable_from_np(word_mask_en).view(self.batch_size * seq_len_en)
        word_mask_fr_tensor = get_torch_variable_from_np(word_mask_fr).view(self.batch_size * seq_len_fr)
        """
        #output_word_en_en = F.softmax(output_word_en_en, dim=1).detach()
        #output_word_fr_en = F.log_softmax(output_word_fr_en, dim=1)
        #loss = unlabeled_loss_function(output_word_fr_en, output_word_en_en)
        output_word_en_en = F.softmax(output_word_en_en, dim=1).detach()
        max_enfr_en = F.log_softmax(max_enfr_en, dim=1)
        loss = unlabeled_loss_function(max_enfr_en, output_word_en_en)
        loss = loss.sum(dim=1)#*word_mask_en_tensor
        loss = loss.sum() / (self.batch_size*seq_len_en)

        #output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach()
        max_enfr_fr = F.softmax(max_enfr_fr, dim=1).detach()
        output_word_fr_fr = F.log_softmax(output_word_fr_fr, dim=1)
        loss_2 = unlabeled_loss_function(output_word_fr_fr, max_enfr_fr)
        loss_2 = loss_2.sum(dim=1)#*word_mask_fr_tensor

        loss_2 = loss_2.sum()/ (self.batch_size*seq_len_fr)

        return  loss, loss_2
Esempio n. 21
0
    def forward(self, batch_input, elmo):

        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_batch = get_torch_variable_from_np(batch_input['word'])
        lemma_batch = get_torch_variable_from_np(batch_input['lemma'])
        pos_batch = get_torch_variable_from_np(batch_input['pos'])
        deprel_batch = get_torch_variable_from_np(batch_input['deprel'])
        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        origin_batch = batch_input['origin']
        origin_deprel_batch = batch_input['deprel']
        chars_batch = get_torch_variable_from_np(batch_input['char'])

        if self.use_flag_embedding:
            flag_emb = self.flag_embedding(flag_batch)
        else:
            flag_emb = flag_batch.view(flag_batch.shape[0],
                                       flag_batch.shape[1], 1).float()
        seq_len = flag_batch.shape[1]
        word_emb = self.word_embedding(word_batch)
        lemma_emb = self.lemma_embedding(lemma_batch)
        pos_emb = self.pos_embedding(pos_batch)
        char_embeddings = self.char_embeddings(chars_batch)
        character_embeddings = self.charCNN(char_embeddings)
        pretrain_emb = self.pretrained_embedding(pretrain_batch)

        if self.use_deprel:
            deprel_emb = self.deprel_embedding(deprel_batch)

        # predicate_emb = self.word_embedding(predicate_batch)
        # predicate_pretrain_emb = self.pretrained_embedding(predicate_pretrain_batch)

        ##sentence learner#####################################
        SL_input_emb = self.word_dropout(
            torch.cat([word_emb, pretrain_emb, character_embeddings], 2))
        h0, (_,
             SL_final_state) = self.sentence_learner(SL_input_emb,
                                                     self.SL_hidden_state0)
        h1, (_, SL_final_state) = self.sentence_learner_high(
            h0, self.SL_hidden_state0_high)
        SL_output = h1
        POS_output = self.pos_classifier(SL_output).view(
            self.batch_size * seq_len, -1)
        PI_output = self.PI_classifier(SL_output).view(
            self.batch_size * seq_len, -1)
        ## deprel
        hidden_input = SL_output
        arg_hidden = self.mlp_dropout(self.mlp_arg_deprel(SL_output))
        predicates_1D = batch_input['predicates_idx']
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.pred_dropout(self.mlp_pred_deprel(pred_recur))
        deprel_output = bilinear(arg_hidden,
                                 self.deprel_W,
                                 pred_hidden,
                                 self.mlp_size,
                                 seq_len,
                                 1,
                                 self.batch_size,
                                 num_outputs=self.deprel_vocab_size,
                                 bias_x=True,
                                 bias_y=True)

        arg_hidden = self.mlp_dropout(self.mlp_arg_link(SL_output))
        predicates_1D = batch_input['predicates_idx']
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.pred_dropout(self.mlp_pred_link(pred_recur))
        Link_output = bilinear(arg_hidden,
                               self.link_W,
                               pred_hidden,
                               self.mlp_size,
                               seq_len,
                               1,
                               self.batch_size,
                               num_outputs=4,
                               bias_x=True,
                               bias_y=True)
        deprel_output = deprel_output.view(self.batch_size * seq_len, -1)
        Link_output = Link_output.view(self.batch_size * seq_len, -1)

        POS_probs = F.softmax(POS_output, dim=1).view(self.batch_size, seq_len,
                                                      -1)
        deprel_probs = F.softmax(deprel_output,
                                 dim=1).view(self.batch_size, seq_len, -1)
        link_probs = F.softmax(Link_output,
                               dim=1).view(self.batch_size, seq_len, -1)

        POS_compose = F.tanh(self.POS2hidden(POS_probs))
        deprel_compose = F.tanh(self.deprel2hidden(deprel_probs))
        link_compose = link_probs

        #######semantic role labelerxxxxxxxxxx

        if self.use_deprel:
            input_emb = torch.cat([
                flag_emb, word_emb, pretrain_emb, character_embeddings,
                POS_compose, deprel_compose, link_compose
            ], 2)  #
        else:
            input_emb = torch.cat([
                flag_emb, word_emb, pretrain_emb, character_embeddings,
                POS_compose
            ], 2)  #

        input_emb = self.word_dropout(input_emb)

        w = F.softmax(self.elmo_w, dim=0)
        SRL_composer = self.elmo_gamma * (w[0] * h0 + w[1] * h1)
        SRL_composer = self.elmo_mlp(SRL_composer)
        bilstm_output_0, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state0)
        high_input = torch.cat((bilstm_output_0, SRL_composer), 2)
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer_high(
            high_input, self.bilstm_hidden_state0_high)

        # bilstm_final_state = bilstm_final_state.view(self.bilstm_num_layers, 2, self.batch_size, self.bilstm_hidden_size)

        # bilstm_final_state = bilstm_final_state[-1]

        # sentence latent representation
        # bilstm_final_state = torch.cat([bilstm_final_state[0], bilstm_final_state[1]], 1)

        # bilstm_output = self.bilstm_mlp(bilstm_output)

        if self.use_self_attn:
            x = F.tanh(self.attn_linear_first(bilstm_output))
            x = self.attn_linear_second(x)
            x = self.softmax(x, 1)
            attention = x.transpose(1, 2)
            sentence_embeddings = torch.matmul(attention, bilstm_output)
            sentence_embeddings = torch.sum(sentence_embeddings,
                                            1) / self.self_attn_head

            context = sentence_embeddings.repeat(bilstm_output.size(1), 1,
                                                 1).transpose(0, 1)

            bilstm_output = torch.cat([bilstm_output, context], 2)

            bilstm_output = self.attn_linear_final(bilstm_output)

            # energy = self.biaf_attn(bilstm_output, bilstm_output)

            # # energy = energy.transpose(1, 2)

            # flag_indices = batch_input['flag_indices']

            # attention = []
            # for idx in range(len(flag_indices)):
            #     attention.append(energy[idx,:,:,flag_indices[idx]].view(1, self.bilstm_hidden_size, -1))

            # attention = torch.cat(attention,dim=0)

            # # attention = attention.transpose(1, 2)

            # attention = self.softmax(attention, 2)

            # # attention = attention.transpose(1,2)

            # sentence_embeddings = attention@bilstm_output

            # sentence_embeddings = torch.sum(sentence_embeddings,1)/self.self_attn_head

            # context = sentence_embeddings.repeat(bilstm_output.size(1), 1, 1).transpose(0, 1)

            # bilstm_output = torch.cat([bilstm_output, context], 2)

            # bilstm_output = self.attn_linear_final(bilstm_output)
        else:
            bilstm_output = bilstm_output.contiguous()

        if self.use_gcn:
            # in_rep = torch.matmul(bilstm_output, self.W_in)
            # out_rep = torch.matmul(bilstm_output, self.W_out)
            # self_rep = torch.matmul(bilstm_output, self.W_self)

            # child_indicies = batch_input['children_indicies']

            # head_batch = batch_input['head']

            # context = []
            # for idx in range(head_batch.shape[0]):
            #     states = []
            #     for jdx in range(head_batch.shape[1]):
            #         head_ind = head_batch[idx, jdx]-1
            #         childs = child_indicies[idx][jdx]
            #         state = self_rep[idx, jdx]
            #         if head_ind != -1:
            #               state = state + in_rep[idx, head_ind]
            #         for child in childs:
            #             state = state + out_rep[idx, child]
            #         state = F.relu(state + self.gcn_bias)
            #         states.append(state.unsqueeze(0))
            #     context.append(torch.cat(states, dim=0))

            # context = torch.cat(context, dim=0)

            # bilstm_output = context

            seq_len = bilstm_output.shape[1]

            adj_arc_in = np.zeros((self.batch_size * seq_len, 2),
                                  dtype='int32')
            adj_lab_in = np.zeros((self.batch_size * seq_len), dtype='int32')

            adj_arc_out = np.zeros((self.batch_size * seq_len, 2),
                                   dtype='int32')
            adj_lab_out = np.zeros((self.batch_size * seq_len), dtype='int32')

            mask_in = np.zeros((self.batch_size * seq_len), dtype='float32')
            mask_out = np.zeros((self.batch_size * seq_len), dtype='float32')

            mask_loop = np.ones((self.batch_size * seq_len, 1),
                                dtype='float32')

            for idx in range(len(origin_batch)):
                for jdx in range(len(origin_batch[idx])):

                    offset = jdx + idx * seq_len

                    head_ind = int(origin_batch[idx][jdx][10]) - 1

                    if head_ind == -1:
                        continue

                    dependent_ind = int(origin_batch[idx][jdx][4]) - 1

                    adj_arc_in[offset] = np.array([idx, dependent_ind])

                    adj_lab_in[offset] = np.array(
                        [origin_deprel_batch[idx, jdx]])

                    mask_in[offset] = 1

                    adj_arc_out[offset] = np.array([idx, head_ind])

                    adj_lab_out[offset] = np.array(
                        [origin_deprel_batch[idx, jdx]])

                    mask_out[offset] = 1

            if USE_CUDA:
                adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in)).cuda()
                adj_arc_out = torch.LongTensor(
                    np.transpose(adj_arc_out)).cuda()

                adj_lab_in = Variable(torch.LongTensor(adj_lab_in).cuda())
                adj_lab_out = Variable(torch.LongTensor(adj_lab_out).cuda())

                mask_in = Variable(
                    torch.FloatTensor(
                        mask_in.reshape(
                            (self.batch_size * seq_len, 1))).cuda())
                mask_out = Variable(
                    torch.FloatTensor(
                        mask_out.reshape(
                            (self.batch_size * seq_len, 1))).cuda())
                mask_loop = Variable(torch.FloatTensor(mask_loop).cuda())
            else:
                adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in))
                adj_arc_out = torch.LongTensor(np.transpose(adj_arc_out))

                adj_lab_in = Variable(torch.LongTensor(adj_lab_in))
                adj_lab_out = Variable(torch.LongTensor(adj_lab_out))

                mask_in = Variable(
                    torch.FloatTensor(
                        mask_in.reshape((self.batch_size * seq_len, 1))))
                mask_out = Variable(
                    torch.FloatTensor(
                        mask_out.reshape((self.batch_size * seq_len, 1))))
                mask_loop = Variable(torch.FloatTensor(mask_loop))

            gcn_context = self.syntactic_gcn(bilstm_output, adj_arc_in,
                                             adj_arc_out, adj_lab_in,
                                             adj_lab_out, mask_in, mask_out,
                                             mask_loop)

            #gcn_context = self.softmax(gcn_context, axis=2)
            gcn_context = F.softmax(gcn_context, dim=2)

            bilstm_output = torch.cat([bilstm_output, gcn_context], dim=2)

            bilstm_output = self.gcn_mlp(bilstm_output)

        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)

        if self.use_highway:
            for current_layer in self.highway_layers:
                hidden_input = current_layer(hidden_input)

            output = self.output_layer(hidden_input)
        else:
            hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
            #output = self.output_layer(hidden_input)

        if self.use_biaffine:
            arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input))
            predicates_1D = batch_input['predicates_idx']
            pred_recur = hidden_input[np.arange(0, self.batch_size),
                                      predicates_1D]
            pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur))
            output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)
            output = output.view(self.batch_size * seq_len, -1)

        return output, POS_output, PI_output, deprel_output, Link_output
Esempio n. 22
0
    def forward(self,
                batch_input,
                elmo,
                withParallel=True,
                lang='En',
                isPretrain=False,
                TrainGenerator=False):
        if lang == 'En':
            pretrain_batch = get_torch_variable_from_np(
                batch_input['pretrain'])
        else:
            pretrain_batch = get_torch_variable_from_np(
                batch_input['pretrain'])
        flag_batch = get_torch_variable_from_np(batch_input['flag'])

        if withParallel:
            fr_pretrain_batch = get_torch_variable_from_np(
                batch_input['fr_pretrain'])
            fr_flag_batch = get_torch_variable_from_np(batch_input['fr_flag'])

        if self.use_flag_embedding:
            flag_emb = self.flag_embedding(flag_batch)
        else:
            flag_emb = flag_batch.view(flag_batch.shape[0],
                                       flag_batch.shape[1], 1).float()

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        else:
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        if withParallel:
            fr_pretrain_emb = self.fr_pretrained_embedding(
                fr_pretrain_batch).detach()
            fr_flag_emb = self.flag_embedding(fr_flag_batch).detach()

        input_emb = torch.cat([flag_emb, pretrain_emb], 2)
        predicates_1D = batch_input['predicates_idx']
        if withParallel:
            fr_input_emb = torch.cat([fr_flag_emb, fr_pretrain_emb], 2)

        output_en, enc_real = self.EN_Labeler(input_emb, predicates_1D)
        output_fr, _ = self.FR_Labeler(input_emb, predicates_1D)

        if not withParallel:
            if isPretrain:
                return output_en
            else:
                return output_fr

        predicates_1D = batch_input['fr_predicates_idx']
        _, enc_fake = self.FR_Labeler(fr_input_emb, predicates_1D)

        x_D = torch.cat([real_states.detach(), fake_states.detach()], 0)
        x_G = torch.cat([real_states.detach(), fake_states], 0)
        y = torch.FloatTensor(2 * self.batch_size).zero_().to(device)
        y[:self.batch_size] = 1 - self.dis_smooth
        y[self.batch_size:] = self.dis_smooth

        if not TrainGenerator:
            #prob_real_decision = self.Discriminator(real_states.detach())
            #prob_fake_decision = self.Discriminator(fake_states.detach())
            #D_loss= - torch.mean(torch.log(prob_real_decision) + torch.log(1. - prob_fake_decision))

            preds = self.Discriminator(Variable(x_D.data))
            D_loss = F.binary_cross_entropy(preds, 1 - y)
            return D_loss * get_torch_variable_from_np(
                batch_input['fr_loss_mask']).float()
        else:
            #prob_fake_decision_G = self.Discriminator(real_states.detach())
            #G_loss = -torch.mean(torch.log(prob_fake_decision_G))
            #log("G loss:", G_loss)
            preds = self.Discriminator(x_G)
            G_loss = F.binary_cross_entropy(preds, y)
            return G_loss * get_torch_variable_from_np(
                batch_input['fr_loss_mask']).float()
Esempio n. 23
0
    def parallel_train_(self, batch_input, use_bert, isTrain=True):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])

        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()
        actual_lens_fr = unlabeled_data_fr['seq_len']

        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        actual_lens_en = unlabeled_data_en['seq_len']
        flag_emb = self.flag_embedding(flag_batch).detach()
        seq_len = flag_emb.shape[1]
        seq_len_en = seq_len

        if use_bert:
            bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids'])
            bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask'])
            bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions'])

            bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr)
            bert_emb_fr = bert_emb_fr[0]
            bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach()
            bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach()

            for i in range(len(bert_emb_fr)):
                if i >= len(actual_lens_fr):
                    print("error")
                    break
                for j in range(len(bert_emb_fr[i])):
                    if j >= actual_lens_fr[i]:
                        bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb_fr_noise = gaussian(bert_emb_fr, isTrain, 0, 0.1).detach()
            bert_emb_fr = bert_emb_fr.detach()

            bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids'])
            bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask'])
            bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions'])

            bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en)
            bert_emb_en = bert_emb_en[0]
            bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach()
            bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach()

            for i in range(len(bert_emb_en)):
                if i >= len(actual_lens_en):
                    print("error")
                    break
                for j in range(len(bert_emb_en[i])):
                    if j >= actual_lens_en[i]:
                        bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb_en_noise = gaussian(bert_emb_en, isTrain, 0, 0.1).detach()
            bert_emb_en = bert_emb_en.detach()

        seq_len = flag_emb.shape[1]
        SRL_output = self.SR_Labeler(bert_emb_en, flag_emb.detach(), predicates_1D, seq_len, para=True, use_bert=True)

        CopyLoss_en_noise = self.copy_loss(SRL_output, bert_emb_en_noise, flag_emb.detach(), seq_len)
        CopyLoss_en = self.copy_loss(SRL_output, bert_emb_en, flag_emb.detach(), seq_len)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        SRL_input = F.softmax(SRL_input, 2)
        pred_recur = self.SR_Compressor(SRL_input.detach(), bert_emb_en,
                                        flag_emb.detach(), None, predicates_1D, seq_len, para=True, use_bert=True)

        seq_len_fr = flag_emb_fr.shape[1]
        SRL_output_fr = self.SR_Labeler(bert_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True,
                                        use_bert=True)

        CopyLoss_fr_noise = self.copy_loss(SRL_output_fr, bert_emb_fr_noise, flag_emb_fr.detach(), seq_len_fr)
        CopyLoss_fr = self.copy_loss(SRL_output_fr, bert_emb_fr, flag_emb_fr.detach(), seq_len_fr)


        SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1)
        SRL_input_fr = F.softmax(SRL_input_fr, 2)
        pred_recur_fr = self.SR_Compressor(SRL_input_fr, bert_emb_fr,
                                           flag_emb_fr.detach(), None, predicates_1D_fr, seq_len_fr, para=True,
                                           use_bert=True)


        """
        En event vector, En word
        """
        output_word_en_en = self.SR_Matcher(pred_recur.detach(), bert_emb_en, flag_emb.detach(), None, seq_len,
                                            para=True, use_bert=True).detach()
        score4Null = torch.zeros_like(output_word_en_en[:, 1:2])
        output_word_en_en = torch.cat((output_word_en_en[:, 0:1], score4Null, output_word_en_en[:, 1:]), 1)

        #############################################
        """
        Fr event vector, En word
        """
        output_word_fr_en = self.SR_Matcher(pred_recur_fr, bert_emb_en, flag_emb.detach(), None, seq_len,
                                            para=True, use_bert=True)
        score4Null = torch.zeros_like(output_word_fr_en[:, 1:2])
        output_word_fr_en = torch.cat((output_word_fr_en[:, 0:1], score4Null, output_word_fr_en[:, 1:]), 1)

        #############################################3
        """
        En event vector, Fr word
        """
        output_word_en_fr = self.SR_Matcher(pred_recur.detach(), bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr,
                                            para=True, use_bert=True).detach()
        score4Null = torch.zeros_like(output_word_en_fr[:, 1:2])
        output_word_en_fr = torch.cat((output_word_en_fr[:, 0:1], score4Null, output_word_en_fr[:, 1:]), 1)

        """
        Fr event vector, Fr word
        """
        output_word_fr_fr = self.SR_Matcher(pred_recur_fr, bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr,
                                            para=True, use_bert=True)
        score4Null = torch.zeros_like(output_word_fr_fr[:, 1:2])
        output_word_fr_fr = torch.cat((output_word_fr_fr[:, 0:1], score4Null, output_word_fr_fr[:, 1:]), 1)


        """
        mask_en_en, mask_en_fr = self.filter_word(SRL_input, output_word_en_en,output_word_en_fr, seq_len, seq_len_fr)
        mask_en_en = get_torch_variable_from_np(mask_en_en)
        mask_en_fr = get_torch_variable_from_np(mask_en_fr)

        mask_fr_en, mask_fr_fr = self.filter_word_fr(output_word_fr_en, output_word_fr_fr, seq_len, seq_len_fr)
        mask_fr_en = get_torch_variable_from_np(mask_fr_en)
        mask_fr_fr = get_torch_variable_from_np(mask_fr_fr)

        mask_en_word = mask_en_en + mask_fr_en - mask_en_en*mask_fr_en
        mask_fr_word = mask_en_fr + mask_fr_fr - mask_en_fr*mask_fr_fr
        """

        unlabeled_loss_function = nn.KLDivLoss(reduction='none')

        # output_word_en_en = F.softmax(output_word_en_en, dim=1).detach()
        # output_word_fr_en = F.log_softmax(output_word_fr_en, dim=1)
        # loss = unlabeled_loss_function(output_word_fr_en, output_word_en_en)
        #output_word_en_en = F.softmax(output_word_en_en, dim=1).detach()
        output_word_en_en = F.softmax(output_word_en_en, dim=1).detach()
        output_word_fr_en = F.log_softmax(output_word_fr_en, dim=1)
        loss = unlabeled_loss_function(output_word_fr_en, output_word_en_en).sum(dim=1)#*mask_en_word.view(-1)
        #loss = loss.sum() / mask_en_word.sum() #(self.batch_size * seq_len_en)
        #if mask_en_word.sum().cpu().numpy() > 1:
        loss = loss.sum() / (self.batch_size * seq_len)
        #else:
        #    loss = loss.sum()

        # output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach()
        output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach()
        output_word_fr_fr = F.log_softmax(output_word_fr_fr, dim=1)
        #output_word_fr_fr = F.log_softmax(SRL_output_fr, dim=1)
        loss_2 = unlabeled_loss_function(output_word_fr_fr, output_word_en_fr).sum(dim=1)#*mask_fr_word.view(-1)
        #if mask_fr_word.sum().cpu().numpy() > 1:
        loss_2 = loss_2.sum() / (self.batch_size * seq_len_fr)
        #else:
        #    loss_2 = loss_2.sum()

        #print(mask_en_word.sum())
        #print(mask_fr_word.sum())


        return loss, loss_2, CopyLoss_en, CopyLoss_fr, CopyLoss_en_noise, CopyLoss_fr_noise
Esempio n. 24
0
def train_1_epoc(srl_model,
                 criterion,
                 optimizer,
                 train_dataset,
                 labeled_dataset_fr,
                 batch_size,
                 word2idx,
                 fr_word2idx,
                 lemma2idx,
                 pos2idx,
                 pretrain2idx,
                 fr_pretrain2idx,
                 deprel2idx,
                 argument2idx,
                 idx2word,
                 shuffle=False,
                 lang='En',
                 dev_best_score=None,
                 test_best_score=None,
                 test_ood_best_score=None):
    for batch_i, train_input_data in enumerate(
            inter_utils.get_batch(train_dataset,
                                  batch_size,
                                  word2idx,
                                  fr_word2idx,
                                  lemma2idx,
                                  pos2idx,
                                  pretrain2idx,
                                  fr_pretrain2idx,
                                  deprel2idx,
                                  argument2idx,
                                  idx2word,
                                  shuffle=shuffle,
                                  lang=lang)):

        flat_argument = train_input_data['flat_argument']
        target_batch_variable = get_torch_variable_from_np(flat_argument)

        out, out_word = srl_model(train_input_data, lang='En')
        loss = criterion(out, target_batch_variable)
        loss_word = criterion(out_word, target_batch_variable)
        if batch_i % 50 == 0:
            log(batch_i, loss, loss_word)

        optimizer.zero_grad()
        (loss + loss_word).backward()
        optimizer.step()

        if batch_i > 0 and batch_i % show_steps == 0:

            _, pred = torch.max(out, 1)

            pred = get_data(pred)

            # pred = pred.reshape([bs, sl])

            log('\n')
            log('*' * 80)

            eval_train_batch(epoch, batch_i, loss.data[0], flat_argument, pred,
                             argument2idx)

            log('FR test:')
            score, dev_output = eval_data(srl_model,
                                          elmo,
                                          labeled_dataset_fr,
                                          batch_size,
                                          word2idx,
                                          fr_word2idx,
                                          lemma2idx,
                                          pos2idx,
                                          pretrain2idx,
                                          fr_pretrain2idx,
                                          deprel2idx,
                                          argument2idx,
                                          idx2argument,
                                          idx2word,
                                          False,
                                          dev_predicate_correct,
                                          dev_predicate_sum,
                                          lang='Fr')

            if dev_best_score is None or score[5] > dev_best_score[5]:
                dev_best_score = score
                output_predict(
                    os.path.join(
                        result_path, 'dev_argument_{:.2f}.pred'.format(
                            dev_best_score[2] * 100)), dev_output)
                # torch.save(srl_model, os.path.join(os.path.dirname(__file__),'model/best_{:.2f}.pkl'.format(dev_best_score[2]*100)))
            log('\tdev best P:{:.2f} R:{:.2f} F1:{:.2f} NP:{:.2f} NR:{:.2f} NF1:{:.2f}'
                .format(dev_best_score[0] * 100, dev_best_score[1] * 100,
                        dev_best_score[2] * 100, dev_best_score[3] * 100,
                        dev_best_score[4] * 100, dev_best_score[5] * 100))
    return dev_best_score
Esempio n. 25
0
def eval_pred_recog_data(model, elmo, dataset, batch_size, word2idx, lemma2idx,
                         pos2idx, pretrain2idx, pred2idx, idx2pred):
    model.eval()
    golden = []
    predict = []

    output_data = []
    cur_sentence = None
    cur_sentence_data = None

    for batch_i, input_data in enumerate(
            pred_recog_inter_utils.get_batch(dataset, batch_size, word2idx,
                                             lemma2idx, pos2idx, pretrain2idx,
                                             pred2idx)):

        target_pred = input_data['pred']

        flat_pred = input_data['flat_pred']

        target_batch_variable = get_torch_variable_from_np(flat_pred)

        sentence_id = input_data['sentence_id']
        word_id = input_data['word_id']
        sentence_len = input_data['sentence_len']
        seq_len = input_data['seq_len']
        bs = input_data['batch_size']
        psl = input_data['pad_seq_len']

        out = model(input_data, elmo)

        _, pred = torch.max(out, 1)

        pred = get_data(pred)

        pred = pred.tolist()

        golden += flat_pred.tolist()
        predict += pred

        pre_data = []
        for b in range(len(seq_len)):
            line_data = ['_' for _ in range(sentence_len[b])]
            for s in range(seq_len[b]):
                wid = word_id[b][s]
                line_data[wid - 1] = idx2pred[pred[b * psl + s]]
            pre_data.append(line_data)

        for b in range(len(sentence_id)):
            if cur_sentence != sentence_id[b]:
                if cur_sentence_data is not None:
                    output_data.append(cur_sentence_data)
                cur_sentence_data = [[sentence_id[b]] * len(pre_data[b]),
                                     pre_data[b]]
                cur_sentence = sentence_id[b]
            else:
                assert cur_sentence_data is not None
                cur_sentence_data.append(pre_data[b])

    if cur_sentence_data is not None and len(cur_sentence_data) > 0:
        output_data.append(cur_sentence_data)

    score = pred_recog_score(golden, predict, pred2idx)

    model.train()

    return score, output_data
Esempio n. 26
0
    def parallel_train(self, batch_input, use_bert, isTrain=True):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])

        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()
        actual_lens_fr = unlabeled_data_fr['seq_len']

        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        actual_lens_en = unlabeled_data_en['seq_len']
        flag_emb = self.flag_embedding(flag_batch).detach()
        seq_len = flag_emb.shape[1]
        seq_len_en = seq_len

        if use_bert:
            bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids'])
            bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask'])
            bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions'])

            bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr)
            bert_emb_fr = bert_emb_fr[0]
            bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach()
            bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach()

            for i in range(len(bert_emb_fr)):
                if i >= len(actual_lens_fr):
                    print("error")
                    break
                for j in range(len(bert_emb_fr[i])):
                    if j >= actual_lens_fr[i]:
                        bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            #bert_emb_fr = gaussian(bert_emb_fr, isTrain, 0, 0.1)
            bert_emb_fr = bert_emb_fr.detach()

            bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids'])
            bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask'])
            bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions'])

            bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en)
            bert_emb_en = bert_emb_en[0]
            #bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach()
            bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach()

            for i in range(len(bert_emb_en)):
                if i >= len(actual_lens_en):
                    print("error")
                    break
                for j in range(len(bert_emb_en[i])):
                    if j >= actual_lens_en[i]:
                        bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            #bert_emb_en = gaussian(bert_emb_en, isTrain, 0, 0.1)
            bert_emb_en = bert_emb_en.detach()

        seq_len = flag_emb.shape[1]
        SRL_output = self.SR_Labeler(bert_emb_en, flag_emb.detach(), predicates_1D, seq_len, para=True, use_bert=True)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        pred_recur = self.SR_Compressor(SRL_input.detach(), bert_emb_en,
                                        flag_emb.detach(), None, predicates_1D, seq_len, para=True, use_bert=True)

        seq_len_fr = flag_emb_fr.shape[1]
        SRL_output_fr = self.SR_Labeler(bert_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True,
                                        use_bert=True)

        SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1)
        pred_recur_fr = self.SR_Compressor(SRL_input_fr, bert_emb_fr,
                                           flag_emb_fr.detach(), None, predicates_1D_fr, seq_len_fr, para=True,
                                           use_bert=True)

        """
        En event vector, En word
        """
        output_word_en_en = self.SR_Matcher(pred_recur.detach(), bert_emb_en, flag_emb.detach(), None, seq_len,
                                            para=True, use_bert=True)
        output_word_en_en = F.softmax(output_word_en_en, dim=1).detach()

        output_word_en_en_nonNull = torch.cat((output_word_en_en[:, 0:1], output_word_en_en[:, 2:]), 1)
        output_en_en_nonNull_max, output_en_en_nonNull_argmax = torch.max(output_word_en_en_nonNull, 1)

        #############################################
        """
        Fr event vector, En word
        """
        output_word_fr_en = self.SR_Matcher(pred_recur_fr, bert_emb_en, flag_emb.detach(), None, seq_len,
                                            para=True, use_bert=True)
        output_word_fr_en = F.softmax(output_word_fr_en, dim=1)



        #############################################3
        """
        En event vector, Fr word
        """
        output_word_en_fr = self.SR_Matcher(pred_recur.detach(), bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr,
                                            para=True, use_bert=True)
        output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach()

        output_word_en_fr_nonNull = torch.cat((output_word_en_fr[:, 0:1], output_word_en_fr[:, 2:]), 1)
        output_en_fr_nonNull_max, output_en_fr_nonNull_argmax = torch.max(output_word_en_fr_nonNull, 1)


        """
        Fr event vector, Fr word
        """
        output_word_fr_fr = self.SR_Matcher(pred_recur_fr, bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr,
                                            para=True, use_bert=True)
        output_word_fr_fr = F.softmax(output_word_fr_fr, dim=1)

        output_word_fr_fr_nonNull = torch.cat((output_word_fr_fr[:, 0:1], output_word_fr_fr[:, 2:]), 1)
        output_word_fr_fr_nonNull_maxarg = torch.gather(output_word_fr_fr_nonNull, 1, output_en_fr_nonNull_argmax.view(-1,1))


        ## B*T R 2
        #Union_enfr_fr = torch.cat((output_word_en_fr.view(-1, self.target_vocab_size, 1),
        #                           output_word_fr_fr.view(-1, self.target_vocab_size, 1)), 2)
        ## B*T R
        #max_enfr_fr = torch.max(output_word_fr_fr, output_word_en_fr).detach()
        #max_enfr_fr[:, :2] = output_word_fr_fr[:,:2].detach()


        unlabeled_loss_function = nn.L1Loss(reduction='none')
        loss = unlabeled_loss_function(output_word_fr_en[:, 1], output_word_en_en[:, 1])
        theta = torch.gt(output_word_en_en[:, 1], output_en_en_nonNull_max)
        loss = theta * loss
        if torch.gt(theta.sum(), 0):
            loss = loss.sum() /theta.sum()
        else:
            loss = loss.sum()


        loss_2 = unlabeled_loss_function(output_word_fr_fr_nonNull_maxarg.view(-1), output_en_fr_nonNull_max)
        theta = torch.gt(output_en_fr_nonNull_max, output_word_en_fr[:, 1])
        loss_2 = theta*loss_2
        if torch.gt(theta.sum(), 0):
            loss_2 = loss_2.sum() /theta.sum()
        else:
            loss_2 = loss_2.sum()


        return loss, loss_2
Esempio n. 27
0
    def word_trans(self, batch_input, use_bert, isTrain=True):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])

        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()
        actual_lens_fr = unlabeled_data_fr['seq_len']

        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        actual_lens_en = unlabeled_data_en['seq_len']
        flag_emb = self.flag_embedding(flag_batch).detach()
        seq_len = flag_emb.shape[1]
        seq_len_en = seq_len

        if use_bert:
            bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids'])
            bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask'])
            bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions'])

            bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr)
            bert_emb_fr = bert_emb_fr[0]
            bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach()
            bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach()

            for i in range(len(bert_emb_fr)):
                if i >= len(actual_lens_fr):
                    print("error")
                    break
                for j in range(len(bert_emb_fr[i])):
                    if j >= actual_lens_fr[i]:
                        bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb_fr_noise = gaussian(bert_emb_fr, isTrain, 0, 0.1).detach()
            bert_emb_fr = bert_emb_fr.detach()

            bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids'])
            bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask'])
            bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions'])

            bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en)
            bert_emb_en = bert_emb_en[0]
            bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach()
            bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach()

            for i in range(len(bert_emb_en)):
                if i >= len(actual_lens_en):
                    print("error")
                    break
                for j in range(len(bert_emb_en[i])):
                    if j >= actual_lens_en[i]:
                        bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))
            bert_emb_en_noise = gaussian(bert_emb_en, isTrain, 0, 0.1).detach()
            bert_emb_en = bert_emb_en.detach()

            pred_bert_fr = bert_emb_fr[np.arange(0, self.batch_size), predicates_1D_fr]
            pred_bert_en = bert_emb_en[np.arange(0, self.batch_size), predicates_1D]
            #transed_bert_fr = self.Fr2En_Trans(pred_bert_fr)
            En_Extracted = self.bert_FeatureExtractor(pred_bert_en)
            Fr_Extracted = self.bert_FeatureExtractor(pred_bert_fr)
            #loss = nn.MSELoss()
            #l2loss = loss(Fr_Extracted, En_Extracted)
            #return l2loss

            x_D_real = En_Extracted.view(-1, 256)#self.bert_NonlinearTrans(pred_bert_en.detach().view(-1, 768))
            x_D_fake = Fr_Extracted.view(-1, 256)
            #x_D_real = self.En_LinearTrans(pred_bert_en.detach()).view(-1, 768)
            #x_D_fake = self.Fr_LinearTrans(pred_bert_fr.detach()).view(-1, 768)
            en_preds = self.Discriminator(x_D_real).view(self.batch_size, 2)
            real_labels = torch.empty((30,1), dtype=torch.long).fill_(1).view(-1)
            #D_loss_real = F.binary_cross_entropy(en_preds, real_labels)
            fr_preds = self.Discriminator(x_D_fake).view(self.batch_size, 2)
            fake_labels = torch.empty((30,1), dtype=torch.long).fill_(0).view(-1)
            #D_loss_fake = F.binary_cross_entropy(fr_preds, fake_labels)
            #D_loss = 0.5 * (D_loss_real + D_loss_fake)
            preds = torch.cat((en_preds, fr_preds), 0)
            labels = torch.cat((real_labels, fake_labels)).to(device)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(preds, labels)
            return loss
Esempio n. 28
0
    def forward(self,
                batch_input,
                lang='En',
                unlabeled=False,
                self_constrain=False,
                use_bert=False,
                isTrain=False):
        if unlabeled:
            loss, loss_2 = self.parallel_train_(batch_input, use_bert)

            return loss, loss_2  #, copy_loss_en, copy_loss_fr
            #l2loss = self.word_trans(batch_input, use_bert)
            #return l2loss

        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_id = get_torch_variable_from_np(batch_input['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        actual_lens = batch_input['seq_len']
        # print(actual_lens)
        if use_bert:
            bert_input_ids = get_torch_variable_from_np(
                batch_input['bert_input_ids'])
            bert_input_mask = get_torch_variable_from_np(
                batch_input['bert_input_mask'])
            bert_out_positions = get_torch_variable_from_np(
                batch_input['bert_out_positions'])

            bert_emb = self.model(bert_input_ids,
                                  attention_mask=bert_input_mask)
            bert_emb = bert_emb[0]
            bert_emb = bert_emb[:, 1:-1, :].contiguous().detach()

            bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1),
                                bert_out_positions].detach()

            for i in range(len(bert_emb)):
                if i >= len(actual_lens):
                    break
                for j in range(len(bert_emb[i])):
                    if j >= actual_lens[i]:
                        bert_emb[i][j] = get_torch_variable_from_np(
                            np.zeros(768, dtype="float32"))

            bert_emb = bert_emb.detach()
            #bert_emb_noise = gaussian(bert_emb, isTrain, 0, 0.1).detach()

        seq_len = flag_emb.shape[1]

        SRL_output = self.SR_Labeler(bert_emb,
                                     flag_emb,
                                     predicates_1D,
                                     seq_len,
                                     para=False,
                                     use_bert=True)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        SRL_input_probs = F.softmax(SRL_input, 2).detach()

        if isTrain:
            role_embs = self.SR_Compressor(SRL_input_probs,
                                           bert_emb.detach(),
                                           flag_emb.detach(),
                                           word_id_emb,
                                           predicates_1D,
                                           seq_len,
                                           para=False,
                                           use_bert=True)

            output_word = self.SR_Matcher(role_embs,
                                          bert_emb.detach(),
                                          flag_emb.detach(),
                                          word_id_emb.detach(),
                                          seq_len,
                                          copy=True,
                                          para=False,
                                          use_bert=True)
        else:
            role_embs = self.SR_Compressor(SRL_input_probs,
                                           bert_emb.detach(),
                                           flag_emb.detach(),
                                           word_id_emb,
                                           predicates_1D,
                                           seq_len,
                                           para=False,
                                           use_bert=True)

            output_word = self.SR_Matcher(role_embs,
                                          bert_emb.detach(),
                                          flag_emb.detach(),
                                          word_id_emb.detach(),
                                          seq_len,
                                          copy=False,
                                          para=False,
                                          use_bert=True)
        #score4Null = torch.zeros_like(output_word[:, 1:2])
        #output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1)

        recover_loss = self.learn_loss(SRL_input, output_word, seq_len)

        return SRL_output, output_word, recover_loss
Esempio n. 29
0
    def forward(self, batch_input, lang='En', unlabeled=False, self_constrain=False, use_bert=False, isTrain=False):
        if unlabeled:
            loss, loss_2, copy_loss_en, copy_loss_fr,a,b = self.parallel_train_(batch_input, use_bert)

            return loss, loss_2, copy_loss_en, copy_loss_fr, a, b
            #l2loss = self.word_trans(batch_input, use_bert)
            #return l2loss

        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_id = get_torch_variable_from_np(batch_input['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        actual_lens = batch_input['seq_len']

        if use_bert:
            bert_input_ids = get_torch_variable_from_np(batch_input['bert_input_ids'])
            bert_input_mask = get_torch_variable_from_np(batch_input['bert_input_mask'])
            bert_out_positions = get_torch_variable_from_np(batch_input['bert_out_positions'])

            bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask)
            bert_emb = bert_emb[0]
            bert_emb = bert_emb[:, 1:-1, :].contiguous().detach()

            bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1), bert_out_positions].detach()

            for i in range(len(bert_emb)):
                if i >= len(actual_lens):
                    break
                for j in range(len(bert_emb[i])):
                    if j >= actual_lens[i]:
                        bert_emb[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32"))


            bert_emb = bert_emb.detach()
        #bert_emb = self.bert_NonlinearTrans(bert_emb)
        #bert_emb_noise = gaussian(bert_emb, isTrain, 0, 0.1).detach()

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
            #bert_emb = gaussian(bert_emb, isTrain, 0, 0.1).detach()
            #bert_emb = self.En_LinearTrans(bert_emb).detach()
        else:
            pretrain_emb = self.fr_pretrained_embedding(pretrain_batch).detach()
            #bert_emb = self.Fr_LinearTrans(bert_emb).detach()
            #bert_emb = self.Fr2En_Trans(bert_emb).detach()
        #bert_emb = self.bert_FeatureExtractor(bert_emb)
        seq_len = flag_emb.shape[1]
        if not use_bert:
            SRL_output = self.SR_Labeler(pretrain_emb, flag_emb, predicates_1D, seq_len, para=False)

            SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
            SRL_input = SRL_input
            pred_recur = self.SR_Compressor(SRL_input, pretrain_emb,
                                            flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False)

            output_word = self.SR_Matcher(pred_recur, pretrain_emb, flag_emb.detach(), word_id_emb.detach(), seq_len,
                                          para=False)

        else:
            SRL_output = self.SR_Labeler(bert_emb, flag_emb, predicates_1D, seq_len, para=False, use_bert=True)

            SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
            SRL_input_probs = F.softmax(SRL_input, 2).detach()


            if isTrain:
                pred_recur = self.SR_Compressor(SRL_input_probs, bert_emb.detach(),
                                                flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False,
                                                use_bert=False)

                output_word = self.SR_Matcher(pred_recur, bert_emb.detach(), flag_emb.detach(), word_id_emb.detach(), seq_len, copy=True,
                                              para=False, use_bert=False)
            else:
                pred_recur = self.SR_Compressor(SRL_input_probs, bert_emb.detach(),
                                                flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False,
                                                use_bert=False)

                output_word = self.SR_Matcher(pred_recur, bert_emb.detach(), flag_emb.detach(), word_id_emb.detach(),
                                              seq_len, para=False, use_bert=False)

            score4Null = torch.zeros_like(output_word[:, 1:2])
            output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1)

            teacher = F.softmax(SRL_input.view(self.batch_size * seq_len, -1), dim=1).detach()
            student = F.log_softmax(output_word, dim=1)
            unlabeled_loss_function = nn.KLDivLoss(reduction='none')

            loss_copy = unlabeled_loss_function(student, teacher).view(self.batch_size*seq_len,-1)
            loss_copy = loss_copy.sum() / (self.batch_size*seq_len)

            #CopyLoss = self.learn_loss(SRL_output, pretrain_emb, flag_emb.detach(), seq_len, mask_copy,
            #                              mask_unk)


        return SRL_output, output_word, loss_copy
Esempio n. 30
0
                                          idx2word,
                                          shuffle=True)):

                target_argument = train_input_data['argument']

                flat_argument = train_input_data['flat_argument']

                gold_pos = train_input_data['gold_pos']

                gold_PI = train_input_data['predicates_flag']

                gold_deprel = train_input_data['sep_dep_rel']

                gold_link = train_input_data['sep_dep_link']

                target_batch_variable = get_torch_variable_from_np(
                    flat_argument)
                gold_pos_batch_variable = get_torch_variable_from_np(gold_pos)
                gold_PI_batch_variable = get_torch_variable_from_np(gold_PI)
                gold_deprel_batch_variable = get_torch_variable_from_np(
                    gold_deprel)
                gold_link_batch_variable = get_torch_variable_from_np(
                    gold_link)

                bs = train_input_data['batch_size']
                sl = train_input_data['seq_len']

                out, out_pos, out_PI, out_deprel, out_link = srl_model(
                    train_input_data, elmo)

                loss = criterion(out, target_batch_variable)