def forward(self, batch_input, lang): word_batch = get_torch_variable_from_np(batch_input['word']) pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) if lang == "En": pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() else: pretrain_emb = self.fr_pretrained_embedding( pretrain_batch).detach() if lang == "En": word_emb = self.word_embedding(word_batch) else: word_emb = self.fr_word_embedding(word_batch) input_emb = torch.cat((pretrain_emb, word_emb), 2) input_emb = self.word_dropout(input_emb) seq_len = input_emb.shape[1] bilstm_output, (_, bilstm_final_state) = self.bilstm_layer( input_emb, self.bilstm_hidden_state) bilstm_output = bilstm_output.contiguous() hidden_input = bilstm_output.view( bilstm_output.shape[0] * bilstm_output.shape[1], -1) hidden_input = hidden_input.view(self.batch_size, seq_len, -1) hidden_input = self.out_dropout(hidden_input) output = self.scorer(hidden_input) output = output.view(self.batch_size * seq_len, -1) return output
def __init__(self, model_params): super(SR_Matcher, self).__init__() self.dropout_word = nn.Dropout(p=0.3) self.mlp_size = 300 self.dropout_mlp = model_params['dropout_mlp'] self.batch_size = model_params['batch_size'] self.target_vocab_size = model_params['target_vocab_size'] self.use_flag_embedding = model_params['use_flag_embedding'] self.flag_emb_size = model_params['flag_embedding_size'] self.pretrain_emb_size = 768 #model_params['pretrain_emb_size'] self.bilstm_num_layers = model_params['bilstm_num_layers'] self.bilstm_hidden_size = model_params['bilstm_hidden_size'] self.emb2vector = nn.Sequential( nn.Linear(self.pretrain_emb_size + 0 * self.flag_emb_size, 200), nn.Tanh()) self.matrix = nn.Parameter( get_torch_variable_from_np(np.zeros((200, 200)).astype("float32"))) self.probs2vector = nn.Sequential( nn.Linear(self.target_vocab_size, self.target_vocab_size), nn.Tanh()) self.vector2probs = nn.Sequential( nn.Linear(self.target_vocab_size, self.target_vocab_size), nn.Tanh(), nn.Linear(self.target_vocab_size, self.target_vocab_size - 1))
def learn_loss(self, output_SRL, pretrain_emb, flag_emb, seq_len, mask_copy, mask_unk): SRL_input = output_SRL.view(self.batch_size, seq_len, -1) output = SRL_input.view(self.batch_size*seq_len, -1) SRL_input = F.softmax(SRL_input, 2).detach() pred = torch.max(SRL_input, dim=2)[1] for i in range(self.batch_size): for j in range(seq_len): if pred[i][j] > 1: mask_copy[i][j] = 1 mask_copy = get_torch_variable_from_np(mask_copy) mask_final = mask_copy.view(self.batch_size * seq_len) * mask_unk.view(self.batch_size * seq_len) pred_recur = self.SR_Compressor(SRL_input, pretrain_emb, flag_emb.detach(), None, None, seq_len, para=False, use_bert=True) output_word = self.SR_Matcher(pred_recur, pretrain_emb, flag_emb.detach(), None, seq_len, copy = True, para=False, use_bert=True) score4Null = torch.zeros_like(output_word[:, 1:2]) output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1) criterion = nn.CrossEntropyLoss(reduction='none') _, prediction_batch_variable = torch.max(output, 1) loss_word = criterion(output_word, prediction_batch_variable)*mask_final loss_word = loss_word.sum()/(self.batch_size*seq_len) return loss_word
def forward(self, role_vectors, pretrained_emb, word_id_emb, seq_len, para=False): query_vector = torch.cat((pretrained_emb, word_id_emb), 2) role_vectors = role_vectors.view(self.batch_size, self.target_vocab_size, 200) # B T R V role_vectors = role_vectors.unsqueeze(1).expand( self.batch_size, seq_len, self.target_vocab_size, 200) # B T R W query_vector = query_vector.unsqueeze(2).expand( self.batch_size, seq_len, self.target_vocab_size, self.pretrain_emb_size + self.flag_emb_size) # B T R V y = torch.mm( query_vector.contiguous().view( self.batch_size * seq_len * self.target_vocab_size, -1), self.matrix) # B T R y = y.contiguous().view(self.batch_size, seq_len, self.target_vocab_size, 200) roles_scores = torch.sum(role_vectors * y, dim=3) zerosNull = get_torch_variable_from_np( np.zeros((self.batch_size, seq_len, 1), dtype='float32')) roles_scores = roles_scores.view(self.batch_size, seq_len, -1) output_word = torch.cat( (roles_scores[:, :, 0:1], zerosNull, roles_scores[:, :, 2:]), 2) output_word = output_word.view(self.batch_size * seq_len, -1) return output_word
def __init__(self, model_params): super(SR_Matcher, self).__init__() #self.dropout_word_1 = nn.Dropout(p=0.0) #self.dropout_word_2 = nn.Dropout(p=0.0) self.mlp_size = 300 #self.dropout_mlp = model_params['dropout_mlp'] self.batch_size = model_params['batch_size'] self.target_vocab_size = model_params['target_vocab_size'] self.use_flag_embedding = model_params['use_flag_embedding'] self.flag_emb_size = model_params['flag_embedding_size'] self.pretrain_emb_size = model_params['pretrain_emb_size'] self.bilstm_num_layers = model_params['bilstm_num_layers'] self.bilstm_hidden_size = model_params['bilstm_hidden_size'] self.bert_size = 768 self.base_emb2vector = nn.Sequential(nn.Linear(self.bert_size, 300), nn.Tanh()) self.query_emb2vector = nn.Sequential(nn.Linear(self.bert_size, 300), nn.Tanh()) self.probs2vector = nn.Sequential( nn.Linear(self.target_vocab_size, 100), nn.Tanh(), nn.Linear(100, 50), nn.Tanh()) self.vector2scores = nn.Sequential( nn.Linear(50, 30), nn.Tanh(), nn.Linear(30, self.target_vocab_size)) self.matrix = nn.Parameter( get_torch_variable_from_np(np.zeros((300, 300)).astype("float32")))
def Input4Gan_1(output, predicates_1D): output = F.softmax(output, dim=2) seq_len = output.shape[1] shuffled_timestep = np.arange(0, seq_len) np.random.shuffle(shuffled_timestep) #Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor. shuffled_output = output.index_select( dim=1, index=get_torch_variable_from_np(shuffled_timestep)) return shuffled_output
def forward(self, batch_input, lang='En', unlabeled=False): if unlabeled: loss = self.parallel_train(batch_input) loss_word = 0 return loss, loss_word pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) predicates_1D = batch_input['predicates_idx'] flag_batch = get_torch_variable_from_np(batch_input['flag']) word_id = get_torch_variable_from_np(batch_input['word_times']) word_id_emb = self.id_embedding(word_id) flag_emb = self.flag_embedding(flag_batch) if lang == "En": pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() else: pretrain_emb = self.fr_pretrained_embedding( pretrain_batch).detach() seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(pretrain_emb, flag_emb, predicates_1D, seq_len, para=False) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input = SRL_input.detach() pred_recur = self.SR_Compressor(SRL_input, pretrain_emb, word_id_emb, seq_len, para=False) output_word = self.SR_Matcher(pred_recur, pretrain_emb, word_id_emb.detach(), seq_len, para=False) return SRL_output, output_word
def __init__(self, model_params): super(SR_Labeler, self).__init__() self.dropout_word = nn.Dropout(p=0.5) self.dropout_hidden = nn.Dropout(p=0.3) self.dropout_mlp = model_params['dropout_mlp'] self.batch_size = model_params['batch_size'] self.target_vocab_size = model_params['target_vocab_size'] self.use_flag_embedding = model_params['use_flag_embedding'] self.flag_emb_size = model_params['flag_embedding_size'] self.pretrain_emb_size = model_params['pretrain_emb_size'] self.bilstm_num_layers = model_params['bilstm_num_layers'] self.bilstm_hidden_size = model_params['bilstm_hidden_size'] if USE_CUDA: self.bilstm_hidden_state = (Variable(torch.randn( 2 * self.bilstm_num_layers, self.batch_size, self.bilstm_hidden_size), requires_grad=True).cuda(), Variable(torch.randn( 2 * self.bilstm_num_layers, self.batch_size, self.bilstm_hidden_size), requires_grad=True).cuda()) else: self.bilstm_hidden_state = (Variable(torch.randn( 2 * self.bilstm_num_layers, self.batch_size, self.bilstm_hidden_size), requires_grad=True), Variable(torch.randn( 2 * self.bilstm_num_layers, self.batch_size, self.bilstm_hidden_size), requires_grad=True)) self.bilstm_layer = nn.LSTM(input_size=768 + 1 * self.flag_emb_size, hidden_size=self.bilstm_hidden_size, num_layers=self.bilstm_num_layers, bidirectional=True, bias=True, batch_first=True) self.mlp_size = 300 self.rel_W = nn.Parameter( get_torch_variable_from_np( np.zeros((self.mlp_size + 1, self.target_vocab_size * (self.mlp_size + 1))).astype("float32"))) self.mlp_arg = nn.Sequential( nn.Linear(2 * self.bilstm_hidden_size, self.mlp_size), nn.ReLU()) self.mlp_pred = nn.Sequential( nn.Linear(2 * self.bilstm_hidden_size, self.mlp_size), nn.ReLU())
def forward(self, tree, inputs): #, tree_hidden for idx in range(tree.num_children): self.forward(tree.children[idx], inputs) #, tree_hidden if tree.num_children == 0: child_h = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.)) child_deprels = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.)) else: child_h = [x.state for x in tree.children] child_h = torch.cat(child_h, dim=0) child_deprels = self.deprel_emb[get_torch_variable_from_np(np.array([c.deprel for c in tree.children]))] tree.state = self.node_forward(inputs[tree.idx], child_h, child_deprels) return tree.state
def __init__(self, model_params): super(SR_Matcher, self).__init__() self.dropout_word = nn.Dropout(p=0.0) self.mlp_size = 300 self.dropout_mlp = model_params['dropout_mlp'] self.batch_size = model_params['batch_size'] self.target_vocab_size = model_params['target_vocab_size'] self.use_flag_embedding = model_params['use_flag_embedding'] self.flag_emb_size = model_params['flag_embedding_size'] self.pretrain_emb_size = model_params['pretrain_emb_size'] self.bilstm_num_layers = model_params['bilstm_num_layers'] self.bilstm_hidden_size = model_params['bilstm_hidden_size'] self.matrix = nn.Parameter( get_torch_variable_from_np( np.zeros((self.pretrain_emb_size + self.flag_emb_size, 200)).astype("float32")))
def forward(self, tree, inputs): #, tree_hidden for idx in range(tree.num_children): self.forward(tree.children[idx], inputs) #, tree_hidden if tree.num_children == 0: child_c = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.)) child_h = Variable(inputs[0].data.new(1, self.mem_dim).fill_(0.)) deprels = None else: child_c, child_h = zip(*map(lambda x: x.state, tree.children)) child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0) deprels = get_torch_variable_from_np( np.array([c.deprel for c in tree.children])) tree.state = self.node_forward(inputs[tree.idx], child_c, child_h, deprels) # tree_hidden[tree.idx] = tree.state return tree.state
def __init__(self, model_params): super(SR_Matcher, self).__init__() #self.dropout_word_1 = nn.Dropout(p=0.0) #self.dropout_word_2 = nn.Dropout(p=0.0) self.mlp_size = 300 #self.dropout_mlp = model_params['dropout_mlp'] self.batch_size = model_params['batch_size'] self.target_vocab_size = model_params['target_vocab_size'] self.use_flag_embedding = model_params['use_flag_embedding'] self.flag_emb_size = model_params['flag_embedding_size'] self.pretrain_emb_size = model_params['pretrain_emb_size'] self.bilstm_num_layers = model_params['bilstm_num_layers'] self.bilstm_hidden_size = model_params['bilstm_hidden_size'] self.compress_emb = nn.Sequential(nn.Linear(768, 128), nn.Tanh()) self.matrix = nn.Parameter( get_torch_variable_from_np(np.zeros((256, 256)).astype("float32"))) self.specific_NULL_emb = nn.Sequential(nn.Linear(768, 256), nn.Tanh(), nn.Linear(256, 128), nn.Tanh()) self.scorer = nn.Sequential(nn.Linear(256, 64), nn.Tanh(), nn.Linear(64, 1))
def forward(self, batch_input, lang='En', unlabeled=False, use_bert=False, isTrain=True): if unlabeled: loss = self.parallel_train(batch_input) loss_word = 0 return loss, loss_word pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) predicates_1D = batch_input['predicates_idx'] flag_batch = get_torch_variable_from_np(batch_input['flag']) word_id = get_torch_variable_from_np(batch_input['word_times']) word_id_emb = self.id_embedding(word_id) flag_emb = self.flag_embedding(flag_batch) actual_lens = batch_input['seq_len'] if lang == "En": pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() else: pretrain_emb = self.fr_pretrained_embedding( pretrain_batch).detach() bert_input_ids = get_torch_variable_from_np( batch_input['bert_input_ids']) bert_input_mask = get_torch_variable_from_np( batch_input['bert_input_mask']) bert_out_positions = get_torch_variable_from_np( batch_input['bert_out_positions']) bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask) bert_emb = bert_emb[0] bert_emb = bert_emb[:, 1:-1, :].contiguous().detach() bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1), bert_out_positions].detach() for i in range(len(bert_emb)): if i >= len(actual_lens): break for j in range(len(bert_emb[i])): if j >= actual_lens[i]: bert_emb[i][j] = get_torch_variable_from_np( np.zeros(768, dtype="float32")) bert_emb = bert_emb.detach() pretrain_emb = bert_emb seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(pretrain_emb, flag_emb, predicates_1D, seq_len, para=False) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input = F.softmax(SRL_input, 2).detach() pred_recur = self.SR_Compressor(pretrain_emb, word_id_emb, seq_len, para=False) output_word = self.SR_Matcher(pred_recur, SRL_input, pretrain_emb, word_id_emb.detach(), seq_len, para=False) score4Null = torch.zeros_like(output_word[:, 1:2]) output_word = torch.cat( (output_word[:, 0:1], score4Null, output_word[:, 1:]), 1) teacher = SRL_input.view(self.batch_size * seq_len, -1).detach() eps = 1e-7 student = torch.log_softmax(output_word, 1) unlabeled_loss_function = nn.KLDivLoss(reduction='none') loss_copy = unlabeled_loss_function(student, teacher) loss_copy = loss_copy.sum() / (self.batch_size * seq_len) return SRL_output, output_word, loss_copy
def parallel_train(self, batch_input): unlabeled_data_en, unlabeled_data_fr = batch_input pretrain_batch_fr = get_torch_variable_from_np( unlabeled_data_fr['pretrain']) predicates_1D_fr = unlabeled_data_fr['predicates_idx'] flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag']) word_id_fr = get_torch_variable_from_np( unlabeled_data_fr['word_times']) word_id_emb_fr = self.id_embedding(word_id_fr).detach() flag_emb_fr = self.flag_embedding(flag_batch_fr).detach() pretrain_emb_fr = self.fr_pretrained_embedding( pretrain_batch_fr).detach() pretrain_batch = get_torch_variable_from_np( unlabeled_data_en['pretrain']) predicates_1D = unlabeled_data_en['predicates_idx'] flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag']) word_id = get_torch_variable_from_np(unlabeled_data_en['word_times']) word_id_emb = self.id_embedding(word_id) flag_emb = self.flag_embedding(flag_batch) seq_len = flag_emb.shape[1] seq_len_en = seq_len pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(pretrain_emb, flag_emb.detach(), predicates_1D, seq_len, para=True) pred_recur = self.SR_Compressor(pretrain_emb, word_id_emb.detach(), seq_len, para=True) seq_len_fr = flag_emb_fr.shape[1] SRL_output_fr = self.SR_Labeler(pretrain_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True) pred_recur_fr = self.SR_Compressor(pretrain_emb_fr, word_id_emb_fr.detach(), seq_len_fr, para=True) """ En event vector, En word """ output_word_en = self.SR_Matcher(pred_recur.detach(), SRL_output.detach(), pretrain_emb, word_id_emb.detach(), seq_len, para=True) ############################################# """ Fr event vector, En word """ output_word_fr = self.SR_Matcher(pred_recur_fr.detach(), SRL_output_fr, pretrain_emb, word_id_emb.detach(), seq_len, para=True) unlabeled_loss_function = nn.KLDivLoss(size_average=False) output_word_en = F.softmax(output_word_en, dim=1).detach() output_word_fr = F.log_softmax(output_word_fr, dim=1) loss = unlabeled_loss_function( output_word_fr, output_word_en) / (seq_len_en * self.batch_size) #############################################3 """ En event vector, Fr word """ output_word_en = self.SR_Matcher(pred_recur.detach(), SRL_output.detach(), pretrain_emb_fr, word_id_emb_fr.detach(), seq_len_fr, para=True) """ Fr event vector, Fr word """ output_word_fr = self.SR_Matcher(pred_recur_fr.detach(), SRL_output_fr, pretrain_emb_fr, word_id_emb_fr.detach(), seq_len_fr, para=True) unlabeled_loss_function = nn.KLDivLoss(size_average=False) output_word_en = F.softmax(output_word_en, dim=1).detach() output_word_fr = F.log_softmax(output_word_fr, dim=1) loss_2 = unlabeled_loss_function( output_word_fr, output_word_en) / (seq_len_fr * self.batch_size) return loss, loss_2
def forward(self, batch_input, elmo, withParallel=True, lang='En'): if lang == 'En': word_batch = get_torch_variable_from_np(batch_input['word']) pretrain_batch = get_torch_variable_from_np( batch_input['pretrain']) else: word_batch = get_torch_variable_from_np(batch_input['word']) pretrain_batch = get_torch_variable_from_np( batch_input['pretrain']) flag_batch = get_torch_variable_from_np(batch_input['flag']) pos_batch = get_torch_variable_from_np(batch_input['pos']) pos_emb = self.pos_embedding(pos_batch) #log(pretrain_batch) #log(flag_batch) role_index = get_torch_variable_from_np(batch_input['role_index']) role_mask = get_torch_variable_from_np(batch_input['role_mask']) role_mask_expand = role_mask.unsqueeze(2).expand( self.batch_size, self.target_vocab_size, self.word_emb_size) role2word_batch = pretrain_batch.gather(dim=1, index=role_index) #log(role2word_batch) role2word_emb = self.pretrained_embedding(role2word_batch).detach() role2word_emb = role2word_emb * role_mask_expand.float() #log(role2word_emb[0][1]) #log(role2word_emb[0][2]) #log(role_mask) #log(role_mask_expand) #log("#################") #log(role_index) #log(role_mask) #log(role2word_batch) if withParallel: fr_word_batch = get_torch_variable_from_np(batch_input['fr_word']) fr_pretrain_batch = get_torch_variable_from_np( batch_input['fr_pretrain']) fr_flag_batch = get_torch_variable_from_np(batch_input['fr_flag']) #log(fr_pretrain_batch[0]) #log(fr_flag_batch) if self.use_flag_embedding: flag_emb = self.flag_embedding(flag_batch) else: flag_emb = flag_batch.view(flag_batch.shape[0], flag_batch.shape[1], 1).float() seq_len = flag_batch.shape[1] if lang == "En": word_emb = self.word_embedding(word_batch) pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() else: word_emb = self.fr_word_embedding(word_batch) pretrain_emb = self.fr_pretrained_embedding( pretrain_batch).detach() if withParallel: #fr_word_emb = self.fr_word_embedding(fr_word_batch) fr_pretrain_emb = self.fr_pretrained_embedding( fr_pretrain_batch).detach() fr_flag_emb = self.flag_embedding(fr_flag_batch) fr_seq_len = fr_flag_batch.shape[1] role_mask_expand_timestep = role_mask.unsqueeze(2).expand( self.batch_size, self.target_vocab_size, fr_seq_len) role_mask_expand_forFR = \ role_mask_expand.unsqueeze(1).expand(self.batch_size, fr_seq_len, self.target_vocab_size, self.pretrain_emb_size) # predicate_emb = self.word_embedding(predicate_batch) # predicate_pretrain_emb = self.pretrained_embedding(predicate_pretrain_batch) #######semantic role labelerxxxxxxxxxx if self.use_deprel: input_emb = torch.cat([flag_emb, pretrain_emb], 2) # else: input_emb = torch.cat([flag_emb, pretrain_emb], 2) # if withParallel: fr_input_emb = torch.cat([fr_flag_emb, fr_pretrain_emb], 2) input_emb = self.word_dropout(input_emb) #input_emb = torch.cat((input_emb, get_torch_variable_from_np(np.zeros((self.batch_size, seq_len, self.target_vocab_size))).float()), 2) bilstm_output, (_, bilstm_final_state) = self.bilstm_layer( input_emb, self.bilstm_hidden_state) bilstm_output = bilstm_output.contiguous() hidden_input = bilstm_output.view( bilstm_output.shape[0] * bilstm_output.shape[1], -1) hidden_input = hidden_input.view(self.batch_size, seq_len, -1) #output = self.output_layer(hidden_input) arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input)) predicates_1D = batch_input['predicates_idx'] pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D] pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur)) output = bilinear(arg_hidden, self.rel_W, pred_hidden, self.mlp_size, seq_len, 1, self.batch_size, num_outputs=self.target_vocab_size, bias_x=True, bias_y=True) en_output = output.view(self.batch_size * seq_len, -1) if withParallel: role2word_emb_1 = role2word_emb.view(self.batch_size, self.target_vocab_size, -1) role2word_emb_2 = role2word_emb_1.unsqueeze(dim=1) role2word_emb_expand = role2word_emb_2.expand( self.batch_size, fr_seq_len, self.target_vocab_size, self.pretrain_emb_size) #log(role2word_emb_expand[0, 1, 4]) #log(role2word_emb_expand[0, 2, 4]) fr_pretrain_emb = fr_pretrain_emb.view(self.batch_size, fr_seq_len, -1) fr_pretrain_emb = fr_pretrain_emb.unsqueeze(dim=2) fr_pretrain_emb_expand = fr_pretrain_emb.expand( self.batch_size, fr_seq_len, self.target_vocab_size, self.pretrain_emb_size) fr_pretrain_emb_expand = fr_pretrain_emb_expand # * role_mask_expand_forFR.float() # B T R W emb_distance = fr_pretrain_emb_expand - role2word_emb_expand emb_distance = emb_distance * emb_distance # B T R emb_distance = emb_distance.sum(dim=3) emb_distance = torch.sqrt(emb_distance) emb_distance = emb_distance.transpose(1, 2) # emb_distance_min = emb_distance_min.expand(self.batch_size, fr_seq_len, self.target_vocab_size) # emb_distance = F.softmax(emb_distance, dim=1)*role_mask_expand_timestep.float() # B R emb_distance_min, emb_distance_argmin = torch.min(emb_distance, dim=2, keepdim=True) emb_distance_max, emb_distance_argmax = torch.max(emb_distance, dim=2, keepdim=True) emb_distance_max_expand = emb_distance_max.expand( self.batch_size, self.target_vocab_size, fr_seq_len) emb_distance_min_expand = emb_distance_min.expand( self.batch_size, self.target_vocab_size, fr_seq_len) emb_distance_nomalized = (emb_distance - emb_distance_min_expand ) / (emb_distance_max_expand - emb_distance_min_expand) #* role_mask_expand_timestep.float() # emb_distance_argmin_expand = emb_distance_argmin.expand(self.batch_size, fr_seq_len, self.target_vocab_size) # log("######################") emb_distance_nomalized = emb_distance_nomalized.detach() top2 = emb_distance_nomalized.topk(2, dim=2, largest=False, sorted=True)[0] top2_gap = top2[:, :, 1] - top2[:, :, 0] #log(top2_gap) #log(role_mask) #log(emb_distance_nomalized[0,:,2]) #log(emb_distance_argmax) #emb_distance_argmin = emb_distance_argmin*role_mask #log(emb_distance_argmin) #log(emb_distance_nomalized[0][2]) """ weight_4_loss = emb_distance_nomalized #log(emb_distance_argmin[0,2]) for i in range(self.batch_size): for j in range(self.target_vocab_size): weight_4_loss[i,j].index_fill_(0, emb_distance_argmin[i][j], 1) """ #log(weight_4_loss[0][2]) fr_input_emb = self.word_dropout(fr_input_emb).detach() #fr_input_emb = torch.cat((fr_input_emb, emb_distance_nomalized), 2) fr_bilstm_output, (_, bilstm_final_state) = self.bilstm_layer( fr_input_emb, self.fr_bilstm_hidden_state) fr_bilstm_output = fr_bilstm_output.contiguous() hidden_input = fr_bilstm_output.view( fr_bilstm_output.shape[0] * fr_bilstm_output.shape[1], -1) hidden_input = hidden_input.view(self.batch_size, fr_seq_len, -1) arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input)) predicates_1D = batch_input['fr_predicates_idx'] #log(predicates_1D) #log(predicates_1D) pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D] pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur)) output = bilinear(arg_hidden, self.rel_W, pred_hidden, self.mlp_size, fr_seq_len, 1, self.batch_size, num_outputs=self.target_vocab_size, bias_x=True, bias_y=True) output = output.view(self.batch_size, fr_seq_len, -1) output = output.transpose(1, 2) #B R T #output_p = F.softmax(output, dim=2) #output = output * role_mask_expand_timestep.float() #log(role_mask_expand_timestep) """ output_exp = torch.exp(output) output_exp_weighted = output_exp #* weight_4_loss output_expsum = output_exp_weighted.sum(dim=2, keepdim=True).expand(self.batch_size, self.target_vocab_size, fr_seq_len) output = output_exp/output_expsum """ #log(output_expsum) #log(output_p[0, 2]) # B R 1 #output_pmax, output_pmax_arg = torch.max(output_p, dim=2, keepdim=True) #log(output[0]) #log(emb_distance[0,:, 2]) #log(emb_distance_nomalized[0,:, 2]) #log(role_mask[0]) #log(role_mask[0]) #log(output[0][0]) #log("++++++++++++++++++++") #log(emb_distance) #log(emb_distance.gather(1, emb_distance_argmin)) #output_pargminD = output_p.gather(2, emb_distance_argmin) #log(output_argminD) #weighted_distance = (output/output_argminD) * emb_distance_nomalized #bias = torch.FloatTensor(output_pmax.size()).fill_(1).cuda() #rank_loss = (output_pmax/output_pargminD)# * emb_distance_nomalized.gather(2, output_pmax_arg) #rank_loss = rank_loss.view(self.batch_size, self.target_vocab_size) #rank_loss = rank_loss * role_mask.float() #log("++++++++++++++++++++++") #log(output_max-output_argminD) #log(emb_distance_nomalized.gather(1, output_max_arg)) # B R #weighted_distance = weighted_distance.squeeze() * role_mask.float() """ output = F.softmax(output, dim=1) output = output.transpose(1, 2) fr_role2word_emb = torch.bmm(output, fr_pretrain_emb) criterion = nn.MSELoss(reduce=False) l2_loss = criterion(fr_role2word_emb.view(self.batch_size*self.target_vocab_size, -1), role2word_emb.view(self.batch_size*self.target_vocab_size, -1)) """ criterion = nn.CrossEntropyLoss(ignore_index=0, reduce=False) output = output.view(self.batch_size * self.target_vocab_size, -1) emb_distance_argmin = emb_distance_argmin * role_mask.unsqueeze(2) emb_distance_argmin = emb_distance_argmin.view(-1) #log(emb_distance_argmin[0][2]) l2_loss = criterion(output, emb_distance_argmin) l2_loss = l2_loss.view(self.batch_size, self.target_vocab_size) top2_gap = top2_gap.view(self.batch_size, self.target_vocab_size) l2_loss = l2_loss * top2_gap l2_loss = l2_loss.sum() #l2_loss = F.nll_loss(torch.log(output), emb_distance_argmin, ignore_index=0) #log(emb_distance_argmin) #log(l2_loss) #log("+++++++++++++++++++++") #log(l2_loss) #log(batch_input['fr_loss_mask']) l2_loss = l2_loss * get_torch_variable_from_np( batch_input['fr_loss_mask']).float() l2_loss = l2_loss.sum() #/float_role_mask.sum() #log("+") #log(l2_loss) return en_output, l2_loss return en_output
def forward(self, batch_input, lang='En', unlabeled=False, self_constrain=False, use_bert=False): if unlabeled: loss = self.parallel_train(batch_input, use_bert) loss_word = 0 return loss, loss_word if self_constrain: loss = self.self_train(batch_input) return loss #pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) predicates_1D = batch_input['predicates_idx'] flag_batch = get_torch_variable_from_np(batch_input['flag']) word_id = get_torch_variable_from_np(batch_input['word_times']) word_id_emb = self.id_embedding(word_id) flag_emb = self.flag_embedding(flag_batch) actual_lens = batch_input['seq_len'] #print(actual_lens) if use_bert: bert_input_ids = get_torch_variable_from_np( batch_input['bert_input_ids']) bert_input_mask = get_torch_variable_from_np( batch_input['bert_input_mask']) bert_out_positions = get_torch_variable_from_np( batch_input['bert_out_positions']) bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask) bert_emb = bert_emb[0] bert_emb = bert_emb[:, 1:-1, :].contiguous().detach() bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1), bert_out_positions].detach() for i in range(len(bert_emb)): if i >= len(actual_lens): break for j in range(len(bert_emb[i])): if j >= actual_lens[i]: bert_emb[i][j] = get_torch_variable_from_np( np.zeros(768, dtype="float32")) bert_emb = bert_emb.detach() #if lang == "En": # pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() #else: # pretrain_emb = self.fr_pretrained_embedding(pretrain_batch).detach() seq_len = flag_emb.shape[1] if not use_bert: SRL_output = self.SR_Labeler(pretrain_emb, flag_emb, predicates_1D, seq_len, para=False) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input = SRL_input pred_recur = self.SR_Compressor(SRL_input, pretrain_emb, flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False) output_word = self.SR_Matcher(pred_recur, pretrain_emb, flag_emb.detach(), word_id_emb.detach(), seq_len, para=False) else: SRL_output = self.SR_Labeler(bert_emb, flag_emb, predicates_1D, seq_len, para=False, use_bert=True) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input = SRL_input pred_recur = self.SR_Compressor(SRL_input, bert_emb, flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False, use_bert=True) output_word = self.SR_Matcher(pred_recur, bert_emb, flag_emb.detach(), word_id_emb.detach(), seq_len, para=False, use_bert=True) copy_loss = 0 #self.copy_loss(SRL_input, flag_emb, bert_emb, seq_len) return SRL_output, output_word, copy_loss
batch_size, word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2word, shuffle=True, withParrallel=True)): srl_model.train() target_argument = train_input_data['argument'] flat_argument = train_input_data['flat_argument'] target_batch_variable = get_torch_variable_from_np( flat_argument) out = srl_model(train_input_data, elmo, withParallel=False, lang='En', isPretrain=True) loss = criterion(out, target_batch_variable) if batch_i % 50 == 0: log(batch_i, loss) optimizer.zero_grad() loss.backward() optimizer.step()
def eval_data(model, elmo, dataset, batch_size ,word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2argument, idx2word, unify_pred = False, predicate_correct=0, predicate_sum=0, isPretrain=False): model.eval() golden = [] predict = [] output_data = [] cur_sentence = None cur_sentence_data = None for batch_i, input_data in enumerate(inter_utils.get_batch(dataset, batch_size, word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2word, lang="En")): target_flags = input_data['sen_flags'] flat_argument = input_data['flat_argument'] target_batch_variable = get_torch_variable_from_np(flat_argument) sentence_id = input_data['sentence_id'] predicate_id = input_data['predicate_id'] word_id = input_data['word_id'] sentence_len = input_data['sentence_len'] seq_len = input_data['seq_len'] bs = input_data['batch_size'] psl = input_data['pad_seq_len'] out = model(input_data, lang='En') _, pred = torch.max(out, 1) pred = get_data(pred) pred = np.reshape(pred, target_flags.shape) for idx in range(pred.shape[0]): predict.append(list(pred[idx])) golden.append(list(target_flags[idx])) pre_data = [] for b in range(len(seq_len)): line_data = ['_' for _ in range(sentence_len[b])] for s in range(seq_len[b]): wid = word_id[b][s] #line_data[wid-1] = idx2argument[pred[b][s]] line_data[wid - 1] = str(pred[b][s]) pre_data.append(line_data) for b in range(len(sentence_id)): if cur_sentence != sentence_id[b]: if cur_sentence_data is not None: output_data.append(cur_sentence_data) cur_sentence_data = [[sentence_id[b]]*len(pre_data[b]),pre_data[b]] cur_sentence = sentence_id[b] else: assert cur_sentence_data is not None cur_sentence_data.append(pre_data[b]) if cur_sentence_data is not None and len(cur_sentence_data)>0: output_data.append(cur_sentence_data) score = sem_f1_score(golden, predict, argument2idx, unify_pred, predicate_correct, predicate_sum) """ P = correct_pos / NonullPredict_pos R = correct_pos / NonullTruth_pos F = 2 * P * R / (P + R) log("POS: ", P, R, F) P = correct_PI / NonullPredict_PI R = correct_PI / NonullTruth_PI F = 2 * P * R / (P + R) log("PI: ", P, R, F) P = correct_deprel / NonullPredict_deprel R = correct_deprel / NonullTruth_deprel F = 2 * P * R / (P + R) log("deprel: ", P, R, F) P = correct_link / NonullPredict_link R = correct_link / NonullTruth_link F = 2 * P * R / (P + R) log(correct_link, NonullPredict_link, NonullTruth_link) log("link: ", P, R, F) """ model.train() return score, output_data
def forward(self, batch_input, lang='En', unlabeled=False, self_constrain=False, use_bert=False, isTrain=False): if unlabeled: #l2loss = self.word_trans(batch_input, use_bert) consistent_loss = self.parallel_train(batch_input, use_bert) return consistent_loss pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) predicates_1D = batch_input['predicates_idx'] flag_batch = get_torch_variable_from_np(batch_input['flag']) flag_emb = self.flag_embedding(flag_batch) actual_lens = batch_input['seq_len'] # print(actual_lens) if use_bert: bert_input_ids = get_torch_variable_from_np( batch_input['bert_input_ids']) bert_input_mask = get_torch_variable_from_np( batch_input['bert_input_mask']) bert_out_positions = get_torch_variable_from_np( batch_input['bert_out_positions']) bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask) bert_emb = bert_emb[0] bert_emb = bert_emb[:, 1:-1, :].contiguous().detach() bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1), bert_out_positions].detach() """ for i in range(len(bert_emb)): if i >= len(actual_lens): break for j in range(len(bert_emb[i])): if j >= actual_lens[i]: bert_emb[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb = bert_emb.detach() """ seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(bert_emb, flag_emb, predicates_1D, seq_len, para=False, use_bert=True) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input_probs = F.softmax(SRL_input, 2).detach() output_word = self.SR_Matcher(bert_emb.detach(), bert_emb.detach(), SRL_input_probs, seq_len, seq_len, isTrain=isTrain, para=False) #score4Null = torch.zeros_like(output_word[:, 1:2]) #output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1) teacher = SRL_input_probs.view(self.batch_size * seq_len, -1).detach() student = torch.log_softmax(output_word, dim=1) unlabeled_loss_function = nn.KLDivLoss(reduction='none') loss_copy = unlabeled_loss_function(student, teacher) loss_copy = loss_copy.sum() / (self.batch_size * seq_len) return SRL_output, output_word, loss_copy
def parallel_train_(self, batch_input, use_bert, isTrain=True): unlabeled_data_en, unlabeled_data_fr = batch_input predicates_1D_fr = unlabeled_data_fr['predicates_idx'] flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag']) flag_emb_fr = self.flag_embedding(flag_batch_fr).detach() actual_lens_fr = unlabeled_data_fr['seq_len'] predicates_1D = unlabeled_data_en['predicates_idx'] flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag']) actual_lens_en = unlabeled_data_en['seq_len'] flag_emb = self.flag_embedding(flag_batch).detach() seq_len = flag_emb.shape[1] seq_len_en = seq_len if use_bert: bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids']) bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask']) bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions']) bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr) bert_emb_fr = bert_emb_fr[0] bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach() bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach() for i in range(len(bert_emb_fr)): if i >= len(actual_lens_fr): print("error") break for j in range(len(bert_emb_fr[i])): if j >= actual_lens_fr[i]: bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb_fr = gaussian(bert_emb_fr, isTrain, 0, 0.1) bert_emb_fr = bert_emb_fr.detach() bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids']) bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask']) bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions']) bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en) bert_emb_en = bert_emb_en[0] bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach() bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach() for i in range(len(bert_emb_en)): if i >= len(actual_lens_en): print("error") break for j in range(len(bert_emb_en[i])): if j >= actual_lens_en[i]: bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb_en = gaussian(bert_emb_en, isTrain, 0, 0.1) bert_emb_en = bert_emb_en.detach() seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(bert_emb_en, flag_emb.detach(), predicates_1D, seq_len, para=True, use_bert=True) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) pred_recur = self.SR_Compressor(SRL_input.detach(), bert_emb_en, flag_emb.detach(), None, predicates_1D, seq_len, para=True, use_bert=True) seq_len_fr = flag_emb_fr.shape[1] SRL_output_fr = self.SR_Labeler(bert_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True, use_bert=True) SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1) pred_recur_fr = self.SR_Compressor(SRL_input_fr, bert_emb_fr, flag_emb_fr.detach(), None, predicates_1D_fr, seq_len_fr, para=True, use_bert=True) """ En event vector, En word """ output_word_en_en = self.SR_Matcher(pred_recur.detach(), bert_emb_en, flag_emb.detach(), None, seq_len, para=True, use_bert=True).detach() ############################################# """ Fr event vector, En word """ output_word_fr_en = self.SR_Matcher(pred_recur_fr, bert_emb_en, flag_emb.detach(), None, seq_len, para=True, use_bert=True) ## B*T R 2 Union_enfr_en = torch.cat((output_word_en_en.view(-1, self.target_vocab_size, 1), output_word_fr_en.view(-1, self.target_vocab_size, 1)), 2) ## B*T R max_enfr_en = torch.max(Union_enfr_en, 2)[0] #############################################3 """ En event vector, Fr word """ output_word_en_fr = self.SR_Matcher(pred_recur.detach(), bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr, para=True, use_bert=True).detach() """ Fr event vector, Fr word """ output_word_fr_fr = self.SR_Matcher(pred_recur_fr, bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr, para=True, use_bert=True) ## B*T R 2 Union_enfr_fr = torch.cat((output_word_en_fr.view(-1, self.target_vocab_size, 1), output_word_fr_fr.view(-1, self.target_vocab_size, 1)), 2) ## B*T R max_enfr_fr = torch.max(Union_enfr_fr, 2)[0] unlabeled_loss_function = nn.KLDivLoss(reduction='none') """ word_mask_4en = self.P_word_mask(output_word_fr_en.view(self.batch_size, seq_len, -1), output_word_fr_fr.view(self.batch_size, seq_len_fr, -1), seq_len_en) word_mask_4en_tensor = get_torch_variable_from_np(word_mask_4en).view(self.batch_size*seq_len_en, -1) word_mask_4fr = self.R_word_mask(output_word_en_en.view(self.batch_size, seq_len, -1), output_word_en_fr.view(self.batch_size, seq_len_fr, -1), seq_len_fr) word_mask_4fr_tensor = get_torch_variable_from_np(word_mask_4fr).view(self.batch_size*seq_len_fr, -1) word_mask_en, word_mask_fr = self.word_mask_soft(output_word_en_en.view(self.batch_size, seq_len_en, -1), output_word_en_fr.view(self.batch_size, seq_len_fr, -1), seq_len_en, seq_len_fr) word_mask_en_tensor = get_torch_variable_from_np(word_mask_en).view(self.batch_size * seq_len_en) word_mask_fr_tensor = get_torch_variable_from_np(word_mask_fr).view(self.batch_size * seq_len_fr) """ #output_word_en_en = F.softmax(output_word_en_en, dim=1).detach() #output_word_fr_en = F.log_softmax(output_word_fr_en, dim=1) #loss = unlabeled_loss_function(output_word_fr_en, output_word_en_en) output_word_en_en = F.softmax(output_word_en_en, dim=1).detach() max_enfr_en = F.log_softmax(max_enfr_en, dim=1) loss = unlabeled_loss_function(max_enfr_en, output_word_en_en) loss = loss.sum(dim=1)#*word_mask_en_tensor loss = loss.sum() / (self.batch_size*seq_len_en) #output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach() max_enfr_fr = F.softmax(max_enfr_fr, dim=1).detach() output_word_fr_fr = F.log_softmax(output_word_fr_fr, dim=1) loss_2 = unlabeled_loss_function(output_word_fr_fr, max_enfr_fr) loss_2 = loss_2.sum(dim=1)#*word_mask_fr_tensor loss_2 = loss_2.sum()/ (self.batch_size*seq_len_fr) return loss, loss_2
def forward(self, batch_input, elmo): flag_batch = get_torch_variable_from_np(batch_input['flag']) word_batch = get_torch_variable_from_np(batch_input['word']) lemma_batch = get_torch_variable_from_np(batch_input['lemma']) pos_batch = get_torch_variable_from_np(batch_input['pos']) deprel_batch = get_torch_variable_from_np(batch_input['deprel']) pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) origin_batch = batch_input['origin'] origin_deprel_batch = batch_input['deprel'] chars_batch = get_torch_variable_from_np(batch_input['char']) if self.use_flag_embedding: flag_emb = self.flag_embedding(flag_batch) else: flag_emb = flag_batch.view(flag_batch.shape[0], flag_batch.shape[1], 1).float() seq_len = flag_batch.shape[1] word_emb = self.word_embedding(word_batch) lemma_emb = self.lemma_embedding(lemma_batch) pos_emb = self.pos_embedding(pos_batch) char_embeddings = self.char_embeddings(chars_batch) character_embeddings = self.charCNN(char_embeddings) pretrain_emb = self.pretrained_embedding(pretrain_batch) if self.use_deprel: deprel_emb = self.deprel_embedding(deprel_batch) # predicate_emb = self.word_embedding(predicate_batch) # predicate_pretrain_emb = self.pretrained_embedding(predicate_pretrain_batch) ##sentence learner##################################### SL_input_emb = self.word_dropout( torch.cat([word_emb, pretrain_emb, character_embeddings], 2)) h0, (_, SL_final_state) = self.sentence_learner(SL_input_emb, self.SL_hidden_state0) h1, (_, SL_final_state) = self.sentence_learner_high( h0, self.SL_hidden_state0_high) SL_output = h1 POS_output = self.pos_classifier(SL_output).view( self.batch_size * seq_len, -1) PI_output = self.PI_classifier(SL_output).view( self.batch_size * seq_len, -1) ## deprel hidden_input = SL_output arg_hidden = self.mlp_dropout(self.mlp_arg_deprel(SL_output)) predicates_1D = batch_input['predicates_idx'] pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D] pred_hidden = self.pred_dropout(self.mlp_pred_deprel(pred_recur)) deprel_output = bilinear(arg_hidden, self.deprel_W, pred_hidden, self.mlp_size, seq_len, 1, self.batch_size, num_outputs=self.deprel_vocab_size, bias_x=True, bias_y=True) arg_hidden = self.mlp_dropout(self.mlp_arg_link(SL_output)) predicates_1D = batch_input['predicates_idx'] pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D] pred_hidden = self.pred_dropout(self.mlp_pred_link(pred_recur)) Link_output = bilinear(arg_hidden, self.link_W, pred_hidden, self.mlp_size, seq_len, 1, self.batch_size, num_outputs=4, bias_x=True, bias_y=True) deprel_output = deprel_output.view(self.batch_size * seq_len, -1) Link_output = Link_output.view(self.batch_size * seq_len, -1) POS_probs = F.softmax(POS_output, dim=1).view(self.batch_size, seq_len, -1) deprel_probs = F.softmax(deprel_output, dim=1).view(self.batch_size, seq_len, -1) link_probs = F.softmax(Link_output, dim=1).view(self.batch_size, seq_len, -1) POS_compose = F.tanh(self.POS2hidden(POS_probs)) deprel_compose = F.tanh(self.deprel2hidden(deprel_probs)) link_compose = link_probs #######semantic role labelerxxxxxxxxxx if self.use_deprel: input_emb = torch.cat([ flag_emb, word_emb, pretrain_emb, character_embeddings, POS_compose, deprel_compose, link_compose ], 2) # else: input_emb = torch.cat([ flag_emb, word_emb, pretrain_emb, character_embeddings, POS_compose ], 2) # input_emb = self.word_dropout(input_emb) w = F.softmax(self.elmo_w, dim=0) SRL_composer = self.elmo_gamma * (w[0] * h0 + w[1] * h1) SRL_composer = self.elmo_mlp(SRL_composer) bilstm_output_0, (_, bilstm_final_state) = self.bilstm_layer( input_emb, self.bilstm_hidden_state0) high_input = torch.cat((bilstm_output_0, SRL_composer), 2) bilstm_output, (_, bilstm_final_state) = self.bilstm_layer_high( high_input, self.bilstm_hidden_state0_high) # bilstm_final_state = bilstm_final_state.view(self.bilstm_num_layers, 2, self.batch_size, self.bilstm_hidden_size) # bilstm_final_state = bilstm_final_state[-1] # sentence latent representation # bilstm_final_state = torch.cat([bilstm_final_state[0], bilstm_final_state[1]], 1) # bilstm_output = self.bilstm_mlp(bilstm_output) if self.use_self_attn: x = F.tanh(self.attn_linear_first(bilstm_output)) x = self.attn_linear_second(x) x = self.softmax(x, 1) attention = x.transpose(1, 2) sentence_embeddings = torch.matmul(attention, bilstm_output) sentence_embeddings = torch.sum(sentence_embeddings, 1) / self.self_attn_head context = sentence_embeddings.repeat(bilstm_output.size(1), 1, 1).transpose(0, 1) bilstm_output = torch.cat([bilstm_output, context], 2) bilstm_output = self.attn_linear_final(bilstm_output) # energy = self.biaf_attn(bilstm_output, bilstm_output) # # energy = energy.transpose(1, 2) # flag_indices = batch_input['flag_indices'] # attention = [] # for idx in range(len(flag_indices)): # attention.append(energy[idx,:,:,flag_indices[idx]].view(1, self.bilstm_hidden_size, -1)) # attention = torch.cat(attention,dim=0) # # attention = attention.transpose(1, 2) # attention = self.softmax(attention, 2) # # attention = attention.transpose(1,2) # sentence_embeddings = attention@bilstm_output # sentence_embeddings = torch.sum(sentence_embeddings,1)/self.self_attn_head # context = sentence_embeddings.repeat(bilstm_output.size(1), 1, 1).transpose(0, 1) # bilstm_output = torch.cat([bilstm_output, context], 2) # bilstm_output = self.attn_linear_final(bilstm_output) else: bilstm_output = bilstm_output.contiguous() if self.use_gcn: # in_rep = torch.matmul(bilstm_output, self.W_in) # out_rep = torch.matmul(bilstm_output, self.W_out) # self_rep = torch.matmul(bilstm_output, self.W_self) # child_indicies = batch_input['children_indicies'] # head_batch = batch_input['head'] # context = [] # for idx in range(head_batch.shape[0]): # states = [] # for jdx in range(head_batch.shape[1]): # head_ind = head_batch[idx, jdx]-1 # childs = child_indicies[idx][jdx] # state = self_rep[idx, jdx] # if head_ind != -1: # state = state + in_rep[idx, head_ind] # for child in childs: # state = state + out_rep[idx, child] # state = F.relu(state + self.gcn_bias) # states.append(state.unsqueeze(0)) # context.append(torch.cat(states, dim=0)) # context = torch.cat(context, dim=0) # bilstm_output = context seq_len = bilstm_output.shape[1] adj_arc_in = np.zeros((self.batch_size * seq_len, 2), dtype='int32') adj_lab_in = np.zeros((self.batch_size * seq_len), dtype='int32') adj_arc_out = np.zeros((self.batch_size * seq_len, 2), dtype='int32') adj_lab_out = np.zeros((self.batch_size * seq_len), dtype='int32') mask_in = np.zeros((self.batch_size * seq_len), dtype='float32') mask_out = np.zeros((self.batch_size * seq_len), dtype='float32') mask_loop = np.ones((self.batch_size * seq_len, 1), dtype='float32') for idx in range(len(origin_batch)): for jdx in range(len(origin_batch[idx])): offset = jdx + idx * seq_len head_ind = int(origin_batch[idx][jdx][10]) - 1 if head_ind == -1: continue dependent_ind = int(origin_batch[idx][jdx][4]) - 1 adj_arc_in[offset] = np.array([idx, dependent_ind]) adj_lab_in[offset] = np.array( [origin_deprel_batch[idx, jdx]]) mask_in[offset] = 1 adj_arc_out[offset] = np.array([idx, head_ind]) adj_lab_out[offset] = np.array( [origin_deprel_batch[idx, jdx]]) mask_out[offset] = 1 if USE_CUDA: adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in)).cuda() adj_arc_out = torch.LongTensor( np.transpose(adj_arc_out)).cuda() adj_lab_in = Variable(torch.LongTensor(adj_lab_in).cuda()) adj_lab_out = Variable(torch.LongTensor(adj_lab_out).cuda()) mask_in = Variable( torch.FloatTensor( mask_in.reshape( (self.batch_size * seq_len, 1))).cuda()) mask_out = Variable( torch.FloatTensor( mask_out.reshape( (self.batch_size * seq_len, 1))).cuda()) mask_loop = Variable(torch.FloatTensor(mask_loop).cuda()) else: adj_arc_in = torch.LongTensor(np.transpose(adj_arc_in)) adj_arc_out = torch.LongTensor(np.transpose(adj_arc_out)) adj_lab_in = Variable(torch.LongTensor(adj_lab_in)) adj_lab_out = Variable(torch.LongTensor(adj_lab_out)) mask_in = Variable( torch.FloatTensor( mask_in.reshape((self.batch_size * seq_len, 1)))) mask_out = Variable( torch.FloatTensor( mask_out.reshape((self.batch_size * seq_len, 1)))) mask_loop = Variable(torch.FloatTensor(mask_loop)) gcn_context = self.syntactic_gcn(bilstm_output, adj_arc_in, adj_arc_out, adj_lab_in, adj_lab_out, mask_in, mask_out, mask_loop) #gcn_context = self.softmax(gcn_context, axis=2) gcn_context = F.softmax(gcn_context, dim=2) bilstm_output = torch.cat([bilstm_output, gcn_context], dim=2) bilstm_output = self.gcn_mlp(bilstm_output) hidden_input = bilstm_output.view( bilstm_output.shape[0] * bilstm_output.shape[1], -1) if self.use_highway: for current_layer in self.highway_layers: hidden_input = current_layer(hidden_input) output = self.output_layer(hidden_input) else: hidden_input = hidden_input.view(self.batch_size, seq_len, -1) #output = self.output_layer(hidden_input) if self.use_biaffine: arg_hidden = self.mlp_dropout(self.mlp_arg(hidden_input)) predicates_1D = batch_input['predicates_idx'] pred_recur = hidden_input[np.arange(0, self.batch_size), predicates_1D] pred_hidden = self.pred_dropout(self.mlp_pred(pred_recur)) output = bilinear(arg_hidden, self.rel_W, pred_hidden, self.mlp_size, seq_len, 1, self.batch_size, num_outputs=self.target_vocab_size, bias_x=True, bias_y=True) output = output.view(self.batch_size * seq_len, -1) return output, POS_output, PI_output, deprel_output, Link_output
def forward(self, batch_input, elmo, withParallel=True, lang='En', isPretrain=False, TrainGenerator=False): if lang == 'En': pretrain_batch = get_torch_variable_from_np( batch_input['pretrain']) else: pretrain_batch = get_torch_variable_from_np( batch_input['pretrain']) flag_batch = get_torch_variable_from_np(batch_input['flag']) if withParallel: fr_pretrain_batch = get_torch_variable_from_np( batch_input['fr_pretrain']) fr_flag_batch = get_torch_variable_from_np(batch_input['fr_flag']) if self.use_flag_embedding: flag_emb = self.flag_embedding(flag_batch) else: flag_emb = flag_batch.view(flag_batch.shape[0], flag_batch.shape[1], 1).float() if lang == "En": pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() else: pretrain_emb = self.fr_pretrained_embedding( pretrain_batch).detach() if withParallel: fr_pretrain_emb = self.fr_pretrained_embedding( fr_pretrain_batch).detach() fr_flag_emb = self.flag_embedding(fr_flag_batch).detach() input_emb = torch.cat([flag_emb, pretrain_emb], 2) predicates_1D = batch_input['predicates_idx'] if withParallel: fr_input_emb = torch.cat([fr_flag_emb, fr_pretrain_emb], 2) output_en, enc_real = self.EN_Labeler(input_emb, predicates_1D) output_fr, _ = self.FR_Labeler(input_emb, predicates_1D) if not withParallel: if isPretrain: return output_en else: return output_fr predicates_1D = batch_input['fr_predicates_idx'] _, enc_fake = self.FR_Labeler(fr_input_emb, predicates_1D) x_D = torch.cat([real_states.detach(), fake_states.detach()], 0) x_G = torch.cat([real_states.detach(), fake_states], 0) y = torch.FloatTensor(2 * self.batch_size).zero_().to(device) y[:self.batch_size] = 1 - self.dis_smooth y[self.batch_size:] = self.dis_smooth if not TrainGenerator: #prob_real_decision = self.Discriminator(real_states.detach()) #prob_fake_decision = self.Discriminator(fake_states.detach()) #D_loss= - torch.mean(torch.log(prob_real_decision) + torch.log(1. - prob_fake_decision)) preds = self.Discriminator(Variable(x_D.data)) D_loss = F.binary_cross_entropy(preds, 1 - y) return D_loss * get_torch_variable_from_np( batch_input['fr_loss_mask']).float() else: #prob_fake_decision_G = self.Discriminator(real_states.detach()) #G_loss = -torch.mean(torch.log(prob_fake_decision_G)) #log("G loss:", G_loss) preds = self.Discriminator(x_G) G_loss = F.binary_cross_entropy(preds, y) return G_loss * get_torch_variable_from_np( batch_input['fr_loss_mask']).float()
def parallel_train_(self, batch_input, use_bert, isTrain=True): unlabeled_data_en, unlabeled_data_fr = batch_input predicates_1D_fr = unlabeled_data_fr['predicates_idx'] flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag']) flag_emb_fr = self.flag_embedding(flag_batch_fr).detach() actual_lens_fr = unlabeled_data_fr['seq_len'] predicates_1D = unlabeled_data_en['predicates_idx'] flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag']) actual_lens_en = unlabeled_data_en['seq_len'] flag_emb = self.flag_embedding(flag_batch).detach() seq_len = flag_emb.shape[1] seq_len_en = seq_len if use_bert: bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids']) bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask']) bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions']) bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr) bert_emb_fr = bert_emb_fr[0] bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach() bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach() for i in range(len(bert_emb_fr)): if i >= len(actual_lens_fr): print("error") break for j in range(len(bert_emb_fr[i])): if j >= actual_lens_fr[i]: bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb_fr_noise = gaussian(bert_emb_fr, isTrain, 0, 0.1).detach() bert_emb_fr = bert_emb_fr.detach() bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids']) bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask']) bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions']) bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en) bert_emb_en = bert_emb_en[0] bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach() bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach() for i in range(len(bert_emb_en)): if i >= len(actual_lens_en): print("error") break for j in range(len(bert_emb_en[i])): if j >= actual_lens_en[i]: bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb_en_noise = gaussian(bert_emb_en, isTrain, 0, 0.1).detach() bert_emb_en = bert_emb_en.detach() seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(bert_emb_en, flag_emb.detach(), predicates_1D, seq_len, para=True, use_bert=True) CopyLoss_en_noise = self.copy_loss(SRL_output, bert_emb_en_noise, flag_emb.detach(), seq_len) CopyLoss_en = self.copy_loss(SRL_output, bert_emb_en, flag_emb.detach(), seq_len) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input = F.softmax(SRL_input, 2) pred_recur = self.SR_Compressor(SRL_input.detach(), bert_emb_en, flag_emb.detach(), None, predicates_1D, seq_len, para=True, use_bert=True) seq_len_fr = flag_emb_fr.shape[1] SRL_output_fr = self.SR_Labeler(bert_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True, use_bert=True) CopyLoss_fr_noise = self.copy_loss(SRL_output_fr, bert_emb_fr_noise, flag_emb_fr.detach(), seq_len_fr) CopyLoss_fr = self.copy_loss(SRL_output_fr, bert_emb_fr, flag_emb_fr.detach(), seq_len_fr) SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1) SRL_input_fr = F.softmax(SRL_input_fr, 2) pred_recur_fr = self.SR_Compressor(SRL_input_fr, bert_emb_fr, flag_emb_fr.detach(), None, predicates_1D_fr, seq_len_fr, para=True, use_bert=True) """ En event vector, En word """ output_word_en_en = self.SR_Matcher(pred_recur.detach(), bert_emb_en, flag_emb.detach(), None, seq_len, para=True, use_bert=True).detach() score4Null = torch.zeros_like(output_word_en_en[:, 1:2]) output_word_en_en = torch.cat((output_word_en_en[:, 0:1], score4Null, output_word_en_en[:, 1:]), 1) ############################################# """ Fr event vector, En word """ output_word_fr_en = self.SR_Matcher(pred_recur_fr, bert_emb_en, flag_emb.detach(), None, seq_len, para=True, use_bert=True) score4Null = torch.zeros_like(output_word_fr_en[:, 1:2]) output_word_fr_en = torch.cat((output_word_fr_en[:, 0:1], score4Null, output_word_fr_en[:, 1:]), 1) #############################################3 """ En event vector, Fr word """ output_word_en_fr = self.SR_Matcher(pred_recur.detach(), bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr, para=True, use_bert=True).detach() score4Null = torch.zeros_like(output_word_en_fr[:, 1:2]) output_word_en_fr = torch.cat((output_word_en_fr[:, 0:1], score4Null, output_word_en_fr[:, 1:]), 1) """ Fr event vector, Fr word """ output_word_fr_fr = self.SR_Matcher(pred_recur_fr, bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr, para=True, use_bert=True) score4Null = torch.zeros_like(output_word_fr_fr[:, 1:2]) output_word_fr_fr = torch.cat((output_word_fr_fr[:, 0:1], score4Null, output_word_fr_fr[:, 1:]), 1) """ mask_en_en, mask_en_fr = self.filter_word(SRL_input, output_word_en_en,output_word_en_fr, seq_len, seq_len_fr) mask_en_en = get_torch_variable_from_np(mask_en_en) mask_en_fr = get_torch_variable_from_np(mask_en_fr) mask_fr_en, mask_fr_fr = self.filter_word_fr(output_word_fr_en, output_word_fr_fr, seq_len, seq_len_fr) mask_fr_en = get_torch_variable_from_np(mask_fr_en) mask_fr_fr = get_torch_variable_from_np(mask_fr_fr) mask_en_word = mask_en_en + mask_fr_en - mask_en_en*mask_fr_en mask_fr_word = mask_en_fr + mask_fr_fr - mask_en_fr*mask_fr_fr """ unlabeled_loss_function = nn.KLDivLoss(reduction='none') # output_word_en_en = F.softmax(output_word_en_en, dim=1).detach() # output_word_fr_en = F.log_softmax(output_word_fr_en, dim=1) # loss = unlabeled_loss_function(output_word_fr_en, output_word_en_en) #output_word_en_en = F.softmax(output_word_en_en, dim=1).detach() output_word_en_en = F.softmax(output_word_en_en, dim=1).detach() output_word_fr_en = F.log_softmax(output_word_fr_en, dim=1) loss = unlabeled_loss_function(output_word_fr_en, output_word_en_en).sum(dim=1)#*mask_en_word.view(-1) #loss = loss.sum() / mask_en_word.sum() #(self.batch_size * seq_len_en) #if mask_en_word.sum().cpu().numpy() > 1: loss = loss.sum() / (self.batch_size * seq_len) #else: # loss = loss.sum() # output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach() output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach() output_word_fr_fr = F.log_softmax(output_word_fr_fr, dim=1) #output_word_fr_fr = F.log_softmax(SRL_output_fr, dim=1) loss_2 = unlabeled_loss_function(output_word_fr_fr, output_word_en_fr).sum(dim=1)#*mask_fr_word.view(-1) #if mask_fr_word.sum().cpu().numpy() > 1: loss_2 = loss_2.sum() / (self.batch_size * seq_len_fr) #else: # loss_2 = loss_2.sum() #print(mask_en_word.sum()) #print(mask_fr_word.sum()) return loss, loss_2, CopyLoss_en, CopyLoss_fr, CopyLoss_en_noise, CopyLoss_fr_noise
def train_1_epoc(srl_model, criterion, optimizer, train_dataset, labeled_dataset_fr, batch_size, word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2word, shuffle=False, lang='En', dev_best_score=None, test_best_score=None, test_ood_best_score=None): for batch_i, train_input_data in enumerate( inter_utils.get_batch(train_dataset, batch_size, word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2word, shuffle=shuffle, lang=lang)): flat_argument = train_input_data['flat_argument'] target_batch_variable = get_torch_variable_from_np(flat_argument) out, out_word = srl_model(train_input_data, lang='En') loss = criterion(out, target_batch_variable) loss_word = criterion(out_word, target_batch_variable) if batch_i % 50 == 0: log(batch_i, loss, loss_word) optimizer.zero_grad() (loss + loss_word).backward() optimizer.step() if batch_i > 0 and batch_i % show_steps == 0: _, pred = torch.max(out, 1) pred = get_data(pred) # pred = pred.reshape([bs, sl]) log('\n') log('*' * 80) eval_train_batch(epoch, batch_i, loss.data[0], flat_argument, pred, argument2idx) log('FR test:') score, dev_output = eval_data(srl_model, elmo, labeled_dataset_fr, batch_size, word2idx, fr_word2idx, lemma2idx, pos2idx, pretrain2idx, fr_pretrain2idx, deprel2idx, argument2idx, idx2argument, idx2word, False, dev_predicate_correct, dev_predicate_sum, lang='Fr') if dev_best_score is None or score[5] > dev_best_score[5]: dev_best_score = score output_predict( os.path.join( result_path, 'dev_argument_{:.2f}.pred'.format( dev_best_score[2] * 100)), dev_output) # torch.save(srl_model, os.path.join(os.path.dirname(__file__),'model/best_{:.2f}.pkl'.format(dev_best_score[2]*100))) log('\tdev best P:{:.2f} R:{:.2f} F1:{:.2f} NP:{:.2f} NR:{:.2f} NF1:{:.2f}' .format(dev_best_score[0] * 100, dev_best_score[1] * 100, dev_best_score[2] * 100, dev_best_score[3] * 100, dev_best_score[4] * 100, dev_best_score[5] * 100)) return dev_best_score
def eval_pred_recog_data(model, elmo, dataset, batch_size, word2idx, lemma2idx, pos2idx, pretrain2idx, pred2idx, idx2pred): model.eval() golden = [] predict = [] output_data = [] cur_sentence = None cur_sentence_data = None for batch_i, input_data in enumerate( pred_recog_inter_utils.get_batch(dataset, batch_size, word2idx, lemma2idx, pos2idx, pretrain2idx, pred2idx)): target_pred = input_data['pred'] flat_pred = input_data['flat_pred'] target_batch_variable = get_torch_variable_from_np(flat_pred) sentence_id = input_data['sentence_id'] word_id = input_data['word_id'] sentence_len = input_data['sentence_len'] seq_len = input_data['seq_len'] bs = input_data['batch_size'] psl = input_data['pad_seq_len'] out = model(input_data, elmo) _, pred = torch.max(out, 1) pred = get_data(pred) pred = pred.tolist() golden += flat_pred.tolist() predict += pred pre_data = [] for b in range(len(seq_len)): line_data = ['_' for _ in range(sentence_len[b])] for s in range(seq_len[b]): wid = word_id[b][s] line_data[wid - 1] = idx2pred[pred[b * psl + s]] pre_data.append(line_data) for b in range(len(sentence_id)): if cur_sentence != sentence_id[b]: if cur_sentence_data is not None: output_data.append(cur_sentence_data) cur_sentence_data = [[sentence_id[b]] * len(pre_data[b]), pre_data[b]] cur_sentence = sentence_id[b] else: assert cur_sentence_data is not None cur_sentence_data.append(pre_data[b]) if cur_sentence_data is not None and len(cur_sentence_data) > 0: output_data.append(cur_sentence_data) score = pred_recog_score(golden, predict, pred2idx) model.train() return score, output_data
def parallel_train(self, batch_input, use_bert, isTrain=True): unlabeled_data_en, unlabeled_data_fr = batch_input predicates_1D_fr = unlabeled_data_fr['predicates_idx'] flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag']) flag_emb_fr = self.flag_embedding(flag_batch_fr).detach() actual_lens_fr = unlabeled_data_fr['seq_len'] predicates_1D = unlabeled_data_en['predicates_idx'] flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag']) actual_lens_en = unlabeled_data_en['seq_len'] flag_emb = self.flag_embedding(flag_batch).detach() seq_len = flag_emb.shape[1] seq_len_en = seq_len if use_bert: bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids']) bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask']) bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions']) bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr) bert_emb_fr = bert_emb_fr[0] bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach() bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach() for i in range(len(bert_emb_fr)): if i >= len(actual_lens_fr): print("error") break for j in range(len(bert_emb_fr[i])): if j >= actual_lens_fr[i]: bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) #bert_emb_fr = gaussian(bert_emb_fr, isTrain, 0, 0.1) bert_emb_fr = bert_emb_fr.detach() bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids']) bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask']) bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions']) bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en) bert_emb_en = bert_emb_en[0] #bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach() bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach() for i in range(len(bert_emb_en)): if i >= len(actual_lens_en): print("error") break for j in range(len(bert_emb_en[i])): if j >= actual_lens_en[i]: bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) #bert_emb_en = gaussian(bert_emb_en, isTrain, 0, 0.1) bert_emb_en = bert_emb_en.detach() seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(bert_emb_en, flag_emb.detach(), predicates_1D, seq_len, para=True, use_bert=True) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) pred_recur = self.SR_Compressor(SRL_input.detach(), bert_emb_en, flag_emb.detach(), None, predicates_1D, seq_len, para=True, use_bert=True) seq_len_fr = flag_emb_fr.shape[1] SRL_output_fr = self.SR_Labeler(bert_emb_fr, flag_emb_fr.detach(), predicates_1D_fr, seq_len_fr, para=True, use_bert=True) SRL_input_fr = SRL_output_fr.view(self.batch_size, seq_len_fr, -1) pred_recur_fr = self.SR_Compressor(SRL_input_fr, bert_emb_fr, flag_emb_fr.detach(), None, predicates_1D_fr, seq_len_fr, para=True, use_bert=True) """ En event vector, En word """ output_word_en_en = self.SR_Matcher(pred_recur.detach(), bert_emb_en, flag_emb.detach(), None, seq_len, para=True, use_bert=True) output_word_en_en = F.softmax(output_word_en_en, dim=1).detach() output_word_en_en_nonNull = torch.cat((output_word_en_en[:, 0:1], output_word_en_en[:, 2:]), 1) output_en_en_nonNull_max, output_en_en_nonNull_argmax = torch.max(output_word_en_en_nonNull, 1) ############################################# """ Fr event vector, En word """ output_word_fr_en = self.SR_Matcher(pred_recur_fr, bert_emb_en, flag_emb.detach(), None, seq_len, para=True, use_bert=True) output_word_fr_en = F.softmax(output_word_fr_en, dim=1) #############################################3 """ En event vector, Fr word """ output_word_en_fr = self.SR_Matcher(pred_recur.detach(), bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr, para=True, use_bert=True) output_word_en_fr = F.softmax(output_word_en_fr, dim=1).detach() output_word_en_fr_nonNull = torch.cat((output_word_en_fr[:, 0:1], output_word_en_fr[:, 2:]), 1) output_en_fr_nonNull_max, output_en_fr_nonNull_argmax = torch.max(output_word_en_fr_nonNull, 1) """ Fr event vector, Fr word """ output_word_fr_fr = self.SR_Matcher(pred_recur_fr, bert_emb_fr, flag_emb_fr.detach(), None, seq_len_fr, para=True, use_bert=True) output_word_fr_fr = F.softmax(output_word_fr_fr, dim=1) output_word_fr_fr_nonNull = torch.cat((output_word_fr_fr[:, 0:1], output_word_fr_fr[:, 2:]), 1) output_word_fr_fr_nonNull_maxarg = torch.gather(output_word_fr_fr_nonNull, 1, output_en_fr_nonNull_argmax.view(-1,1)) ## B*T R 2 #Union_enfr_fr = torch.cat((output_word_en_fr.view(-1, self.target_vocab_size, 1), # output_word_fr_fr.view(-1, self.target_vocab_size, 1)), 2) ## B*T R #max_enfr_fr = torch.max(output_word_fr_fr, output_word_en_fr).detach() #max_enfr_fr[:, :2] = output_word_fr_fr[:,:2].detach() unlabeled_loss_function = nn.L1Loss(reduction='none') loss = unlabeled_loss_function(output_word_fr_en[:, 1], output_word_en_en[:, 1]) theta = torch.gt(output_word_en_en[:, 1], output_en_en_nonNull_max) loss = theta * loss if torch.gt(theta.sum(), 0): loss = loss.sum() /theta.sum() else: loss = loss.sum() loss_2 = unlabeled_loss_function(output_word_fr_fr_nonNull_maxarg.view(-1), output_en_fr_nonNull_max) theta = torch.gt(output_en_fr_nonNull_max, output_word_en_fr[:, 1]) loss_2 = theta*loss_2 if torch.gt(theta.sum(), 0): loss_2 = loss_2.sum() /theta.sum() else: loss_2 = loss_2.sum() return loss, loss_2
def word_trans(self, batch_input, use_bert, isTrain=True): unlabeled_data_en, unlabeled_data_fr = batch_input predicates_1D_fr = unlabeled_data_fr['predicates_idx'] flag_batch_fr = get_torch_variable_from_np(unlabeled_data_fr['flag']) flag_emb_fr = self.flag_embedding(flag_batch_fr).detach() actual_lens_fr = unlabeled_data_fr['seq_len'] predicates_1D = unlabeled_data_en['predicates_idx'] flag_batch = get_torch_variable_from_np(unlabeled_data_en['flag']) actual_lens_en = unlabeled_data_en['seq_len'] flag_emb = self.flag_embedding(flag_batch).detach() seq_len = flag_emb.shape[1] seq_len_en = seq_len if use_bert: bert_input_ids_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_ids']) bert_input_mask_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_input_mask']) bert_out_positions_fr = get_torch_variable_from_np(unlabeled_data_fr['bert_out_positions']) bert_emb_fr = self.model(bert_input_ids_fr, attention_mask=bert_input_mask_fr) bert_emb_fr = bert_emb_fr[0] bert_emb_fr = bert_emb_fr[:, 1:-1, :].contiguous().detach() bert_emb_fr = bert_emb_fr[torch.arange(bert_emb_fr.size(0)).unsqueeze(-1), bert_out_positions_fr].detach() for i in range(len(bert_emb_fr)): if i >= len(actual_lens_fr): print("error") break for j in range(len(bert_emb_fr[i])): if j >= actual_lens_fr[i]: bert_emb_fr[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb_fr_noise = gaussian(bert_emb_fr, isTrain, 0, 0.1).detach() bert_emb_fr = bert_emb_fr.detach() bert_input_ids_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_ids']) bert_input_mask_en = get_torch_variable_from_np(unlabeled_data_en['bert_input_mask']) bert_out_positions_en = get_torch_variable_from_np(unlabeled_data_en['bert_out_positions']) bert_emb_en = self.model(bert_input_ids_en, attention_mask=bert_input_mask_en) bert_emb_en = bert_emb_en[0] bert_emb_en = bert_emb_en[:, 1:-1, :].contiguous().detach() bert_emb_en = bert_emb_en[torch.arange(bert_emb_en.size(0)).unsqueeze(-1), bert_out_positions_en].detach() for i in range(len(bert_emb_en)): if i >= len(actual_lens_en): print("error") break for j in range(len(bert_emb_en[i])): if j >= actual_lens_en[i]: bert_emb_en[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb_en_noise = gaussian(bert_emb_en, isTrain, 0, 0.1).detach() bert_emb_en = bert_emb_en.detach() pred_bert_fr = bert_emb_fr[np.arange(0, self.batch_size), predicates_1D_fr] pred_bert_en = bert_emb_en[np.arange(0, self.batch_size), predicates_1D] #transed_bert_fr = self.Fr2En_Trans(pred_bert_fr) En_Extracted = self.bert_FeatureExtractor(pred_bert_en) Fr_Extracted = self.bert_FeatureExtractor(pred_bert_fr) #loss = nn.MSELoss() #l2loss = loss(Fr_Extracted, En_Extracted) #return l2loss x_D_real = En_Extracted.view(-1, 256)#self.bert_NonlinearTrans(pred_bert_en.detach().view(-1, 768)) x_D_fake = Fr_Extracted.view(-1, 256) #x_D_real = self.En_LinearTrans(pred_bert_en.detach()).view(-1, 768) #x_D_fake = self.Fr_LinearTrans(pred_bert_fr.detach()).view(-1, 768) en_preds = self.Discriminator(x_D_real).view(self.batch_size, 2) real_labels = torch.empty((30,1), dtype=torch.long).fill_(1).view(-1) #D_loss_real = F.binary_cross_entropy(en_preds, real_labels) fr_preds = self.Discriminator(x_D_fake).view(self.batch_size, 2) fake_labels = torch.empty((30,1), dtype=torch.long).fill_(0).view(-1) #D_loss_fake = F.binary_cross_entropy(fr_preds, fake_labels) #D_loss = 0.5 * (D_loss_real + D_loss_fake) preds = torch.cat((en_preds, fr_preds), 0) labels = torch.cat((real_labels, fake_labels)).to(device) criterion = nn.CrossEntropyLoss() loss = criterion(preds, labels) return loss
def forward(self, batch_input, lang='En', unlabeled=False, self_constrain=False, use_bert=False, isTrain=False): if unlabeled: loss, loss_2 = self.parallel_train_(batch_input, use_bert) return loss, loss_2 #, copy_loss_en, copy_loss_fr #l2loss = self.word_trans(batch_input, use_bert) #return l2loss pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) predicates_1D = batch_input['predicates_idx'] flag_batch = get_torch_variable_from_np(batch_input['flag']) word_id = get_torch_variable_from_np(batch_input['word_times']) word_id_emb = self.id_embedding(word_id) flag_emb = self.flag_embedding(flag_batch) actual_lens = batch_input['seq_len'] # print(actual_lens) if use_bert: bert_input_ids = get_torch_variable_from_np( batch_input['bert_input_ids']) bert_input_mask = get_torch_variable_from_np( batch_input['bert_input_mask']) bert_out_positions = get_torch_variable_from_np( batch_input['bert_out_positions']) bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask) bert_emb = bert_emb[0] bert_emb = bert_emb[:, 1:-1, :].contiguous().detach() bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1), bert_out_positions].detach() for i in range(len(bert_emb)): if i >= len(actual_lens): break for j in range(len(bert_emb[i])): if j >= actual_lens[i]: bert_emb[i][j] = get_torch_variable_from_np( np.zeros(768, dtype="float32")) bert_emb = bert_emb.detach() #bert_emb_noise = gaussian(bert_emb, isTrain, 0, 0.1).detach() seq_len = flag_emb.shape[1] SRL_output = self.SR_Labeler(bert_emb, flag_emb, predicates_1D, seq_len, para=False, use_bert=True) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input_probs = F.softmax(SRL_input, 2).detach() if isTrain: role_embs = self.SR_Compressor(SRL_input_probs, bert_emb.detach(), flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False, use_bert=True) output_word = self.SR_Matcher(role_embs, bert_emb.detach(), flag_emb.detach(), word_id_emb.detach(), seq_len, copy=True, para=False, use_bert=True) else: role_embs = self.SR_Compressor(SRL_input_probs, bert_emb.detach(), flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False, use_bert=True) output_word = self.SR_Matcher(role_embs, bert_emb.detach(), flag_emb.detach(), word_id_emb.detach(), seq_len, copy=False, para=False, use_bert=True) #score4Null = torch.zeros_like(output_word[:, 1:2]) #output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1) recover_loss = self.learn_loss(SRL_input, output_word, seq_len) return SRL_output, output_word, recover_loss
def forward(self, batch_input, lang='En', unlabeled=False, self_constrain=False, use_bert=False, isTrain=False): if unlabeled: loss, loss_2, copy_loss_en, copy_loss_fr,a,b = self.parallel_train_(batch_input, use_bert) return loss, loss_2, copy_loss_en, copy_loss_fr, a, b #l2loss = self.word_trans(batch_input, use_bert) #return l2loss pretrain_batch = get_torch_variable_from_np(batch_input['pretrain']) predicates_1D = batch_input['predicates_idx'] flag_batch = get_torch_variable_from_np(batch_input['flag']) word_id = get_torch_variable_from_np(batch_input['word_times']) word_id_emb = self.id_embedding(word_id) flag_emb = self.flag_embedding(flag_batch) actual_lens = batch_input['seq_len'] if use_bert: bert_input_ids = get_torch_variable_from_np(batch_input['bert_input_ids']) bert_input_mask = get_torch_variable_from_np(batch_input['bert_input_mask']) bert_out_positions = get_torch_variable_from_np(batch_input['bert_out_positions']) bert_emb = self.model(bert_input_ids, attention_mask=bert_input_mask) bert_emb = bert_emb[0] bert_emb = bert_emb[:, 1:-1, :].contiguous().detach() bert_emb = bert_emb[torch.arange(bert_emb.size(0)).unsqueeze(-1), bert_out_positions].detach() for i in range(len(bert_emb)): if i >= len(actual_lens): break for j in range(len(bert_emb[i])): if j >= actual_lens[i]: bert_emb[i][j] = get_torch_variable_from_np(np.zeros(768, dtype="float32")) bert_emb = bert_emb.detach() #bert_emb = self.bert_NonlinearTrans(bert_emb) #bert_emb_noise = gaussian(bert_emb, isTrain, 0, 0.1).detach() if lang == "En": pretrain_emb = self.pretrained_embedding(pretrain_batch).detach() #bert_emb = gaussian(bert_emb, isTrain, 0, 0.1).detach() #bert_emb = self.En_LinearTrans(bert_emb).detach() else: pretrain_emb = self.fr_pretrained_embedding(pretrain_batch).detach() #bert_emb = self.Fr_LinearTrans(bert_emb).detach() #bert_emb = self.Fr2En_Trans(bert_emb).detach() #bert_emb = self.bert_FeatureExtractor(bert_emb) seq_len = flag_emb.shape[1] if not use_bert: SRL_output = self.SR_Labeler(pretrain_emb, flag_emb, predicates_1D, seq_len, para=False) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input = SRL_input pred_recur = self.SR_Compressor(SRL_input, pretrain_emb, flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False) output_word = self.SR_Matcher(pred_recur, pretrain_emb, flag_emb.detach(), word_id_emb.detach(), seq_len, para=False) else: SRL_output = self.SR_Labeler(bert_emb, flag_emb, predicates_1D, seq_len, para=False, use_bert=True) SRL_input = SRL_output.view(self.batch_size, seq_len, -1) SRL_input_probs = F.softmax(SRL_input, 2).detach() if isTrain: pred_recur = self.SR_Compressor(SRL_input_probs, bert_emb.detach(), flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False, use_bert=False) output_word = self.SR_Matcher(pred_recur, bert_emb.detach(), flag_emb.detach(), word_id_emb.detach(), seq_len, copy=True, para=False, use_bert=False) else: pred_recur = self.SR_Compressor(SRL_input_probs, bert_emb.detach(), flag_emb.detach(), word_id_emb, predicates_1D, seq_len, para=False, use_bert=False) output_word = self.SR_Matcher(pred_recur, bert_emb.detach(), flag_emb.detach(), word_id_emb.detach(), seq_len, para=False, use_bert=False) score4Null = torch.zeros_like(output_word[:, 1:2]) output_word = torch.cat((output_word[:, 0:1], score4Null, output_word[:, 1:]), 1) teacher = F.softmax(SRL_input.view(self.batch_size * seq_len, -1), dim=1).detach() student = F.log_softmax(output_word, dim=1) unlabeled_loss_function = nn.KLDivLoss(reduction='none') loss_copy = unlabeled_loss_function(student, teacher).view(self.batch_size*seq_len,-1) loss_copy = loss_copy.sum() / (self.batch_size*seq_len) #CopyLoss = self.learn_loss(SRL_output, pretrain_emb, flag_emb.detach(), seq_len, mask_copy, # mask_unk) return SRL_output, output_word, loss_copy
idx2word, shuffle=True)): target_argument = train_input_data['argument'] flat_argument = train_input_data['flat_argument'] gold_pos = train_input_data['gold_pos'] gold_PI = train_input_data['predicates_flag'] gold_deprel = train_input_data['sep_dep_rel'] gold_link = train_input_data['sep_dep_link'] target_batch_variable = get_torch_variable_from_np( flat_argument) gold_pos_batch_variable = get_torch_variable_from_np(gold_pos) gold_PI_batch_variable = get_torch_variable_from_np(gold_PI) gold_deprel_batch_variable = get_torch_variable_from_np( gold_deprel) gold_link_batch_variable = get_torch_variable_from_np( gold_link) bs = train_input_data['batch_size'] sl = train_input_data['seq_len'] out, out_pos, out_PI, out_deprel, out_link = srl_model( train_input_data, elmo) loss = criterion(out, target_batch_variable)