def action(self, sent, index, output, train=True): if train: action = sent.gold[index] else: max_score, max_index = torch.max(output, dim=1) # print(max_index) # print(max_score) action = self.id2gold[utils.to_scalar(max_index)] sent.action.append(action) pos_id = action.find('#') if pos_id == -1: last_word_record = sent.words_record[-1] + sent.chars[index] sent.words_record[-1] = last_word_record else: sent.words_record.append(sent.chars[index]) pos_record = action[(pos_id + 1):] sent.pos_record.append(pos_record) sent.pos_index.append(self.pos2id[pos_record])
def forward(self, fea_v, length, target_start, target_end): if self.add_char: word_v = fea_v[0] char_v = fea_v[1] else: word_v = fea_v batch_size = word_v.size(0) seq_length = word_v.size(1) word_emb = self.embedding(word_v) word_emb = self.dropout_emb(word_emb) if self.static: word_static = self.embedding_static(word_v) word_static = self.dropout_emb(word_static) word_emb = torch.cat([word_emb, word_static], 2) x = torch.transpose(word_emb, 0, 1) packed_words = pack_padded_sequence(x, length) lstm_out, self.hidden = self.lstm(packed_words, self.hidden) lstm_out, _ = pad_packed_sequence(lstm_out) ##### lstm_out: (seq_len, batch_size, hidden_size) lstm_out = self.dropout_lstm(lstm_out) x = lstm_out ##### batch version # x = torch.squeeze(lstm_out, 1) # x: variable (seq_len, batch_size, hidden_size) # target_start: variable (batch_size) _, start = torch.max(target_start.unsqueeze(0), dim=1) max_start = utils.to_scalar(target_start[start]) _, end = torch.min(target_end.unsqueeze(0), dim=1) min_end = utils.to_scalar(target_end[end]) x = x.transpose(0, 1) left_save = [] mask_left_save = [] right_save = [] mask_right_save = [] target_save = [] for idx in range(batch_size): x_len_cur = x[idx].size(0) start_cur = utils.to_scalar(target_start[idx]) left_len_cur = start_cur left_len_max = max_start if start_cur != 0: x_cur_left = x[idx][:start_cur] left_len_sub = left_len_max - left_len_cur mask_cur_left = [1 for _ in range(left_len_cur)] else: x_cur_left = x[idx][0].unsqueeze(0) left_len_sub = left_len_max - 1 # mask_cur_left = [-1e+20] mask_cur_left = [0] # x_cur_left: variable (start_cur, two_hidden_size) # mask_cur_left = [1 for _ in range(start_cur)] # mask_cur_left: list (start_cur) if start_cur < max_start: add = Variable(torch.zeros(left_len_sub, self.lstm_hiddens)) if self.use_cuda: add = add.cuda() x_cur_left = torch.cat([x_cur_left, add], 0) # x_cur_left: variable (max_start, two_hidden_size) left_save.append(x_cur_left.unsqueeze(0)) # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)]) mask_cur_left.extend([0 for _ in range(left_len_sub)]) # mask_cur_left: list (max_start) mask_left_save.append(mask_cur_left) else: left_save.append(x_cur_left.unsqueeze(0)) mask_left_save.append(mask_cur_left) end_cur = utils.to_scalar(target_end[idx]) right_len_cur = x_len_cur - end_cur - 1 right_len_max = x_len_cur - min_end - 1 if (end_cur + 1) != x_len_cur: x_cur_right = x[idx][(end_cur + 1):] right_len_sub = right_len_max - right_len_cur mask_cur_right = [1 for _ in range(right_len_cur)] else: x_cur_right = x[idx][end_cur].unsqueeze(0) right_len_sub = right_len_max - right_len_cur - 1 # mask_cur_right = [-1e+20] mask_cur_right = [0] # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size) # mask_cur_right = [1 for _ in range(right_len_cur)] # mask_cur_right: list (x_len_cur-end_cur-1==right_len) if end_cur > min_end: add = Variable(torch.zeros(right_len_sub, self.lstm_hiddens)) if self.use_cuda: add = add.cuda() x_cur_right = torch.cat([x_cur_right, add], 0) right_save.append(x_cur_right.unsqueeze(0)) # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)]) mask_cur_right.extend([0 for _ in range(right_len_sub)]) mask_right_save.append(mask_cur_right) else: right_save.append(x_cur_right.unsqueeze(0)) mask_right_save.append(mask_cur_right) # target_sub = end_cur-start_cur x_target = x[idx][start_cur:(end_cur + 1)] x_average_target = torch.mean(x_target, 0) target_save.append(x_average_target.unsqueeze(0)) mask_left_save = Variable(torch.ByteTensor(mask_left_save)) # mask_left_save: variable (batch_size, left_len_max) mask_right_save = Variable(torch.ByteTensor(mask_right_save)) # mask_right_save: variable (batch_size, right_len_max) left_save = torch.cat(left_save, 0) right_save = torch.cat(right_save, 0) target_save = torch.cat(target_save, 0) # left_save: variable (batch_size, left_len_max, two_hidden_size) # right_save: variable (batch_size, right_len_max, two_hidden_size) # target_save: variable (batch_size, two_hidden_size) if self.use_cuda: mask_right_save = mask_right_save.cuda() mask_left_save = mask_left_save.cuda() left_save = left_save.cuda() right_save = right_save.cuda() target_save = target_save.cuda() s = self.attention(x, target_save, None) s_l = self.attention_l(left_save, target_save, mask_left_save) s_r = self.attention_r(right_save, target_save, mask_right_save) result = self.linear(s) # result: variable (1, label_num) # result = self.linear_l(s_l) result = torch.add(result, self.linear_l(s_l)) result = torch.add(result, self.linear_r(s_r)) # result: variable (batch_size, label_num) # print(result) return result
def forward(self, fea_v, length, target_start, target_end, label_v): word_v = fea_v batch_size = word_v.size(0) seq_length = word_v.size(1) word_emb = self.embedding(word_v) word_emb = self.dropout_emb(word_emb) label_emb = self.embedding_label(label_v) label_emb = self.dropout_emb(label_emb) # x = torch.cat([word_emb, label_emb], 2) x = torch.transpose(word_emb, 0, 1) # x = torch.transpose(x, 0, 1) # lstm_out, _ = self.lstm(x) packed_words = pack_padded_sequence(x, length) lstm_out, self.hidden = self.lstm(packed_words, self.hidden) lstm_out, _ = pad_packed_sequence(lstm_out) ##### lstm_out: (seq_len, batch_size, hidden_size) lstm_out = self.dropout_lstm(lstm_out) x = lstm_out # print(x) x = x.transpose(0, 1) ##### batch version # x = torch.squeeze(lstm_out, 1) # x: variable (seq_len, batch_size, hidden_size) # target_start: variable (batch_size) # _, start = torch.max(target_start.unsqueeze(0), dim=1) # max_start = utils.to_scalar(target_start[start]) # _, end = torch.min(target_end.unsqueeze(0), dim=1) # min_end = utils.to_scalar(target_end[end]) max_length = 0 for index in range(batch_size): x_len = x[index].size(0) start = utils.to_scalar(target_start[index]) end = utils.to_scalar(target_end[index]) none_t = x_len - (end - start + 1) if none_t > max_length: max_length = none_t none_target = [] mask_none_target = [] target_save = [] for idx in range(batch_size): mask_none_t = [] none_t = None x_len_cur = x[idx].size(0) start_cur = utils.to_scalar(target_start[idx]) end_cur = utils.to_scalar(target_end[idx]) x_target = x[idx][start_cur:(end_cur + 1)] x_average_target = torch.mean(x_target, 0) target_save.append(x_average_target.unsqueeze(0)) if start_cur != 0: left = x[idx][:start_cur] none_t = left mask_none_t.extend([1] * start_cur) if end_cur != (x_len_cur - 1): right = x[idx][(end_cur + 1):] if none_t is not None: none_t = torch.cat([none_t, right], 0) else: none_t = right mask_none_t.extend([1] * (x_len_cur - end_cur - 1)) if len(mask_none_t) != max_length: add_t = Variable( torch.zeros((max_length - len(mask_none_t)), self.lstm_hiddens)) if self.use_cuda: add_t = add_t.cuda() mask_none_t.extend([0] * (max_length - len(mask_none_t))) # print(add_t) none_t = torch.cat([none_t, add_t], 0) mask_none_target.append(mask_none_t) none_target.append(none_t.unsqueeze(0)) target_save = torch.cat(target_save, 0) # print(none_target) none_target = torch.cat(none_target, 0) mask_none_target = Variable(torch.ByteTensor(mask_none_target)) # target_save: variable (batch_size, two_hidden_size) if self.use_cuda: target_save = target_save.cuda() mask_none_target = mask_none_target.cuda() none_target = none_target.cuda() # squence = torch.cat(none_target, 1) # s = self.attention(x, target_save, None) s = self.attention(none_target, target_save, mask_none_target) # print(s) ##### s: variable (batch_size, lstm_hiddens) result = self.linear(s) # result: variable (1, label_num) # result: variable (batch_size, label_num) return result
def forward(self, fea_v, length, target_start, target_end, label_v): word_v = fea_v batch_size = word_v.size(0) seq_length = word_v.size(1) word_emb = self.embedding(word_v) word_emb = self.dropout_emb(word_emb) # label_emb = self.embedding_label(label_v) # label_emb = self.dropout_emb(label_emb) # x = torch.cat([word_emb, label_emb], 2) # x = torch.transpose(x, 0, 1) x = torch.transpose(word_emb, 0, 1) packed_words = pack_padded_sequence(x, length) lstm_out, self.hidden = self.lstm(packed_words, self.hidden) lstm_out, _ = pad_packed_sequence(lstm_out) ##### lstm_out: (seq_len, batch_size, hidden_size) lstm_out = self.dropout_lstm(lstm_out) x = lstm_out x = x.transpose(0, 1) ##### batch version # x = torch.squeeze(lstm_out, 1) # x: variable (seq_len, batch_size, hidden_size) # target_start: variable (batch_size) _, start = torch.max(target_start.unsqueeze(0), dim=1) max_start = utils.to_scalar(target_start[start]) _, end = torch.min(target_end.unsqueeze(0), dim=1) min_end = utils.to_scalar(target_end[end]) max_length = 0 for index in range(batch_size): x_len = x[index].size(0) start = utils.to_scalar(target_start[index]) end = utils.to_scalar(target_end[index]) none_t = x_len-(end-start+1) if none_t > max_length: max_length = none_t left_save = [] mask_left_save = [] right_save = [] mask_right_save = [] target_save = [] none_target = [] mask_none_target = [] target_save_mean = [] for idx in range(batch_size): mask_none_t = [] none_t = None x_len_cur = x[idx].size(0) start_cur = utils.to_scalar(target_start[idx]) end_cur = utils.to_scalar(target_end[idx]) # left_len_cur = start_cur # left_len_max = max_start x_target = x[idx][start_cur:(end_cur + 1)] x_target_mean = torch.mean(x_target, 0) target_save_mean.append(x_target_mean.unsqueeze(0)) x_target_fc = F.tanh(self.linear_try(x_target)) beta = torch.mm(x_target_fc, self.u_try) beta = torch.t(beta) if self.use_cuda: alpha = F.softmax(beta, dim=1) # alpha: variable (1, len) else: alpha = F.softmax(beta) # alpha: variable (1, len) x_average_target = torch.mm(alpha, x_target) x_average_target = x_average_target.squeeze(0) # print(x_average_target) target_save.append(x_average_target.unsqueeze(0)) if start_cur != 0: left = x[idx][:start_cur] none_t = left mask_none_t.extend([1] * start_cur) if end_cur != (x_len_cur - 1): right = x[idx][(end_cur + 1):] if none_t is not None: none_t = torch.cat([none_t, right], 0) else: none_t = right mask_none_t.extend([1] * (x_len_cur - end_cur - 1)) if len(mask_none_t) != max_length: add_t = Variable(torch.zeros((max_length - len(mask_none_t)), self.lstm_hiddens)) if self.use_cuda: add_t = add_t.cuda() mask_none_t.extend([0] * (max_length - len(mask_none_t))) # print(add_t) none_t = torch.cat([none_t, add_t], 0) mask_none_target.append(mask_none_t) none_target.append(none_t.unsqueeze(0)) x_len_cur = x[idx].size(0) start_cur = utils.to_scalar(target_start[idx]) left_len_cur = start_cur left_len_max = max_start if start_cur != 0: x_cur_left = x[idx][:start_cur] left_len_sub = left_len_max - left_len_cur mask_cur_left = [1 for _ in range(left_len_cur)] else: x_cur_left = x[idx][0].unsqueeze(0) left_len_sub = left_len_max - 1 # mask_cur_left = [-1e+20] mask_cur_left = [0] # x_cur_left: variable (start_cur, two_hidden_size) # mask_cur_left = [1 for _ in range(start_cur)] # mask_cur_left: list (start_cur) if start_cur < max_start: add = Variable(torch.zeros(left_len_sub, self.lstm_hiddens)) if self.use_cuda: add = add.cuda() x_cur_left = torch.cat([x_cur_left, add], 0) # x_cur_left: variable (max_start, two_hidden_size) left_save.append(x_cur_left.unsqueeze(0)) # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)]) mask_cur_left.extend([0 for _ in range(left_len_sub)]) # mask_cur_left: list (max_start) mask_left_save.append(mask_cur_left) else: left_save.append(x_cur_left.unsqueeze(0)) mask_left_save.append(mask_cur_left) end_cur = utils.to_scalar(target_end[idx]) right_len_cur = x_len_cur-end_cur-1 right_len_max = x_len_cur-min_end-1 if (end_cur+1) != x_len_cur: x_cur_right = x[idx][(end_cur+1):] right_len_sub = right_len_max - right_len_cur mask_cur_right = [1 for _ in range(right_len_cur)] else: x_cur_right = x[idx][end_cur].unsqueeze(0) right_len_sub = right_len_max - right_len_cur - 1 # mask_cur_right = [-1e+20] mask_cur_right = [0] # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size) # mask_cur_right = [1 for _ in range(right_len_cur)] # mask_cur_right: list (x_len_cur-end_cur-1==right_len) if end_cur > min_end: add = Variable(torch.zeros(right_len_sub, self.lstm_hiddens)) if self.use_cuda: add = add.cuda() x_cur_right = torch.cat([x_cur_right, add], 0) right_save.append(x_cur_right.unsqueeze(0)) # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)]) mask_cur_right.extend([0 for _ in range(right_len_sub)]) mask_right_save.append(mask_cur_right) else: right_save.append(x_cur_right.unsqueeze(0)) mask_right_save.append(mask_cur_right) # target_sub = end_cur-start_cur # x_target = x[idx][start_cur:(end_cur+1)] # x_average_target = torch.mean(x_target, 0) # target_save.append(x_average_target.unsqueeze(0)) mask_left_save = Variable(torch.ByteTensor(mask_left_save)) # mask_left_save: variable (batch_size, left_len_max) mask_right_save = Variable(torch.ByteTensor(mask_right_save)) # mask_right_save: variable (batch_size, right_len_max) left_save = torch.cat(left_save, 0) right_save = torch.cat(right_save, 0) target_save = torch.cat(target_save, 0) # left_save: variable (batch_size, left_len_max, two_hidden_size) # right_save: variable (batch_size, right_len_max, two_hidden_size) # target_save: variable (batch_size, two_hidden_size) none_target = torch.cat(none_target, 0) mask_none_target = Variable(torch.ByteTensor(mask_none_target)) target_save_mean = torch.cat(target_save_mean, 0) if self.use_cuda: mask_right_save = mask_right_save.cuda() mask_left_save = mask_left_save.cuda() left_save = left_save.cuda() right_save = right_save.cuda() target_save = target_save.cuda() mask_none_target = mask_none_target.cuda() none_target = none_target.cuda() target_save_mean = target_save_mean.cuda() # s = self.attention(x, target_save, None) s = self.attention(none_target, target_save_mean, mask_none_target) s_l = self.attention_l(left_save, target_save_mean, mask_left_save) s_r = self.attention_r(right_save, target_save_mean, mask_right_save) # s = torch.cat([s, target_save], 1) result = self.linear(s) # result: variable (1, label_num) # result = self.linear_l(s_l) result = torch.add(result, self.linear_l(s_l)) result = torch.add(result, self.linear_r(s_r)) # result: variable (batch_size, label_num) # print(result) return result
def train_segpos(train_insts, dev_insts, test_insts, encode, decode, config, params): print('training...') parameters_en = filter(lambda p: p.requires_grad, encode.parameters()) # optimizer_en = torch.optim.SGD(params=parameters_en, lr=config.learning_rate, momentum=0.9, weight_decay=config.decay) optimizer_en = torch.optim.Adam(params=parameters_en, lr=config.learning_rate, weight_decay=config.decay) parameters_de = filter(lambda p: p.requires_grad, decode.parameters()) # optimizer_de = torch.optim.SGD(params= parameters_de, lr=config.learning_rate, momentum=0.9, weight_decay=config.decay) optimizer_de = torch.optim.Adam(params=parameters_de, lr=config.learning_rate, weight_decay=config.decay) best_dev_f1_seg = float('-inf') best_dev_f1_pos = float('-inf') best_test_f1_seg = float('-inf') best_test_f1_pos = float('-inf') dev_eval_seg = Eval() dev_eval_pos = Eval() test_eval_seg = Eval() test_eval_pos = Eval() for epoch in range(config.maxIters): start_time = time.time() encode.train() decode.train() train_insts = utils.random_instances(train_insts) epoch_loss = 0 train_buckets = params.generate_batch_buckets(config.train_batch_size, train_insts) for index in range(len(train_buckets)): batch_length = np.array( [np.sum(mask) for mask in train_buckets[index][-1]]) var_b, list_b, mask_v, length_v, gold_v = utils.patch_var( train_buckets[index], batch_length.tolist(), params) encode.zero_grad() decode.zero_grad() if mask_v.size(0) != config.train_batch_size: encode.hidden = encode.init_hidden(mask_v.size(0)) else: encode.hidden = encode.init_hidden(config.train_batch_size) lstm_out = encode.forward(var_b, list_b, mask_v, batch_length.tolist()) output, state = decode.forward(lstm_out, var_b, list_b, mask_v, batch_length.tolist(), is_train=True) #### output: variable (batch_size, max_length, segpos_num) # num_total = output.size(0)*output.size(1) # output = output.contiguous().view(num_total, output.size(2)) # print(output) # gold_v = gold_v.view(num_total) # print(output) gold_v = gold_v.view(output.size(0)) # print(gold_v) loss = F.cross_entropy(output, gold_v) loss.backward() nn.utils.clip_grad_norm(parameters_en, max_norm=config.clip_grad) nn.utils.clip_grad_norm(parameters_de, max_norm=config.clip_grad) optimizer_en.step() optimizer_de.step() epoch_loss += utils.to_scalar(loss) print('\nepoch is {}, average loss is {} '.format( epoch, (epoch_loss / (config.train_batch_size * len(train_buckets))))) # update lr # adjust_learning_rate(optimizer, config.learning_rate / (1 + (epoch + 1) * config.decay)) # acc = float(correct_num) / float(gold_num) # print('\nepoch is {}, accuracy is {}'.format(epoch, acc)) print('the {} epoch training costs time: {} s '.format( epoch, time.time() - start_time)) print('\nDev...') dev_eval_seg.clear() dev_eval_pos.clear() test_eval_seg.clear() test_eval_pos.clear() start_time = time.time() dev_f1_seg, dev_f1_pos = eval_batch(dev_insts, encode, decode, config, params, test_eval_seg, test_eval_pos) print('the {} epoch dev costs time: {} s'.format( epoch, time.time() - start_time)) if dev_f1_seg > best_dev_f1_seg: best_dev_f1_seg = dev_f1_seg if dev_f1_pos > best_dev_f1_pos: best_dev_f1_pos = dev_f1_pos print('\nTest...') start_time = time.time() test_f1_seg, test_f1_pos = eval_batch(test_insts, encode, decode, config, params, test_eval_seg, test_eval_pos) print('the {} epoch testing costs time: {} s'.format( epoch, time.time() - start_time)) if test_f1_seg > best_test_f1_seg: best_test_f1_seg = test_f1_seg if test_f1_pos > best_test_f1_pos: best_test_f1_pos = test_f1_pos print( 'now, test fscore of seg is {}, test fscore of pos is {}, best test fscore of seg is {}, best fscore of pos is {} ' .format(test_f1_seg, test_f1_pos, best_test_f1_seg, best_test_f1_pos)) torch.save(encode.state_dict(), config.save_encode_path) torch.save(decode.state_dict(), config.save_decode_path) else: if dev_f1_pos > best_dev_f1_pos: best_dev_f1_pos = dev_f1_pos print('\nTest...') start_time = time.time() test_f1_seg, test_f1_pos = eval_batch(test_insts, encode, decode, config, params, test_eval_seg, test_eval_pos) print('the {} epoch testing costs time: {} s'.format( epoch, time.time() - start_time)) if test_f1_seg > best_test_f1_seg: best_test_f1_seg = test_f1_seg if test_f1_pos > best_test_f1_pos: best_test_f1_pos = test_f1_pos print( 'now, test fscore of seg is {}, test fscore of pos is {}, best test fscore of seg is {}, best fscore of pos is {} ' .format(test_f1_seg, test_f1_pos, best_test_f1_seg, best_test_f1_pos)) torch.save(encode.state_dict(), config.save_encode_path) torch.save(decode.state_dict(), config.save_decode_path) print( 'now, dev fscore of seg is {}, dev fscore of pos is {}, best dev fscore of seg is {}, best dev fscore of pos is {}, best test fscore of seg is {}, best test fscore of pos is {}' .format(dev_f1_seg, dev_f1_pos, best_dev_f1_seg, best_test_f1_pos, best_test_f1_seg, best_test_f1_pos))
def action_batch(self, index, states, output, length): # if train: # action = states.golds[index] # ['SEP#P', 'SEP#CD', 'SEP#DT', 'SEP#NT', 'SEP#NN', 'SEP#P', 'SEP#NR', 'SEP#AD', 'SEP#NR', 'SEP#NN'] # if index == 0: # for id, ele in enumerate(states.chars[index]): # states.words_record[id].append(ele) # # pos_record_cur = [] # # for ele in states.golds[index]: # # # states.pos_record[index].append(ele.split('#')[1]) # # pos_record_cur.append(ele.split('#')[1]) # # states.pos_record.append(pos_record_cur) # # # states.pos_record[index].append(states.gold[index].split('#')[1]) # for id, ele in enumerate(states.golds[index]): # states.pos_record[id].append(ele.split('#')[1]) # # pos_index_cur = [] # # for ele in states.golds[index]: # # # states.pos_index[index].append(self.pos2id[ele.split('#')[1]]) # # pos_index_cur.append(self.pos2id[ele.split('#')[1]]) # # states.pos_index.append(pos_index_cur) # # # sent.pos_index.append(self.pos2id[sent.gold[index].split('#')[1]]) # for id, ele in enumerate(states.golds[index]): # states.pos_index[id].append(self.pos2id[ele.split('#')[1]]) # else: max_score, max_index = torch.max(output, dim=1) # print(max_index) # action = [self.id2gold[utils.to_scalar(ele)] for ele in max_index] action = [] for id, ele in enumerate(max_index): action_cur = self.id2gold[utils.to_scalar(ele)] if index >= length[id]: action_cur = '<pad>' action.append(action_cur) # print(action) if index == 0: for id, ele in enumerate(action): pos_id = ele.find('#') if pos_id == -1: print('action at the first index is error.') else: states.words_record[id].append(states.chars[index][id]) pos_record = action[id][(pos_id + 1):] states.pos_record[id].append(pos_record) pos_index = self.pos2id[pos_record] states.pos_index[id].append(pos_index) # states.action.append(action) for id, ele in enumerate(action): states.action[id].append(ele) if index != 0: # last_words_record = states.words_record[index-1] # words_record_cur = [] # pos_record_cur = [] # pos_index_cur = [] for id, ele in enumerate(action): if ele == '<pad>': states.words_record[id].append('<pad>') states.pos_record[id].append('<pad>') states.pos_index[id].append(self.posPadID) else: pos_id = ele.find('#') if pos_id == -1: last_cur = states.words_record[id][-1] + states.chars[ index][id] # chars和golds不一样 # words_record_cur.append(last_cur) states.words_record[id][-1] = last_cur else: # words_record_cur.append(states.chars[index][id]) states.words_record[id].append(states.chars[index][id]) pos_record = action[id][(pos_id + 1):] # pos_record_cur.append(pos_record) states.pos_record[id].append(pos_record) pos_index = self.pos2id[pos_record] # pos_index_cur.append(pos_index) states.pos_index[id].append(pos_index)
def forward(self, fea_v, length, target_start, target_end, label_v): if self.add_char: word_v = fea_v[0] char_v = fea_v[1] else: word_v = fea_v batch_size = word_v.size(0) seq_length = word_v.size(1) word_emb = self.embedding(word_v) word_emb = self.dropout_emb(word_emb) label_emb = self.embedding_label(label_v) label_emb = self.dropout_emb(label_emb) x = torch.cat([word_emb, label_emb], 2) x = torch.transpose(x, 0, 1) # x = torch.transpose(word_emb, 0, 1) packed_words = pack_padded_sequence(x, length) lstm_out, self.hidden = self.lstm(packed_words, self.hidden) lstm_out, _ = pad_packed_sequence(lstm_out) ##### lstm_out: (seq_len, batch_size, hidden_size) lstm_out = self.dropout_lstm(lstm_out) x = lstm_out ##### batch version # x: variable (seq_len, batch_size, hidden_size) # target_start: variable (batch_size) _, start = torch.max(target_start.unsqueeze(0), dim=1) max_start = utils.to_scalar(target_start[start]) _, end = torch.min(target_end.unsqueeze(0), dim=1) min_end = utils.to_scalar(target_end[end]) x = x.transpose(0, 1) max_length = 0 for index in range(batch_size): x_len = x[index].size(0) start = utils.to_scalar(target_start[index]) end = utils.to_scalar(target_end[index]) none_t = x_len - (end - start + 1) if none_t > max_length: max_length = none_t left_save = [] mask_left_save = [] right_save = [] mask_right_save = [] target_save_mean = [] none_target = [] mask_none_target = [] for idx in range(batch_size): mask_none_t = [] none_t = None x_len_cur = x[idx].size(0) start_cur = utils.to_scalar(target_start[idx]) end_cur = utils.to_scalar(target_end[idx]) left_len_cur = start_cur left_len_max = max_start if start_cur != 0: x_cur_left = x[idx][:start_cur] left_len_sub = left_len_max - left_len_cur mask_cur_left = [1 for _ in range(left_len_cur)] else: x_cur_left = x[idx][0].unsqueeze(0) left_len_sub = left_len_max - 1 # mask_cur_left = [-1e+20] mask_cur_left = [0] # x_cur_left: variable (start_cur, two_hidden_size) # mask_cur_left = [1 for _ in range(start_cur)] # mask_cur_left: list (start_cur) if start_cur < max_start: if left_len_sub == 0: print('error') add = Variable(torch.rand(left_len_sub, self.lstm_hiddens)) if self.use_cuda: add = add.cuda() x_cur_left = torch.cat([x_cur_left, add], dim=0) # x_cur_left: variable (max_start, two_hidden_size) left_save.append(x_cur_left.unsqueeze(0)) # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)]) mask_cur_left.extend([0 for _ in range(left_len_sub)]) # mask_cur_left: list (max_start) mask_left_save.append(mask_cur_left) else: left_save.append(x_cur_left.unsqueeze(0)) mask_left_save.append(mask_cur_left) right_len_cur = x_len_cur - end_cur - 1 right_len_max = x_len_cur - min_end - 1 if (end_cur + 1) != x_len_cur: x_cur_right = x[idx][(end_cur + 1):] right_len_sub = right_len_max - right_len_cur mask_cur_right = [1 for _ in range(right_len_cur)] else: x_cur_right = x[idx][end_cur].unsqueeze(0) right_len_sub = right_len_max - right_len_cur - 1 # mask_cur_right = [-1e+20] mask_cur_right = [0] # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size) # mask_cur_right = [1 for _ in range(right_len_cur)] # mask_cur_right: list (x_len_cur-end_cur-1==right_len) if end_cur > min_end: if right_len_sub == 0: print('error2') add = Variable(torch.rand(right_len_sub, self.lstm_hiddens)) if self.use_cuda: add = add.cuda() x_cur_right = torch.cat([x_cur_right, add], dim=0) right_save.append(x_cur_right.unsqueeze(0)) # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)]) mask_cur_right.extend([0 for _ in range(right_len_sub)]) mask_right_save.append(mask_cur_right) else: right_save.append(x_cur_right.unsqueeze(0)) mask_right_save.append(mask_cur_right) if start_cur != 0: left = x[idx][:start_cur] none_t = left mask_none_t.extend([1] * start_cur) if end_cur != (x_len_cur - 1): right = x[idx][(end_cur + 1):] if none_t is not None: none_t = torch.cat([none_t, right], 0) else: none_t = right mask_none_t.extend([1] * (x_len_cur - end_cur - 1)) if len(mask_none_t) != max_length: add_t = Variable( torch.zeros((max_length - len(mask_none_t)), self.lstm_hiddens)) if self.use_cuda: add_t = add_t.cuda() mask_none_t.extend([0] * (max_length - len(mask_none_t))) # print(add_t) none_t = torch.cat([none_t, add_t], 0) mask_none_target.append(mask_none_t) none_target.append(none_t.unsqueeze(0)) x_target = x[idx][start_cur:(end_cur + 1)] x_average_target = torch.mean(x_target, 0) target_save_mean.append(x_average_target.unsqueeze(0)) mask_left_save = Variable(torch.ByteTensor(mask_left_save)) # mask_left_save: variable (batch_size, left_len_max) mask_right_save = Variable(torch.ByteTensor(mask_right_save)) # mask_right_save: variable (batch_size, right_len_max) left_save = torch.cat(left_save, dim=0) right_save = torch.cat(right_save, dim=0) target_save_mean = torch.cat(target_save_mean, dim=0) none_target = torch.cat(none_target, 0) mask_none_target = Variable(torch.ByteTensor(mask_none_target)) # left_save: variable (batch_size, left_len_max, two_hidden_size) # right_save: variable (batch_size, right_len_max, two_hidden_size) # target_save_mean: variable (batch_size, two_hidden_size) if self.use_cuda: mask_right_save = mask_right_save.cuda() mask_left_save = mask_left_save.cuda() left_save = left_save.cuda() right_save = right_save.cuda() target_save_mean = target_save_mean.cuda() none_target = none_target.cuda() mask_none_target = mask_none_target.cuda() # s, s_alpha = self.attention(x, target_save, None) # s_l, s_l_alpha = self.attention_l(left_save, target_save, mask_left_save) # s_r, s_r_alpha = self.attention_r(right_save, target_save, mask_right_save) # s = self.attention(x, target_save, None) s = self.attention(none_target, target_save_mean, mask_none_target) s_l = self.attention_l(left_save, target_save_mean, mask_left_save) s_r = self.attention_r(right_save, target_save_mean, mask_right_save) w1s = torch.mm(self.w1, torch.t(s)) u1t = torch.mm(self.u1, torch.t(target_save_mean)) if self.use_cuda: w1s = w1s.cuda() u1t = u1t.cuda() if batch_size == self.batch_size: z = torch.exp(w1s + u1t + self.b1) else: z = torch.exp(w1s + u1t) z_all = z # z_all: variable (two_hidden_size, batch_size) z_all = z_all.unsqueeze(2) w2s = torch.mm(self.w2, torch.t(s_l)) u2t = torch.mm(self.u2, torch.t(target_save_mean)) if self.use_cuda: w2s = w2s.cuda() u2t = u2t.cuda() if batch_size == self.batch_size: z_l = torch.exp(w2s + u2t + self.b2) else: z_l = torch.exp(w2s + u2t) # print(z_all) # print(z_l) z_all = torch.cat([z_all, z_l.unsqueeze(2)], dim=2) w3s = torch.mm(self.w3, torch.t(s_r)) u3t = torch.mm(self.u3, torch.t(target_save_mean)) if self.use_cuda: w3s = w3s.cuda() u3t = u3t.cuda() if batch_size == self.batch_size: z_r = torch.exp(w3s + u3t + self.b3) else: z_r = torch.exp(w3s + u3t) z_all = torch.cat([z_all, z_r.unsqueeze(2)], dim=2) # z_all: variable (two_hidden_size, batch_size, 3) if self.use_cuda: z_all = F.softmax(z_all, dim=2) else: z_all = F.softmax(z_all) # z_all = torch.t(z_all) z_all = z_all.permute(2, 1, 0) # z = torch.unsqueeze(z_all[:batch_size], 0) # z_l = torch.unsqueeze(z_all[batch_size:(2*batch_size)], 0) # z_r = torch.unsqueeze(z_all[(2*batch_size):], 0) # z = z_all[:batch_size] # z_l = z_all[batch_size:(2*batch_size)] # z_r = z_all[(2*batch_size):] z = z_all[0] z_l = z_all[1] z_r = z_all[2] ss = torch.mul(z, s) ss = torch.add(ss, torch.mul(z_l, s_l)) ss = torch.add(ss, torch.mul(z_r, s_r)) logit = self.linear_2(ss) # print(logit) # alpha = [s_alpha, s_l_alpha, s_r_alpha] return logit
def forward(self, fea_v, length, target_start, target_end): if self.add_char: word_v = fea_v[0] char_v = fea_v[1] else: word_v = fea_v batch_size = word_v.size(0) seq_length = word_v.size(1) word_emb = self.embedding(word_v) word_emb = self.dropout_emb(word_emb) if self.static: word_static = self.embedding_static(word_v) word_static = self.dropout_emb(word_static) word_emb = torch.cat([word_emb, word_static], 2) x = torch.transpose(word_emb, 0, 1) packed_words = pack_padded_sequence(x, length) lstm_out, self.hidden = self.lstm(packed_words, self.hidden) lstm_out, _ = pad_packed_sequence(lstm_out) ##### lstm_out: (seq_len, batch_size, hidden_size) lstm_out = self.dropout_lstm(lstm_out) x = lstm_out x = x.transpose(0, 1) ##### batch version # x = torch.squeeze(lstm_out, 1) # x: variable (seq_len, batch_size, hidden_size) # target_start: variable (batch_size) # _, start = torch.max(target_start.unsqueeze(0), dim=1) # max_start = utils.to_scalar(target_start[start]) # _, end = torch.min(target_end.unsqueeze(0), dim=1) # min_end = utils.to_scalar(target_end[end]) max_length = 0 for index in range(batch_size): x_len = x[index].size(0) start = utils.to_scalar(target_start[index]) end = utils.to_scalar(target_end[index]) none_t = x_len - (end - start + 1) if none_t > max_length: max_length = none_t # left_save = [] # mask_left_save = [] # right_save = [] # mask_right_save = [] none_target = [] mask_none_target = [] target_save = [] for idx in range(batch_size): mask_none_t = [] none_t = None x_len_cur = x[idx].size(0) start_cur = utils.to_scalar(target_start[idx]) end_cur = utils.to_scalar(target_end[idx]) # left_len_cur = start_cur # left_len_max = max_start x_target = x[idx][start_cur:(end_cur + 1)] x_average_target = torch.mean(x_target, 0) target_save.append(x_average_target.unsqueeze(0)) if start_cur != 0: left = x[idx][:start_cur] none_t = left mask_none_t.extend([1] * start_cur) if end_cur != (x_len_cur - 1): right = x[idx][(end_cur + 1):] if none_t is not None: none_t = torch.cat([none_t, right], 0) else: none_t = right mask_none_t.extend([1] * (x_len_cur - end_cur - 1)) if len(mask_none_t) != max_length: add_t = Variable( torch.zeros((max_length - len(mask_none_t)), self.lstm_hiddens)) if self.use_cuda: add_t = add_t.cuda() mask_none_t.extend([0] * (max_length - len(mask_none_t))) # print(add_t) none_t = torch.cat([none_t, add_t], 0) mask_none_target.append(mask_none_t) none_target.append(none_t.unsqueeze(0)) # if start_cur != 0: # x_cur_left = x[idx][:start_cur] # left_len_sub = left_len_max - left_len_cur # mask_cur_left = [1 for _ in range(left_len_cur)] # else: # x_cur_left = x[idx][0].unsqueeze(0) # left_len_sub = left_len_max - 1 # # mask_cur_left = [-1e+20] # mask_cur_left = [0] # # x_cur_left: variable (start_cur, two_hidden_size) # # mask_cur_left = [1 for _ in range(start_cur)] # # mask_cur_left: list (start_cur) # if start_cur < max_start: # add = Variable(torch.zeros(left_len_sub, self.lstm_hiddens)) # if self.use_cuda: add = add.cuda() # x_cur_left = torch.cat([x_cur_left, add], 0) # # x_cur_left: variable (max_start, two_hidden_size) # left_save.append(x_cur_left.unsqueeze(0)) # # mask_cur_left.extend([-1e+20 for _ in range(left_len_sub)]) # mask_cur_left.extend([0 for _ in range(left_len_sub)]) # # mask_cur_left: list (max_start) # mask_left_save.append(mask_cur_left) # else: # left_save.append(x_cur_left.unsqueeze(0)) # mask_left_save.append(mask_cur_left) # # end_cur = utils.to_scalar(target_end[idx]) # right_len_cur = x_len_cur - end_cur - 1 # right_len_max = x_len_cur - min_end - 1 # if (end_cur + 1) != x_len_cur: # x_cur_right = x[idx][(end_cur + 1):] # right_len_sub = right_len_max - right_len_cur # mask_cur_right = [1 for _ in range(right_len_cur)] # else: # x_cur_right = x[idx][end_cur].unsqueeze(0) # right_len_sub = right_len_max - right_len_cur - 1 # # mask_cur_right = [-1e+20] # mask_cur_right = [0] # # x_cur_right: variable ((x_len_cur-end_cur-1), two_hidden_size) # # mask_cur_right = [1 for _ in range(right_len_cur)] # # mask_cur_right: list (x_len_cur-end_cur-1==right_len) # if end_cur > min_end: # add = Variable(torch.zeros(right_len_sub, self.lstm_hiddens)) # if self.use_cuda: add = add.cuda() # x_cur_right = torch.cat([x_cur_right, add], 0) # right_save.append(x_cur_right.unsqueeze(0)) # # mask_cur_right.extend([-1e+20 for _ in range(right_len_sub)]) # mask_cur_right.extend([0 for _ in range(right_len_sub)]) # mask_right_save.append(mask_cur_right) # else: # right_save.append(x_cur_right.unsqueeze(0)) # mask_right_save.append(mask_cur_right) # mask_left_save = Variable(torch.ByteTensor(mask_left_save)) # # mask_left_save: variable (batch_size, left_len_max) # mask_right_save = Variable(torch.ByteTensor(mask_right_save)) # # mask_right_save: variable (batch_size, right_len_max) # left_save = torch.cat(left_save, 0) # right_save = torch.cat(right_save, 0) target_save = torch.cat(target_save, 0) # print(none_target) none_target = torch.cat(none_target, 0) mask_none_target = Variable(torch.ByteTensor(mask_none_target)) # left_save: variable (batch_size, left_len_max, two_hidden_size) # right_save: variable (batch_size, right_len_max, two_hidden_size) # target_save: variable (batch_size, two_hidden_size) if self.use_cuda: # mask_right_save = mask_right_save.cuda() # mask_left_save = mask_left_save.cuda() # left_save = left_save.cuda() # right_save = right_save.cuda() target_save = target_save.cuda() mask_none_target = mask_none_target.cuda() none_target = none_target.cuda() # squence = torch.cat(none_target, 1) s = self.attention(none_target, target_save, mask_none_target) # s = self.attention(x, target_save, None) # s_l = self.attention_l(left_save, target_save, mask_left_save) # s_r = self.attention_r(right_save, target_save, mask_right_save) result = self.linear(s) # result: variable (1, label_num) # result = self.linear_l(s_l) # result = torch.add(result, self.linear_l(s_l)) # result = torch.add(result, self.linear_r(s_r)) # result: variable (batch_size, label_num) # print(result) return result