예제 #1
0
    def forward(self, pretrained_emb, flag_emb, predicates_1D, seq_len, use_bert=False, para=False):
        self.bilstm_hidden_state = (
            torch.zeros(2 * self.bilstm_num_layers, self.batch_size, self.bilstm_hidden_size).to(device),
            torch.zeros(2 * self.bilstm_num_layers, self.batch_size, self.bilstm_hidden_size).to(device))

        input_emb = torch.cat((pretrained_emb, flag_emb), 2)
        if para == False:
            input_emb = self.dropout_word(input_emb)
        if not use_bert:
            bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(input_emb, self.bilstm_hidden_state)
        else:
            bilstm_output, (_, bilstm_final_state) = self.bilstm_bert(input_emb, self.bilstm_hidden_state)

        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        if para == False:
            hidden_input = self.dropout_hidden(hidden_input)
        arg_hidden = self.mlp_arg(hidden_input)
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.mlp_pred(pred_recur)
        SRL_output = bilinear(arg_hidden, self.rel_W, pred_hidden, self.mlp_size, seq_len, 1, self.batch_size,
                              num_outputs=self.target_vocab_size, bias_x=True, bias_y=True)
        SRL_output = SRL_output.view(self.batch_size * seq_len, -1)
        return SRL_output
예제 #2
0
    def forward(self, batch_input_emb, predicates_1D):
        input_emb = batch_input_emb
        seq_len = batch_input_emb.shape[1]
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)

        arg_hidden = self.mlp_arg(hidden_input)
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.mlp_pred(pred_recur)
        output = bilinear(arg_hidden,
                          self.rel_W,
                          pred_hidden,
                          self.mlp_size,
                          seq_len,
                          1,
                          self.batch_size,
                          num_outputs=self.target_vocab_size,
                          bias_x=True,
                          bias_y=True)
        en_output = output.view(self.batch_size * seq_len, -1)
        """
        cat_output = en_output.view(self.batch_size, seq_len, -1)
        pred_recur = pred_recur.unsqueeze(1).expand(self.batch_size, seq_len, 2*self.bilstm_hidden_size)
        all_cat = torch.cat((hidden_input, pred_recur, cat_output), 2)
        shuffled_timestep = np.arange(0, seq_len)
        np.random.shuffle(shuffled_timestep)
        all_cat = all_cat.index_select(dim=1, index=get_torch_variable_from_np(shuffled_timestep))
        all_cat = self.out_dropout(all_cat)
        """
        enc_real = torch.mean(hidden_input, dim=1)
        return en_output, enc_real
예제 #3
0
    def forward(self, batch_input, elmo):

        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_batch = get_torch_variable_from_np(batch_input['word'])
        lemma_batch = get_torch_variable_from_np(batch_input['lemma'])
        pos_batch = get_torch_variable_from_np(batch_input['pos'])
        deprel_batch = get_torch_variable_from_np(batch_input['deprel'])
        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        origin_batch = batch_input['origin']
        origin_deprel_batch = batch_input['deprel']
        chars_batch = get_torch_variable_from_np(batch_input['char'])

        if self.use_flag_embedding:
            flag_emb = self.flag_embedding(flag_batch)
        else:
            flag_emb = flag_batch.view(flag_batch.shape[0],
                                       flag_batch.shape[1], 1).float()
        seq_len = flag_batch.shape[1]
        word_emb = self.word_embedding(word_batch)
        lemma_emb = self.lemma_embedding(lemma_batch)
        pos_emb = self.pos_embedding(pos_batch)
        char_embeddings = self.char_embeddings(chars_batch)
        character_embeddings = self.charCNN(char_embeddings)
        pretrain_emb = self.pretrained_embedding(pretrain_batch)

        if self.use_deprel:
            deprel_emb = self.deprel_embedding(deprel_batch)

        # predicate_emb = self.word_embedding(predicate_batch)
        # predicate_pretrain_emb = self.pretrained_embedding(predicate_pretrain_batch)

        ##sentence learner#####################################
        SL_input_emb = self.word_dropout(
            torch.cat([word_emb, pretrain_emb, character_embeddings], 2))
        h0, (_,
             SL_final_state) = self.sentence_learner(SL_input_emb,
                                                     self.SL_hidden_state0)
        h1, (_, SL_final_state) = self.sentence_learner_high(
            h0, self.SL_hidden_state0_high)
        SL_output = h1
        POS_output = self.pos_classifier(SL_output).view(
            self.batch_size * seq_len, -1)
        PI_output = self.PI_classifier(SL_output).view(
            self.batch_size * seq_len, -1)
        ## deprel
        hidden_input = SL_output
        arg_hidden = self.mlp_dropout(self.mlp_arg_deprel(SL_output))
        predicates_1D = batch_input['predicates_idx']
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.pred_dropout(self.mlp_pred_deprel(pred_recur))
        deprel_output = bilinear(arg_hidden,
                                 self.deprel_W,
                                 pred_hidden,
                                 self.mlp_size,
                                 seq_len,
                                 1,
                                 self.batch_size,
                                 num_outputs=self.deprel_vocab_size,
                                 bias_x=True,
                                 bias_y=True)

        arg_hidden = self.mlp_dropout(self.mlp_arg_link(SL_output))
        predicates_1D = batch_input['predicates_idx']
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.pred_dropout(self.mlp_pred_link(pred_recur))
        Link_output = bilinear(arg_hidden,
                               self.link_W,
                               pred_hidden,
                               self.mlp_size,
                               seq_len,
                               1,
                               self.batch_size,
                               num_outputs=4,
                               bias_x=True,
                               bias_y=True)
        deprel_output = deprel_output.view(self.batch_size * seq_len, -1)
        Link_output = Link_output.view(self.batch_size * seq_len, -1)

        POS_probs = F.softmax(POS_output, dim=1).view(self.batch_size, seq_len,
                                                      -1)
        deprel_probs = F.softmax(deprel_output,
                                 dim=1).view(self.batch_size, seq_len, -1)
        link_probs = F.softmax(Link_output,
                               dim=1).view(self.batch_size, seq_len, -1)

        POS_compose = F.tanh(self.POS2hidden(POS_probs))
        deprel_compose = F.tanh(self.deprel2hidden(deprel_probs))
        link_compose = link_probs

        #######semantic role labelerxxxxxxxxxx

        if self.use_deprel:
            input_emb = torch.cat([
                flag_emb, word_emb, pretrain_emb, character_embeddings,
                POS_compose, deprel_compose, link_compose
            ], 2)  #
        else:
            input_emb = torch.cat([
                flag_emb, word_emb, pretrain_emb, character_embeddings,
                POS_compose
            ], 2)  #

        input_emb = self.word_dropout(input_emb)

        w = F.softmax(self.elmo_w, dim=0)
        SRL_composer = self.elmo_gamma * (w[0] * h0 + w[1] * h1)
        SRL_composer = self.elmo_mlp(SRL_composer)
        bilstm_output_0, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state0)
        high_input = torch.cat((bilstm_output_0, SRL_composer), 2)
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer_high(
            high_input, self.bilstm_hidden_state0_high)

        # bilstm_final_state = bilstm_final_state.view(self.bilstm_num_layers, 2, self.batch_size, self.bilstm_hidden_size)

        # bilstm_final_state = bilstm_final_state[-1]

        # sentence latent representation
        # bilstm_final_state = torch.cat([bilstm_final_state[0], bilstm_final_state[1]], 1)

        # bilstm_output = self.bilstm_mlp(bilstm_output)

        if self.use_self_attn:
            x = F.tanh(self.attn_linear_first(bilstm_output))
            x = self.attn_linear_second(x)
            x = self.softmax(x, 1)
            attention = x.transpose(1, 2)
            sentence_embeddings = torch.matmul(attention, bilstm_output)
            sentence_embeddings = torch.sum(sentence_embeddings,
                                            1) / self.self_attn_head

            context = sentence_embeddings.repeat(bilstm_output.size(1), 1,
                                                 1).transpose(0, 1)

            bilstm_output = torch.cat([bilstm_output, context], 2)

            bilstm_output = self.attn_linear_final(bilstm_output)

            # energy = self.biaf_attn(bilstm_output, bilstm_output)

            # # energy = energy.transpose(1, 2)

            # flag_indices = batch_input['flag_indices']

            # attention = []
            # for idx in range(len(flag_indices)):
            #     attention.append(energy[idx,:,:,flag_indices[idx]].view(1, self.bilstm_hidden_size, -1))

            # attention = torch.cat(attention,dim=0)

            # # attention = attention.transpose(1, 2)

            # attention = self.softmax(attention, 2)

            # # attention = attention.transpose(1,2)

            # sentence_embeddings = attention@bilstm_output

            # sentence_embeddings = torch.sum(sentence_embeddings,1)/self.self_attn_head

            # context = sentence_embeddings.repeat(bilstm_output.size(1), 1, 1).transpose(0, 1)

            # bilstm_output = torch.cat([bilstm_output, context], 2)

            # bilstm_output = self.attn_linear_final(bilstm_output)
        else:
            bilstm_output = bilstm_output.contiguous()

        if self.use_gcn:
            # in_rep = torch.matmul(bilstm_output, self.W_in)
            # out_rep = torch.matmul(bilstm_output, self.W_out)
            # self_rep = torch.matmul(bilstm_output, self.W_self)

            # child_indicies = batch_input['children_indicies']

            # head_batch = batch_input['head']

            # context = []
            # for idx in range(head_batch.shape[0]):
            #     states = []
            #     for jdx in range(head_batch.shape[1]):
            #         head_ind = head_batch[idx, jdx]-1
            #         childs = child_indicies[idx][jdx]
            #         state = self_rep[idx, jdx]
            #         if head_ind != -1:
            #               state = state + in_rep[idx, head_ind]
            #         for child in childs:
            #             state = state + out_rep[idx, child]
            #         state = F.relu(state + self.gcn_bias)
            #         states.append(state.unsqueeze(0))
            #     context.append(torch.cat(states, dim=0))

            # context = torch.cat(context, dim=0)

            # bilstm_output = context

            seq_len = bilstm_output.shape[1]

            adj_arc_in = np.zeros((self.batch_size * seq_len, 2),
                                  dtype='int32')
            adj_lab_in = np.zeros((self.batch_size * seq_len), dtype='int32')

            adj_arc_out = np.zeros((self.batch_size * seq_len, 2),
                                   dtype='int32')
            adj_lab_out = np.zeros((self.batch_size * seq_len), dtype='int32')

            mask_in = np.zeros((self.batch_size * seq_len), dtype='float32')
            mask_out = np.zeros((self.batch_size * seq_len), dtype='float32')

            mask_loop = np.ones((self.batch_size * seq_len, 1),
                                dtype='float32')

            for idx in range(len(origin_batch)):
                for jdx in range(len(origin_batch[idx])):

                    offset = jdx + idx * seq_len

                    head_ind = int(origin_batch[idx][jdx][10]) - 1

                    if head_ind == -1:
                        continue

                    dependent_ind = int(origin_batch[idx][jdx][4]) - 1

                    adj_arc_in[offset] = np.array([idx, dependent_ind])

                    adj_lab_in[offset] = np.array(
                        [origin_deprel_batch[idx, jdx]])

                    mask_in[offset] = 1

                    adj_arc_out[offset] = np.array([idx, head_ind])

                    adj_lab_out[offset] = np.array(
                        [origin_deprel_batch[idx, jdx]])

                    mask_out[offset] = 1

            if USE_CUDA:
                adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in)).cuda()
                adj_arc_out = torch.LongTensor(
                    np.transpose(adj_arc_out)).cuda()

                adj_lab_in = Variable(torch.LongTensor(adj_lab_in).cuda())
                adj_lab_out = Variable(torch.LongTensor(adj_lab_out).cuda())

                mask_in = Variable(
                    torch.FloatTensor(
                        mask_in.reshape(
                            (self.batch_size * seq_len, 1))).cuda())
                mask_out = Variable(
                    torch.FloatTensor(
                        mask_out.reshape(
                            (self.batch_size * seq_len, 1))).cuda())
                mask_loop = Variable(torch.FloatTensor(mask_loop).cuda())
            else:
                adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in))
                adj_arc_out = torch.LongTensor(np.transpose(adj_arc_out))

                adj_lab_in = Variable(torch.LongTensor(adj_lab_in))
                adj_lab_out = Variable(torch.LongTensor(adj_lab_out))

                mask_in = Variable(
                    torch.FloatTensor(
                        mask_in.reshape((self.batch_size * seq_len, 1))))
                mask_out = Variable(
                    torch.FloatTensor(
                        mask_out.reshape((self.batch_size * seq_len, 1))))
                mask_loop = Variable(torch.FloatTensor(mask_loop))

            gcn_context = self.syntactic_gcn(bilstm_output, adj_arc_in,
                                             adj_arc_out, adj_lab_in,
                                             adj_lab_out, mask_in, mask_out,
                                             mask_loop)

            #gcn_context = self.softmax(gcn_context, axis=2)
            gcn_context = F.softmax(gcn_context, dim=2)

            bilstm_output = torch.cat([bilstm_output, gcn_context], dim=2)

            bilstm_output = self.gcn_mlp(bilstm_output)

        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)

        if self.use_highway:
            for current_layer in self.highway_layers:
                hidden_input = current_layer(hidden_input)

            output = self.output_layer(hidden_input)
        else:
            hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
            #output = self.output_layer(hidden_input)

        if self.use_biaffine:
            arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input))
            predicates_1D = batch_input['predicates_idx']
            pred_recur = hidden_input[np.arange(0, self.batch_size),
                                      predicates_1D]
            pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur))
            output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)
            output = output.view(self.batch_size * seq_len, -1)

        return output, POS_output, PI_output, deprel_output, Link_output
예제 #4
0
파일: bilinear.py 프로젝트: Amabitur/CUDA
tex.set_address_mode(1, driver.address_mode.CLAMP)
driver.matrix_to_texref(prep_image, tex, order="C")

bilinear_interpolation_kernel(driver.Out(result),
                              driver.In(x_out),
                              driver.In(y_out),
                              np.int32(M1),
                              np.int32(N1),
                              np.int32(M2),
                              np.int32(N2),
                              block=block,
                              grid=grid,
                              texrefs=[tex])
big_image = normalize_image(result, image.shape[2])
stop.record()
stop.synchronize()
gpu_time = stop.time_since(start)
print("Время интерполяции на ГПУ: %.3f ms" % (gpu_time))
cv2.imwrite("./data/big-gpu-seal.jpg", big_image.astype(np.uint8))

#p_image = prepare_image(image)

print("Считаем на ЦПУ...")
start = timeit.default_timer()
cpu_result = bilinear(image)
cpu_time = timeit.default_timer() - start
print("Время интерполяции на ЦПУ: %.3f ms" % (cpu_time * 1e3))

#big_cpu_image = normalize_image(cpu_result, image.shape[2])

cv2.imwrite("./data/big-cpu-seal.jpg", cpu_result.astype(np.uint8))
예제 #5
0
    def forward(self, batch_input, elmo, withParallel=True, lang='En'):

        if lang == 'En':
            word_batch = get_torch_variable_from_np(batch_input['word'])
            pretrain_batch = get_torch_variable_from_np(
                batch_input['pretrain'])
        else:
            word_batch = get_torch_variable_from_np(batch_input['word'])
            pretrain_batch = get_torch_variable_from_np(
                batch_input['pretrain'])

        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        pos_batch = get_torch_variable_from_np(batch_input['pos'])
        pos_emb = self.pos_embedding(pos_batch)
        #log(pretrain_batch)
        #log(flag_batch)

        role_index = get_torch_variable_from_np(batch_input['role_index'])
        role_mask = get_torch_variable_from_np(batch_input['role_mask'])
        role_mask_expand = role_mask.unsqueeze(2).expand(
            self.batch_size, self.target_vocab_size, self.word_emb_size)

        role2word_batch = pretrain_batch.gather(dim=1, index=role_index)
        #log(role2word_batch)
        role2word_emb = self.pretrained_embedding(role2word_batch).detach()
        role2word_emb = role2word_emb * role_mask_expand.float()
        #log(role2word_emb[0][1])
        #log(role2word_emb[0][2])
        #log(role_mask)
        #log(role_mask_expand)
        #log("#################")
        #log(role_index)
        #log(role_mask)
        #log(role2word_batch)

        if withParallel:
            fr_word_batch = get_torch_variable_from_np(batch_input['fr_word'])
            fr_pretrain_batch = get_torch_variable_from_np(
                batch_input['fr_pretrain'])
            fr_flag_batch = get_torch_variable_from_np(batch_input['fr_flag'])

        #log(fr_pretrain_batch[0])
        #log(fr_flag_batch)
        if self.use_flag_embedding:
            flag_emb = self.flag_embedding(flag_batch)
        else:
            flag_emb = flag_batch.view(flag_batch.shape[0],
                                       flag_batch.shape[1], 1).float()

        seq_len = flag_batch.shape[1]
        if lang == "En":
            word_emb = self.word_embedding(word_batch)
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()

        else:
            word_emb = self.fr_word_embedding(word_batch)
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        if withParallel:
            #fr_word_emb = self.fr_word_embedding(fr_word_batch)
            fr_pretrain_emb = self.fr_pretrained_embedding(
                fr_pretrain_batch).detach()
            fr_flag_emb = self.flag_embedding(fr_flag_batch)
            fr_seq_len = fr_flag_batch.shape[1]
            role_mask_expand_timestep = role_mask.unsqueeze(2).expand(
                self.batch_size, self.target_vocab_size, fr_seq_len)
            role_mask_expand_forFR = \
                role_mask_expand.unsqueeze(1).expand(self.batch_size, fr_seq_len, self.target_vocab_size,
                                                     self.pretrain_emb_size)

        # predicate_emb = self.word_embedding(predicate_batch)
        # predicate_pretrain_emb = self.pretrained_embedding(predicate_pretrain_batch)

        #######semantic role labelerxxxxxxxxxx

        if self.use_deprel:
            input_emb = torch.cat([flag_emb, pretrain_emb], 2)  #
        else:
            input_emb = torch.cat([flag_emb, pretrain_emb], 2)  #

        if withParallel:
            fr_input_emb = torch.cat([fr_flag_emb, fr_pretrain_emb], 2)

        input_emb = self.word_dropout(input_emb)
        #input_emb = torch.cat((input_emb, get_torch_variable_from_np(np.zeros((self.batch_size, seq_len, self.target_vocab_size))).float()), 2)
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        #output = self.output_layer(hidden_input)

        arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input))
        predicates_1D = batch_input['predicates_idx']
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur))
        output = bilinear(arg_hidden,
                          self.rel_W,
                          pred_hidden,
                          self.mlp_size,
                          seq_len,
                          1,
                          self.batch_size,
                          num_outputs=self.target_vocab_size,
                          bias_x=True,
                          bias_y=True)
        en_output = output.view(self.batch_size * seq_len, -1)

        if withParallel:
            role2word_emb_1 = role2word_emb.view(self.batch_size,
                                                 self.target_vocab_size, -1)
            role2word_emb_2 = role2word_emb_1.unsqueeze(dim=1)
            role2word_emb_expand = role2word_emb_2.expand(
                self.batch_size, fr_seq_len, self.target_vocab_size,
                self.pretrain_emb_size)

            #log(role2word_emb_expand[0, 1, 4])
            #log(role2word_emb_expand[0, 2, 4])

            fr_pretrain_emb = fr_pretrain_emb.view(self.batch_size, fr_seq_len,
                                                   -1)
            fr_pretrain_emb = fr_pretrain_emb.unsqueeze(dim=2)
            fr_pretrain_emb_expand = fr_pretrain_emb.expand(
                self.batch_size, fr_seq_len, self.target_vocab_size,
                self.pretrain_emb_size)
            fr_pretrain_emb_expand = fr_pretrain_emb_expand  # * role_mask_expand_forFR.float()
            # B T R W
            emb_distance = fr_pretrain_emb_expand - role2word_emb_expand
            emb_distance = emb_distance * emb_distance
            # B T R
            emb_distance = emb_distance.sum(dim=3)
            emb_distance = torch.sqrt(emb_distance)
            emb_distance = emb_distance.transpose(1, 2)
            # emb_distance_min = emb_distance_min.expand(self.batch_size, fr_seq_len, self.target_vocab_size)
            # emb_distance = F.softmax(emb_distance, dim=1)*role_mask_expand_timestep.float()
            # B R
            emb_distance_min, emb_distance_argmin = torch.min(emb_distance,
                                                              dim=2,
                                                              keepdim=True)
            emb_distance_max, emb_distance_argmax = torch.max(emb_distance,
                                                              dim=2,
                                                              keepdim=True)
            emb_distance_max_expand = emb_distance_max.expand(
                self.batch_size, self.target_vocab_size, fr_seq_len)
            emb_distance_min_expand = emb_distance_min.expand(
                self.batch_size, self.target_vocab_size, fr_seq_len)
            emb_distance_nomalized = (emb_distance - emb_distance_min_expand
                                      ) / (emb_distance_max_expand -
                                           emb_distance_min_expand)
            #* role_mask_expand_timestep.float()
            # emb_distance_argmin_expand = emb_distance_argmin.expand(self.batch_size, fr_seq_len, self.target_vocab_size)
            # log("######################")

            emb_distance_nomalized = emb_distance_nomalized.detach()

            top2 = emb_distance_nomalized.topk(2,
                                               dim=2,
                                               largest=False,
                                               sorted=True)[0]

            top2_gap = top2[:, :, 1] - top2[:, :, 0]
            #log(top2_gap)
            #log(role_mask)
            #log(emb_distance_nomalized[0,:,2])
            #log(emb_distance_argmax)

            #emb_distance_argmin = emb_distance_argmin*role_mask
            #log(emb_distance_argmin)
            #log(emb_distance_nomalized[0][2])
            """
            weight_4_loss = emb_distance_nomalized
            #log(emb_distance_argmin[0,2])
            for i in range(self.batch_size):
                for j in range(self.target_vocab_size):
                    weight_4_loss[i,j].index_fill_(0, emb_distance_argmin[i][j], 1)
            """
            #log(weight_4_loss[0][2])

            fr_input_emb = self.word_dropout(fr_input_emb).detach()
            #fr_input_emb = torch.cat((fr_input_emb, emb_distance_nomalized), 2)
            fr_bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
                fr_input_emb, self.fr_bilstm_hidden_state)
            fr_bilstm_output = fr_bilstm_output.contiguous()
            hidden_input = fr_bilstm_output.view(
                fr_bilstm_output.shape[0] * fr_bilstm_output.shape[1], -1)
            hidden_input = hidden_input.view(self.batch_size, fr_seq_len, -1)
            arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input))
            predicates_1D = batch_input['fr_predicates_idx']
            #log(predicates_1D)
            #log(predicates_1D)
            pred_recur = hidden_input[np.arange(0, self.batch_size),
                                      predicates_1D]
            pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur))
            output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              fr_seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)
            output = output.view(self.batch_size, fr_seq_len, -1)
            output = output.transpose(1, 2)
            #B R T
            #output_p = F.softmax(output, dim=2)
            #output = output * role_mask_expand_timestep.float()
            #log(role_mask_expand_timestep)
            """
            output_exp = torch.exp(output)
            output_exp_weighted = output_exp #* weight_4_loss
            output_expsum = output_exp_weighted.sum(dim=2, keepdim=True).expand(self.batch_size, self.target_vocab_size,
                                                                                fr_seq_len)
            output = output_exp/output_expsum
            """
            #log(output_expsum)
            #log(output_p[0, 2])
            # B R 1
            #output_pmax, output_pmax_arg = torch.max(output_p, dim=2, keepdim=True)

            #log(output[0])
            #log(emb_distance[0,:, 2])
            #log(emb_distance_nomalized[0,:, 2])
            #log(role_mask[0])
            #log(role_mask[0])
            #log(output[0][0])
            #log("++++++++++++++++++++")
            #log(emb_distance)
            #log(emb_distance.gather(1, emb_distance_argmin))
            #output_pargminD = output_p.gather(2, emb_distance_argmin)
            #log(output_argminD)
            #weighted_distance = (output/output_argminD) * emb_distance_nomalized
            #bias = torch.FloatTensor(output_pmax.size()).fill_(1).cuda()
            #rank_loss = (output_pmax/output_pargminD)# * emb_distance_nomalized.gather(2, output_pmax_arg)
            #rank_loss = rank_loss.view(self.batch_size, self.target_vocab_size)
            #rank_loss = rank_loss * role_mask.float()
            #log("++++++++++++++++++++++")
            #log(output_max-output_argminD)
            #log(emb_distance_nomalized.gather(1, output_max_arg))
            # B R
            #weighted_distance = weighted_distance.squeeze() * role_mask.float()
            """
            output = F.softmax(output, dim=1)
            output = output.transpose(1, 2)
            fr_role2word_emb = torch.bmm(output, fr_pretrain_emb)
            criterion = nn.MSELoss(reduce=False)
            l2_loss = criterion(fr_role2word_emb.view(self.batch_size*self.target_vocab_size, -1),
                                  role2word_emb.view(self.batch_size*self.target_vocab_size, -1))
            """
            criterion = nn.CrossEntropyLoss(ignore_index=0, reduce=False)

            output = output.view(self.batch_size * self.target_vocab_size, -1)
            emb_distance_argmin = emb_distance_argmin * role_mask.unsqueeze(2)
            emb_distance_argmin = emb_distance_argmin.view(-1)
            #log(emb_distance_argmin[0][2])
            l2_loss = criterion(output, emb_distance_argmin)
            l2_loss = l2_loss.view(self.batch_size, self.target_vocab_size)
            top2_gap = top2_gap.view(self.batch_size, self.target_vocab_size)
            l2_loss = l2_loss * top2_gap
            l2_loss = l2_loss.sum()
            #l2_loss = F.nll_loss(torch.log(output), emb_distance_argmin, ignore_index=0)

            #log(emb_distance_argmin)
            #log(l2_loss)
            #log("+++++++++++++++++++++")
            #log(l2_loss)
            #log(batch_input['fr_loss_mask'])
            l2_loss = l2_loss * get_torch_variable_from_np(
                batch_input['fr_loss_mask']).float()
            l2_loss = l2_loss.sum()  #/float_role_mask.sum()
            #log("+")
            #log(l2_loss)
            return en_output, l2_loss
        return en_output
예제 #6
0
파일: GAN.py 프로젝트: RuiCaiNLP/Base
    def forward(self, batch_input, unlabeled_batch_input, lang, unlabeled):

        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        flag_emb = self.flag_embedding(flag_batch)

        if unlabeled:
            fr_pretrain_batch = get_torch_variable_from_np(
                unlabeled_batch_input['pretrain'])
            fr_flag_batch = get_torch_variable_from_np(
                unlabeled_batch_input['flag'])
            fr_pretrain_emb = self.fr_pretrained_embedding(
                fr_pretrain_batch).detach()
            fr_flag_emb = self.flag_embedding(fr_flag_batch).detach()
            predicates_1D = unlabeled_batch_input['predicates_idx']
        else:
            predicates_1D = batch_input['predicates_idx']

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        else:
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()

        if not unlabeled:
            input_emb = torch.cat([flag_emb, pretrain_emb], 2)
        else:
            input_emb = torch.cat([fr_flag_emb, fr_pretrain_emb], 2)

        input_emb = self.word_dropout(input_emb)
        seq_len = input_emb.shape[1]
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        hidden_input = self.out_dropout(hidden_input)
        #predicate_hidden = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        #predicates_hidden = predicate_hidden.unsqueeze(1).expand(self.batch_size, seq_len, 2*self.bilstm_hidden_size)
        arg_hidden = self.mlp_arg(hidden_input)
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.mlp_pred(pred_recur)
        output = bilinear(arg_hidden,
                          self.rel_W,
                          pred_hidden,
                          self.mlp_size,
                          seq_len,
                          1,
                          self.batch_size,
                          num_outputs=self.target_vocab_size,
                          bias_x=True,
                          bias_y=True)
        output = output.view(self.batch_size * seq_len, -1)
        """
        cat_output = en_output.view(self.batch_size, seq_len, -1)
        pred_recur = pred_recur.unsqueeze(1).expand(self.batch_size, seq_len, 2*self.bilstm_hidden_size)
        all_cat = torch.cat((hidden_input, pred_recur, cat_output), 2)
        shuffled_timestep = np.arange(0, seq_len)
        np.random.shuffle(shuffled_timestep)
        all_cat = all_cat.index_select(dim=1, index=get_torch_variable_from_np(shuffled_timestep))
        all_cat = self.out_dropout(all_cat)
        """
        #enc = torch.mean(hidden_input, dim=1)
        #enc = Input4Gan_0(hidden_input, predicates_1D)
        enc = Input4Gan_1(output.view(self.batch_size, seq_len, -1),
                          predicates_1D)
        return output, enc.view(self.batch_size, seq_len, -1)
예제 #7
0
파일: model.py 프로젝트: RuiCaiNLP/SRL_word
    def forward(self, batch_input, lang='En', unlabeled=False):
        if unlabeled:
            #self.batch_size=1
            loss = self.parallel_train(batch_input)
            #loss_word = self.word_train(batch_input)
            loss_word = 0
            #self.batch_size = 30
            return loss, loss_word
        word_batch = get_torch_variable_from_np(batch_input['word'])
        pretrain_batch = get_torch_variable_from_np(batch_input['pretrain'])
        predicates_1D = batch_input['predicates_idx']
        flag_batch = get_torch_variable_from_np(batch_input['flag'])
        word_id = get_torch_variable_from_np(batch_input['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)

        if lang == "En":
            pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
            input_emb = torch.cat((pretrain_emb, flag_emb), 2)
            input_emb_word = input_emb
        else:
            pretrain_emb = self.fr_pretrained_embedding(
                pretrain_batch).detach()
            pretrain_emb_matrixed = self.word_matrix(pretrain_emb).detach()
            input_emb = torch.cat((pretrain_emb, flag_emb), 2)
            input_emb_word = torch.cat((pretrain_emb, flag_emb), 2)
        """
        if lang == "En":
            word_emb = self.word_embedding(word_batch)
        else:
            word_emb = self.fr_word_embedding(word_batch)
        """

        #input_emb = torch.cat((pretrain_emb, word_emb), 2)

        input_emb = self.word_dropout(input_emb)
        seq_len = input_emb.shape[1]
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        hidden_input = self.out_dropout(hidden_input)
        arg_hidden = self.mlp_arg(hidden_input)
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.mlp_pred(pred_recur)
        SRL_output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)

        SRL_output = SRL_output.view(self.batch_size * seq_len, -1)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        SRL_input = SRL_input.detach()
        compress_input = torch.cat(
            (input_emb_word.detach(), word_id_emb, SRL_input), 2)
        bilstm_output_word, (_,
                             bilstm_final_state_word) = self.bilstm_layer_word(
                                 compress_input, self.bilstm_hidden_state_word)
        bilstm_output_word = bilstm_output_word.contiguous()
        #hidden_input_word = bilstm_output_word.view(bilstm_output_word.shape[0] * bilstm_output_word.shape[1], -1)
        pred_recur = bilstm_output_word[np.arange(0, self.batch_size),
                                        predicates_1D]
        pred_recur = pred_recur.view(self.batch_size,
                                     self.bilstm_hidden_size * 2)
        pred_recur = pred_recur.unsqueeze(1).expand(
            self.batch_size, seq_len, self.bilstm_hidden_size * 2)
        combine = torch.cat(
            (pred_recur, input_emb_word.detach(), word_id_emb.detach()), 2)
        output_word = self.match_word(combine)
        output_word = output_word.view(self.batch_size * seq_len, -1)
        return SRL_output, output_word
예제 #8
0
파일: model.py 프로젝트: RuiCaiNLP/SRL_word
    def word_train(self, batch_input):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        pretrain_batch_fr = get_torch_variable_from_np(
            unlabeled_data_fr['pretrain'])
        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])
        # log(flag_batch_fr)
        word_id_fr = get_torch_variable_from_np(
            unlabeled_data_fr['word_times'])
        word_id_emb_fr = self.id_embedding(word_id_fr).detach()
        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()
        pretrain_emb_fr = self.fr_pretrained_embedding(
            pretrain_batch_fr).detach()
        pretrain_emb_fr = self.word_matrix(pretrain_emb_fr)
        input_emb_fr = torch.cat((pretrain_emb_fr, flag_emb_fr), 2)
        seq_len_fr = input_emb_fr.shape[1]

        word_batch = get_torch_variable_from_np(unlabeled_data_en['word'])
        pretrain_batch = get_torch_variable_from_np(
            unlabeled_data_en['pretrain'])
        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        #log(flag_batch)
        word_id = get_torch_variable_from_np(unlabeled_data_en['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        word_id_emb_en = word_id_emb.detach()
        pretrain_emb_en = pretrain_emb
        input_emb = torch.cat((pretrain_emb, flag_emb), 2)
        #input_emb = self.word_dropout(input_emb)
        input_emb_en = input_emb
        seq_len = input_emb.shape[1]
        seq_len_en = seq_len
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state_p)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        arg_hidden = self.mlp_arg(hidden_input)
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.mlp_pred(pred_recur)
        SRL_output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)
        SRL_output = SRL_output.view(self.batch_size * seq_len, -1)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        compress_input = torch.cat((input_emb, word_id_emb, SRL_input), 2)
        bilstm_output_word, (_,
                             bilstm_final_state_word) = self.bilstm_layer_word(
                                 compress_input,
                                 self.bilstm_hidden_state_word_p)
        bilstm_output_word = bilstm_output_word.contiguous()
        # hidden_input_word = bilstm_output_word.view(bilstm_output_word.shape[0] * bilstm_output_word.shape[1], -1)
        pred_recur = bilstm_output_word[np.arange(0, self.batch_size),
                                        predicates_1D]
        pred_recur = pred_recur.view(self.batch_size,
                                     self.bilstm_hidden_size * 2)
        pred_recur_1 = pred_recur.unsqueeze(1).expand(
            self.batch_size, seq_len, self.bilstm_hidden_size * 2)
        pred_recur_2 = pred_recur.unsqueeze(1).expand(
            self.batch_size, seq_len_fr, self.bilstm_hidden_size * 2)
        pred_recur_en = pred_recur_1
        pred_recur_en_2 = pred_recur_2

        combine = torch.cat((pred_recur_en, input_emb, word_id_emb), 2)
        output_word = self.match_word(combine)
        output_word_en = output_word.view(self.batch_size, seq_len, -1)
        output_word_en = F.softmax(output_word_en, 2)
        max_role_en = torch.max(output_word_en, 1)[0].detach()

        combine_fr = torch.cat(
            (pred_recur_en_2.detach(), input_emb_fr, word_id_emb_fr.detach()),
            2)
        output_word_fr = self.match_word(combine_fr)
        output_word_fr = output_word_fr.view(self.batch_size, seq_len_fr, -1)
        output_word_fr = F.softmax(output_word_fr, 2)
        max_role_fr = torch.max(output_word_fr, 1)[0]
        loss = nn.MSELoss()
        word_loss = loss(max_role_fr, max_role_en)
        return word_loss
예제 #9
0
파일: model.py 프로젝트: RuiCaiNLP/SRL_word
    def parallel_train(self, batch_input):
        unlabeled_data_en, unlabeled_data_fr = batch_input

        pretrain_batch_fr = get_torch_variable_from_np(
            unlabeled_data_fr['pretrain'])
        predicates_1D_fr = unlabeled_data_fr['predicates_idx']
        flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag'])
        # log(flag_batch_fr)
        word_id_fr = get_torch_variable_from_np(
            unlabeled_data_fr['word_times'])
        word_id_emb_fr = self.id_embedding(word_id_fr).detach()
        flag_emb_fr = self.flag_embedding(flag_batch_fr).detach()

        pretrain_emb_fr = self.fr_pretrained_embedding(
            pretrain_batch_fr).detach()
        pretrain_emb_fr_matrixed = self.word_matrix(pretrain_emb_fr)
        input_emb_fr = torch.cat((pretrain_emb_fr, flag_emb_fr), 2).detach()
        #input_emb_fr_matrixed = torch.cat((pretrain_emb_fr_matrixed, flag_emb_fr), 2).detach()
        seq_len_fr = input_emb_fr.shape[1]

        pretrain_batch = get_torch_variable_from_np(
            unlabeled_data_en['pretrain'])
        predicates_1D = unlabeled_data_en['predicates_idx']
        flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag'])
        #log(flag_batch)
        word_id = get_torch_variable_from_np(unlabeled_data_en['word_times'])
        word_id_emb = self.id_embedding(word_id)
        flag_emb = self.flag_embedding(flag_batch)
        pretrain_emb = self.pretrained_embedding(pretrain_batch).detach()
        word_id_emb_en = word_id_emb.detach()
        pretrain_emb_en = pretrain_emb
        input_emb = torch.cat((pretrain_emb, flag_emb), 2).detach()
        #input_emb = self.word_dropout(input_emb)
        input_emb_en = input_emb.detach()
        seq_len = input_emb.shape[1]
        seq_len_en = seq_len
        bilstm_output, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb, self.bilstm_hidden_state_p)
        bilstm_output = bilstm_output.contiguous()
        hidden_input = bilstm_output.view(
            bilstm_output.shape[0] * bilstm_output.shape[1], -1)
        hidden_input = hidden_input.view(self.batch_size, seq_len, -1)
        arg_hidden = self.mlp_arg(hidden_input)
        pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D]
        pred_hidden = self.mlp_pred(pred_recur)
        SRL_output = bilinear(arg_hidden,
                              self.rel_W,
                              pred_hidden,
                              self.mlp_size,
                              seq_len,
                              1,
                              self.batch_size,
                              num_outputs=self.target_vocab_size,
                              bias_x=True,
                              bias_y=True)
        SRL_output = SRL_output.view(self.batch_size * seq_len, -1)

        SRL_input = SRL_output.view(self.batch_size, seq_len, -1)
        compress_input = torch.cat(
            (input_emb.detach(), word_id_emb.detach(), SRL_input.detach()), 2)
        bilstm_output_word, (_,
                             bilstm_final_state_word) = self.bilstm_layer_word(
                                 compress_input,
                                 self.bilstm_hidden_state_word_p)
        bilstm_output_word = bilstm_output_word.contiguous().detach()
        # hidden_input_word = bilstm_output_word.view(bilstm_output_word.shape[0] * bilstm_output_word.shape[1], -1)
        pred_recur = bilstm_output_word[np.arange(0, self.batch_size),
                                        predicates_1D]
        pred_recur = pred_recur.view(self.batch_size,
                                     self.bilstm_hidden_size * 2)
        pred_recur_1 = pred_recur.unsqueeze(1).expand(
            self.batch_size, seq_len, self.bilstm_hidden_size * 2)
        pred_recur_2 = pred_recur.unsqueeze(1).expand(
            self.batch_size, seq_len_fr, self.bilstm_hidden_size * 2)
        pred_recur_en = pred_recur_1.detach()
        pred_recur_en_2 = pred_recur_2.detach()
        """
        En event vector, En word
        """
        combine = torch.cat(
            (pred_recur_en.detach(), input_emb.detach(), word_id_emb.detach()),
            2)
        output_word = self.match_word(combine)
        output_word_en = output_word.view(self.batch_size * seq_len,
                                          -1).detach()

        bilstm_output_fr, (_, bilstm_final_state) = self.bilstm_layer(
            input_emb_fr, self.bilstm_hidden_state_p)
        bilstm_output_fr = bilstm_output_fr.contiguous()
        hidden_input_fr = bilstm_output_fr.view(
            bilstm_output_fr.shape[0] * bilstm_output_fr.shape[1], -1)
        hidden_input_fr = hidden_input_fr.view(self.batch_size, seq_len_fr, -1)
        arg_hidden_fr = self.mlp_arg(hidden_input_fr)
        pred_recur_fr = hidden_input_fr[np.arange(0, self.batch_size),
                                        predicates_1D_fr]
        pred_hidden_fr = self.mlp_pred(pred_recur_fr)
        SRL_output_fr = bilinear(arg_hidden_fr,
                                 self.rel_W,
                                 pred_hidden_fr,
                                 self.mlp_size,
                                 seq_len_fr,
                                 1,
                                 self.batch_size,
                                 num_outputs=self.target_vocab_size,
                                 bias_x=True,
                                 bias_y=True)
        SRL_output_fr = SRL_output_fr.view(self.batch_size * seq_len_fr, -1)

        SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1)
        compress_input_fr = torch.cat(
            (input_emb_fr.detach(), word_id_emb_fr.detach(), SRL_input_fr), 2)
        bilstm_output_word_fr, (
            _, bilstm_final_state_word) = self.bilstm_layer_word(
                compress_input_fr, self.bilstm_hidden_state_word_p)
        bilstm_output_word_fr = bilstm_output_word_fr.contiguous()
        pred_recur_fr = bilstm_output_word_fr[np.arange(0, self.batch_size),
                                              predicates_1D_fr]
        pred_recur_fr = pred_recur_fr.view(self.batch_size,
                                           self.bilstm_hidden_size * 2)

        #############################################
        """
        Fr event vector, En word
        """

        pred_recur_fr_1 = pred_recur_fr.unsqueeze(1).expand(
            self.batch_size, seq_len_en, self.bilstm_hidden_size * 2)
        combine = torch.cat(
            (pred_recur_fr_1, input_emb_en.detach(), word_id_emb_en.detach()),
            2)
        output_word_fr = self.match_word(combine)
        output_word_fr = output_word_fr.view(self.batch_size * seq_len_en, -1)

        unlabeled_loss_function = nn.KLDivLoss(size_average=False)
        output_word_en = F.softmax(output_word_en, dim=1).detach()
        output_word_fr = F.log_softmax(output_word_fr, dim=1)
        loss = unlabeled_loss_function(output_word_fr, output_word_en) / (
            seq_len_en * self.para_batch_size)

        #############################################3
        """
        En event vector, Fr word
        """
        combine = torch.cat((pred_recur_en_2.detach(), input_emb_fr.detach(),
                             word_id_emb_fr.detach()), 2)
        output_word = self.match_word(combine)
        output_word_en_2 = output_word.view(self.batch_size * seq_len_fr, -1)
        """
        Fr event vector, Fr word
        """
        pred_recur_fr_2 = pred_recur_fr.unsqueeze(1).expand(
            self.batch_size, seq_len_fr, self.bilstm_hidden_size * 2)
        combine = torch.cat(
            (pred_recur_fr_2, input_emb_fr.detach(), word_id_emb_fr.detach()),
            2)
        output_word_fr_2 = self.match_word(combine)
        output_word_fr_2 = output_word_fr_2.view(self.batch_size * seq_len_fr,
                                                 -1)

        unlabeled_loss_function = nn.KLDivLoss(size_average=False)
        output_word_en_2 = F.softmax(output_word_en_2, dim=1).detach()
        output_word_fr_2 = F.log_softmax(output_word_fr_2, dim=1)
        loss_2 = unlabeled_loss_function(
            output_word_fr_2,
            output_word_en_2) / (seq_len_fr * self.para_batch_size)
        return loss, loss_2
예제 #10
0
    def run(self, words, tags, heads, rels, masks_w, masks_t, isTrain):
        if config.biaffine:
            mlp_dep_bias = dy.parameter(self.mlp_dep_bias)
            mlp_dep = dy.parameter(self.mlp_dep)
            mlp_head_bias = dy.parameter(self.mlp_head_bias)
            mlp_head = dy.parameter(self.mlp_head)
            W_arc = dy.parameter(self.W_arc)
            W_rel = dy.parameter(self.W_rel)

        #tokens in the sentence and root
        seq_len = len(words) + 1

        punct_mask = np.array(
            [1 if rel != self._punct_id else 0 for rel in rels],
            dtype=np.uint32)

        preds_arc = []
        preds_rel = []

        loss_arc = 0
        loss_rel = 0

        num_cor_arc = 0
        num_cor_rel = 0

        if isTrain:
            # embs_w = [self.lp_w[w if w < self._vocab_size_w else 0] * mask_w for w, mask_w in zip(words, masks_w)]
            # embs_t = [self.lp_t[t if t < self._vocab_size_t else 0] * mask_t for t, mask_t in zip(tags, masks_t)]
            embs_w = [
                self.lp_w[w] * mask_w for w, mask_w in zip(words, masks_w)
            ]
            embs_t = [
                self.lp_t[t] * mask_t for t, mask_t in zip(tags, masks_t)
            ]
            embs_w = [self.emb_root[0] * masks_t[-1]] + embs_w
            embs_t = [self.emb_root[1] * masks_w[-1]] + embs_t

        else:
            # embs_w = [self.lp_w[w if w < self._vocab_size_w else 0] for w in words]
            # embs_t = [self.lp_t[t if t < self._vocab_size_t else 0] for t in tags]
            embs_w = [self.lp_w[w] for w in words]
            embs_t = [self.lp_t[t] for t in tags]
            embs_w = [self.emb_root[0]] + embs_w
            embs_t = [self.emb_root[1]] + embs_t

        lstm_ins = [
            dy.concatenate([emb_w, emb_t])
            for emb_w, emb_t in zip(embs_w, embs_t)
        ]
        # lstm_outs = dy.concatenate_cols([self.emb_root[0]] + utils.bilstm(self.l2r_lstm, self.r2l_lstm, lstm_ins, self._pdrop))
        # lstm_outs = dy.concatenate_cols(utils.bilstm(self.LSTM_builders[0], self.LSTM_builders[1], lstm_ins, self._pdrop_lstm))
        lstm_outs = dy.concatenate_cols(
            utils.biLSTM(self.LSTM_builders, lstm_ins, None, self._pdrop_lstm,
                         self._pdrop_lstm))

        # if isTrain:
        #     lstm_outs = dy.dropout(lstm_outs, self._pdrop)

        if config.biaffine:
            embs_dep, embs_head = \
                utils.leaky_relu(dy.affine_transform([mlp_dep_bias, mlp_dep, lstm_outs])), \
                utils.leaky_relu(dy.affine_transform([mlp_head_bias, mlp_head, lstm_outs]))

            if isTrain:
                embs_dep, embs_head = dy.dropout(embs_dep,
                                                 self._pdrop_mlp), dy.dropout(
                                                     embs_head,
                                                     self._pdrop_mlp)

            dep_arc, dep_rel = embs_dep[:self._arc_dim], embs_dep[self.
                                                                  _arc_dim:]
            head_arc, head_rel = embs_head[:self.
                                           _arc_dim], embs_head[self._arc_dim:]

            logits_arc = utils.bilinear(dep_arc, W_arc, head_arc,
                                        self._arc_dim, seq_len,
                                        config.batch_size, 1,
                                        self.biaffine_bias_x_arc,
                                        self.biaffine_bias_y_arc)
        else:
            mlp = dy.parameter(self.mlp)
            mlp_bias = dy.parameter(self.mlp_bias)

            embs = \
                utils.leaky_relu(dy.affine_transform([mlp_bias, mlp, lstm_outs]))
            if isTrain:
                embs = dy.dropout(embs, self._pdrop_mlp)

            embs_arc, embs_rel = embs[:self._arc_dim * 2], embs[self._arc_dim *
                                                                2:]

            W_r_arc = dy.parameter(self.V_r_arc)
            W_i_arc = dy.parameter(self.V_i_arc)
            bias_arc = dy.parameter(self.bias_arc)

            logits_arc = utils.biED(embs_arc,
                                    W_r_arc,
                                    W_i_arc,
                                    embs_arc,
                                    seq_len,
                                    1,
                                    bias=bias_arc)

        # flat_logits_arc = dy.reshape(logits_arc[:][1:], (seq_len,), seq_len - 1)
        flat_logits_arc = dy.reshape(logits_arc, (seq_len, ), seq_len)
        # flat_logits_arc = dy.pick_batch_elems(flat_logits_arc, [e for e in range(1, seq_len)])
        flat_logits_arc = dy.pick_batch_elems(
            flat_logits_arc, np.arange(1, seq_len, dtype='int32'))

        loss_arc = dy.pickneglogsoftmax_batch(flat_logits_arc, heads)

        if not isTrain:
            # msk = [1] * seq_len
            msk = np.ones((seq_len), dtype='int32')
            arc_probs = dy.softmax(logits_arc).npvalue()
            arc_probs = np.transpose(arc_probs)
            preds_arc = utils.arc_argmax(arc_probs,
                                         seq_len,
                                         msk,
                                         ensure_tree=True)

            # preds_arc = logits_arc.npvalue().argmax(0)
            cor_arcs = np.multiply(np.equal(preds_arc[1:], heads), punct_mask)
            num_cor_arc = np.sum(cor_arcs)

        if not config.las:
            return loss_arc, num_cor_arc, num_cor_rel

        if config.biaffine:
            logits_rel = utils.bilinear(dep_rel, W_rel, head_rel,
                                        self._rel_dim, seq_len, 1,
                                        self._vocab_size_r,
                                        self.biaffine_bias_x_rel,
                                        self.biaffine_bias_y_rel)
        else:
            V_r_rel = dy.parameter(self.V_r_rel)
            V_i_rel = dy.parameter(self.V_i_rel)
            bias_rel = dy.parameter(self.bias_rel)

            logits_rel = utils.biED(embs_rel,
                                    V_r_rel,
                                    V_i_rel,
                                    embs_rel,
                                    seq_len,
                                    self._vocab_size_r,
                                    bias=bias_rel)

        # flat_logits_rel = dy.reshape(logits_rel[:][1:], (seq_len, self._vocab_size_r), seq_len - 1)
        flat_logits_rel = dy.reshape(logits_rel, (seq_len, self._vocab_size_r),
                                     seq_len)
        # flat_logits_rel = dy.pick_batch_elems(flat_logits_rel, [e for e in range(1, seq_len)])
        flat_logits_rel = dy.pick_batch_elems(
            flat_logits_rel, np.arange(1, seq_len, dtype='int32'))

        partial_rel_logits = dy.pick_batch(flat_logits_rel,
                                           heads if isTrain else preds_arc[1:])

        if isTrain:
            loss_rel = dy.sum_batches(
                dy.pickneglogsoftmax_batch(partial_rel_logits, rels))
        else:
            preds_rel = partial_rel_logits.npvalue().argmax(0)
            num_cor_rel = np.sum(
                np.multiply(np.equal(preds_rel, rels), cor_arcs))
        return loss_arc + loss_rel, num_cor_arc, num_cor_rel