def _pooling_avg(self, encoded_sentences, mask, len_seq_sent): mask = mask.unsqueeze(3).repeat(1, 1, 1, encoded_sentences.size(3)) encoded_sentences = encoded_sentences * mask len_seq_sent = utils.cast_type(len_seq_sent, FLOAT, self.use_gpu) encoded_sentences = torch.div(torch.sum(encoded_sentences, dim=2), len_seq_sent.unsqueeze(2)) return encoded_sentences, mask
def update_cp_seq(len_seq, len_sents, config, cur_cp_ind, text_inputs, cp_seq_list): sent_mask = torch.sign(len_sents) # (batch_size, len_sent) sent_mask = utils.cast_type(sent_mask, FLOAT, config.use_gpu) num_sents = sent_mask.sum(dim=1) # (batch_size) for batch_ind, cur_doc_cp in enumerate(cur_cp_ind): cur_doc_len = int(len_seq[batch_ind]) cur_sent_len = len_sents[batch_ind] cur_sent_num = int(num_sents[batch_ind]) # (max_sent_num) end_loc = 0 prev_loc = 0 cur_cp_seq = [] cur_text_seq = text_inputs[batch_ind] for cur_sent_ind in range(cur_sent_num): end_loc += int(cur_sent_len[cur_sent_ind]) cur_sent_seq = cur_text_seq[ prev_loc:end_loc] # get the current sentence seq from a doc cur_sent_cp = cur_doc_cp[cur_sent_ind] cur_cp_seq.append(int( cur_sent_seq[cur_sent_cp])) # get seq id of Cp prev_loc = end_loc # end for cp_seq_list.append(cur_cp_seq) # end for return cp_seq_list
def _pooling_doc_max(self, encoded_docs, mask, target_dim): mask = utils.cast_type(mask, FLOAT, self.use_gpu) mask = ((mask - 1) * 999).unsqueeze(target_dim + 1).repeat(1, 1, encoded_docs.size(target_dim + 1)) encoded_docs = encoded_docs + mask encoded_docs = encoded_docs.max(dim=target_dim)[0] # Batch * sent * dim return encoded_docs, mask
def eval_measure_mse(self): # cur_pred_list = torch.cat(self.pred_list, dim=0) reverted_pred = self._convert_to_origin_scale(cur_pred_list) reverted_pred = torch.round(reverted_pred) predicted = utils.cast_type(reverted_pred, LONG, self.use_gpu) cur_label_list = [label for row in self.label_list for label in row] cur_label_list = torch.LongTensor(cur_label_list) cur_label_list = utils.cast_type(cur_label_list, LONG, self.use_gpu) cur_label_list = cur_label_list.view(cur_label_list.shape[0], 1) correct = (predicted == cur_label_list).sum().item() total = predicted.size(0) accuracy = correct / total self.eval_reset() return accuracy
def __init__(self, config, vocab=None, key_vocab=None, ignore_vocab=None): super(PPLLoss, self).__init__() self.weight = None self.ignore = None if vocab is not None: if key_vocab is not None: self.logger.info("Use extra cost for key words") weight = np.ones(len(vocab)) for key_w in key_vocab.values(): weight[vocab.token2id.get(key_w)] = 10.0 self.weight = cast_type(torch.from_numpy(weight), FLOAT, config.use_gpu) if ignore_vocab is not None: self.logger.info("Use extra vocab for ignore words") ignore = np.ones(len(vocab)) for ignore_w in ignore_vocab.values(): ignore[vocab.token2id.get(ignore_w)] = 0.0 self.ignore = cast_type(torch.from_numpy(ignore), FLOAT, config.use_gpu)
def __init__(self, padding_idx, config, rev_vocab=None, key_vocab=None): super(NLLEntropy, self).__init__() self.padding_idx = padding_idx if padding_idx is not None else -100 self.avg_type = config.avg_type if rev_vocab is None or key_vocab is None: self.weight = None else: self.logger.info("Use extra cost for key words") weight = np.ones(len(rev_vocab)) for key in key_vocab: weight[rev_vocab[key]] = 10.0 self.weight = cast_type(torch.from_numpy(weight), FLOAT, config.use_gpu)
def gumbel_max(self, log_probs): """ Obtain a sample from the Gumbel max. Not this is not differentibale. :param log_probs: [batch_size x vocab_size] :return: [batch_size x 1] selected token IDs """ sample = torch.Tensor(log_probs.size()).uniform_(0, 1) sample = cast_type(Variable(sample), FLOAT, self.use_gpu) # compute the gumbel sample matrix_u = -1.0 * torch.log(-1.0 * torch.log(sample)) gumbel_log_probs = log_probs + matrix_u max_val, max_ids = torch.max(gumbel_log_probs, dim=-1, keepdim=True) return max_ids
def forward(self, s=None, a_seq=None, mode=None, gen_type=None): a_seq = cast_type(a_seq, LONG, USE_GPU) batch_size = s.shape[0] # logging.info("s shape: {}".format(s.shape)) dec_init_state = self.net(s).unsqueeze(0) # logging.info("h: {}, inp: {}".format(dec_init_state.shape, a_seq.shape)) dec_outs, dec_last, dec_ctx = self.decoder(batch_size, a_seq[:, 0:-1], dec_init_state, mode=mode, gen_type=gen_type, beam_size=1) labels = a_seq[:, 1:].contiguous() enc_dec_nll = self.nll_loss(dec_outs, labels) return enc_dec_nll
def gen_positional_input(self, seq_x): pos_x = [] for cur_batch in seq_x: cur_pos = [] for ind, val in enumerate(cur_batch): if val != 0: cur_pos.append(ind + 1) else: cur_pos.append(0) pos_x.append(cur_pos) pos_input = torch.LongTensor(pos_x) pos_input = utils.cast_type(pos_input, LONG, self.use_gpu) return pos_input
def centering_attn(self, text_inputs, mask_input, len_sents, num_sents, tid): ## Parser stage1: determine foward-looking centers and preferred centers avg_sents_repr, fwrd_repr, batch_cp_ind = self.get_fwrd_centers( text_inputs, mask_input, len_sents) ## Parser stage2: decide backward center back_repr = self.get_back_centers(avg_sents_repr, fwrd_repr) ## Parser stage3: construct hierarchical discourse segments batch_segMap = [] batch_adj_mat = [] batch_adj_list = [] batch_root_list = [] for ind_batch, cur_batch_repr in enumerate(back_repr): cur_sent_num = int(num_sents[ind_batch]) ## Parser stage3-1: get structural information seg_map, adj_list, list_root_ds = self.get_disco_seg( cur_sent_num, ind_batch, fwrd_repr, cur_batch_repr) ## Parser stage3-2: make a tree structure using the information cur_tree = self.make_tree_stru(seg_map, adj_list, list_root_ds) ## Parser stage3-3: make a numpy array from networkx tree cur_adj_mat = np.zeros((self.max_num_sents, self.max_num_sents)) undir_tree = cur_tree.to_undirected() # we make an undirected tree np_adj_mat = nx.to_numpy_matrix(undir_tree) cur_adj_mat[:np_adj_mat.shape[0], :np_adj_mat. shape[1]] = np_adj_mat ## store structures for statistical analysis batch_adj_mat.append(cur_adj_mat) batch_adj_list.append(adj_list) batch_root_list.append(list_root_ds) batch_segMap.append(list(seg_map.items())) # end for ind_batch # structural information which will be passed to structure-aware transformer adj_mat = torch.from_numpy(np.array(batch_adj_mat)) adj_mat = utils.cast_type(adj_mat, FLOAT, self.use_gpu) batch_cp_ind = batch_cp_ind.tolist() return adj_mat, avg_sents_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind
def sent_repr_avg(self, batch_size, encoder_out, len_sents): sent_mask = torch.sign(len_sents) # (batch_size, len_sent) num_sents = sent_mask.sum(dim=1) # (batch_size) sent_repr = torch.zeros(batch_size, self.max_num_sents, self.encoder_coh.encoder_out_size) sent_repr = utils.cast_type(sent_repr, FLOAT, self.use_gpu) for cur_ind_doc in range(batch_size): list_sent_len = len_sents[cur_ind_doc] cur_sent_num = int(num_sents[cur_ind_doc]) cur_loc_sent = 0 list_cur_doc_sents = [] for cur_ind_sent in range(cur_sent_num): cur_sent_len = int(list_sent_len[cur_ind_sent]) # cur_local_words = local_output_words[cur_batch, cur_ind_sent:end_sent, :] # cur_sent_repr = encoder_out[cur_ind_doc, cur_loc_sent+cur_sent_len-1, :] # pick the last representation of each sentence cur_sent_repr = torch.div( torch.sum( encoder_out[cur_ind_doc, cur_loc_sent:cur_loc_sent + cur_sent_len], dim=0), cur_sent_len) # avg version cur_sent_repr = cur_sent_repr.view( 1, 1, -1) # restore to (1, 1, xrnn_cell_size) list_cur_doc_sents.append(cur_sent_repr) cur_loc_sent = cur_loc_sent + cur_sent_len # end for cur_len_sent cur_sents_repr = torch.stack( list_cur_doc_sents, dim=1) # (batch_size, num_sents, rnn_cell_size) cur_sents_repr = cur_sents_repr.squeeze( 2) # not needed when the last repr is used sent_repr[cur_ind_doc, :cur_sent_num, :] = cur_sents_repr # end for cur_doc return sent_repr
def sent_repr_avg(self, batch_size, encoder_out, len_sents): """return sentence representation by averaging of all words.""" mask_sent = torch.sign(len_sents) # (batch_size, len_sent) num_sents = mask_sent.sum(dim=1) # (batch_size) sent_repr = torch.zeros(batch_size, self.max_num_sents, self.base_encoder.encoder_out_size) sent_repr = utils.cast_type(sent_repr, FLOAT, self.use_gpu) for cur_ind_doc in range(batch_size): list_sent_len = len_sents[cur_ind_doc] cur_sent_num = int(num_sents[cur_ind_doc]) cur_loc_sent = 0 list_cur_doc_sents = [] for cur_ind_sent in range(cur_sent_num): cur_sent_len = int(list_sent_len[cur_ind_sent]) cur_sent_repr = torch.div( torch.sum( encoder_out[cur_ind_doc, cur_loc_sent:cur_loc_sent + cur_sent_len], dim=0), cur_sent_len) # avg version cur_sent_repr = cur_sent_repr.view( 1, 1, -1) # restore to (1, 1, xrnn_cell_size) list_cur_doc_sents.append(cur_sent_repr) cur_loc_sent = cur_loc_sent + cur_sent_len # end for cur_len_sent cur_sents_repr = torch.stack( list_cur_doc_sents, dim=1) # (batch_size, num_sents, rnn_cell_size) cur_sents_repr = cur_sents_repr.squeeze( 2) # not needed when the last repr is used sent_repr[cur_ind_doc, :cur_sent_num, :] = cur_sents_repr # end for cur_doc return sent_repr
def forward(self, logits, temperature, use_gpu, hard=False, return_max_id=False): """ :param logits: [batch_size, n_class] unnormalized log-prob :param temperature: non-negative scalar :param hard: if True take argmax :return: [batch_size, n_class] sample from gumbel softmax """ y = self.gumbel_softmax_sample(logits, temperature, use_gpu) _, y_hard = torch.max(y, dim=-1, keepdim=True) if hard: y_onehot = cast_type(torch.zeros(y.size()), FLOAT, use_gpu) y_onehot.scatter_(-1, y_hard, 1.0) y = y_onehot if return_max_id: return y, y_hard else: return y
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, mode=""): # # if self.pad_level == "sent" or self.pad_level == "sentence": # text_inputs = text_inputs.view(self.batch_size, self.max_num_sents*self.max_len_sent) text_inputs = text_inputs.view( text_inputs.size(0), text_inputs.size(1) * text_inputs.size(2)) # encoder_out = self.base_encoder(text_inputs, mask_input, len_seq) ## averaging RNN output by their length; named implicit lexical cohesion vector len_seq = utils.cast_type(len_seq, FLOAT, self.use_gpu) ilc_vec = torch.div( torch.sum(encoder_out, dim=1), len_seq.unsqueeze(1)) # (batch_size, rnn_cell_size) #### Fully Connected fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.corpus_target.lower() == "asap": fc_out = self.sigmoid(fc_out) return fc_out
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""): #### word level representations encoder_out = self.encoder_base(text_inputs, mask_input, len_seq) #### make sentence representations batch_size = text_inputs.size(0) sent_mask = torch.sign(len_sents) # (batch_size, len_sent) sent_mask = utils.cast_type(sent_mask, FLOAT, self.use_gpu) num_sents = sent_mask.sum(dim=1) # (batch_size) ### sentence level representations sent_repr = self.sent_repr_avg(batch_size, encoder_out, len_sents) # do not consider sentence level attention in this implementation encoded_sentences = self.encoder_sent.forward_skip(sent_repr, sent_mask, num_sents) # structured attention for document level encoded_documents, doc_attention_matrix = self.structure_att(encoded_sentences) # get structured attn for sent # pooling for sentence if self.pooling_sent.lower() == "avg": encoded_documents, mask = self._pooling_avg(encoded_documents, sent_mask, len_sent_seq) elif self.pooling_sent.lower() == "max": encoded_documents, mask = self._pooling_doc_max(encoded_documents, sent_mask, 1) ## Fully Connected layers fc_out = self.linear_out(encoded_documents) if self.corpus_target.lower() == "asap": fc_out = self.sigmoid(fc_out) model_outputs = [] model_outputs.append(fc_out) # return fc_out return model_outputs
def get_back_centers(self, avg_sents_repr, fwrd_repr): """ Determine backward-looking centers""" batch_size = avg_sents_repr.size(0) back_repr = torch.zeros(batch_size, self.max_num_sents, self.topk_back, self.base_encoder.encoder_out_size) back_repr = utils.cast_type(back_repr, FLOAT, self.use_gpu) for sent_i in range(self.max_num_sents): if sent_i == 0 or sent_i == self.max_num_sents - 1: # there is no backward center in the first sentence continue # end if else: prev_fwrd_repr = fwrd_repr[:, sent_i - 1, :, :] # (batch_size, topk_fwrd, dim) cur_fwrd_repr = fwrd_repr[:, sent_i, :, :] # (batch_size, topk_fwrd, dim) cur_sent_repr = avg_sents_repr[:, sent_i, :] # (batch_size, dim) sim_rank = self.sim_cosine_d2(prev_fwrd_repr, cur_sent_repr.unsqueeze(1)) max_sim_val, max_sim_ind = torch.max(sim_rank, dim=1) idx = max_sim_ind.view(-1, self.topk_back, 1).expand( max_sim_ind.size(0), self.topk_back, self.base_encoder.encoder_out_size) cur_back_repr = prev_fwrd_repr.gather(1, idx) back_repr[:, sent_i] = cur_back_repr # end for topk_i # end else # end for sent_i return back_repr
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, mode=""): # if self.pad_level == "sent" or self.pad_level == "sentence": text_inputs = text_inputs.view( text_inputs.shape[0], self.max_num_sents * self.max_len_sent) mask = mask_input.view(text_inputs.shape) # encoder_out = self.base_encoder(text_inputs, mask_input, len_seq) # applying conv1d after rnn avg_pooled = torch.zeros(text_inputs.shape[0], text_inputs.shape[1], self.conv_output_size) avg_pooled = utils.cast_type(avg_pooled, FLOAT, self.use_gpu) for cur_batch, cur_tensor in enumerate(encoder_out): ## Actual length version if self.target_model == "conll17_al": cur_seq_len = int(len_seq[cur_batch]) cur_tensor = cur_tensor.unsqueeze(0) crop_tensor = cur_tensor.narrow(1, 0, cur_seq_len) crop_tensor = crop_tensor.transpose(1, 0) cur_tensor = crop_tensor ## published version: do not consider actual length else: cur_tensor = cur_tensor.unsqueeze(1) # applying conv cur_tensor = self.conv(cur_tensor) cur_tensor = self.leak_relu(cur_tensor) cur_tensor = self.dropout_layer(cur_tensor) # cur_tensor = self.avg_pool_1d(cur_tensor) cur_tensor = self.avg_adapt_pool1(cur_tensor) cur_tensor = cur_tensor.view(cur_tensor.shape[0], self.conv_output_size) avg_pooled[cur_batch, :cur_tensor.shape[0], :] = cur_tensor len_seq = utils.cast_type(len_seq, FLOAT, self.use_gpu) ## implement attention by parameters context_weight = self.context_weight.unsqueeze(1) context_weight = context_weight.expand(text_inputs.shape[0], self.conv_output_size, 1) attn_weight = torch.bmm(avg_pooled, context_weight).squeeze(2) attn_weight = self.tanh(attn_weight) attn_weight = self.softmax(attn_weight) # attention applied attn_vec = torch.bmm(avg_pooled.transpose(1, 2), attn_weight.unsqueeze(2)) ilc_vec = attn_vec.squeeze(2) ## implement attention by linear #attn_vec = self.attn(encoder_out.view(self.batch_size, -1)).unsqueeze(2) #attn_vec = self.softmax(attn_vec) #ilc_vec_attn = torch.bmm(encoder_out.transpose(1, 2), attn_vec).squeeze(2) ## FC # fully connected stage fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.corpus_target.lower() == "asap": fc_out = self.sigmoid(fc_out) return fc_out
def validate(model, evaluator, dataset_test, config, loss_func, is_test=False): model.eval() losses = [] sampler_test = SequentialSampler(dataset_test) if config.local_rank == -1 else DistributedSampler(dataset_test) dataloader_test = DataLoader(dataset_test, sampler=sampler_test, batch_size=config.batch_size) adj_list = [] root_ds_list = [] seg_map_list = [] cp_ind_list = [] num_sents_list = [] cp_seq_list = [] tid_list = [] label_list = [] for text_inputs, label_y, *remains in dataloader_test: mask_input = remains[0] len_seq = remains[1] len_sents = remains[2] tid = remains[3] cur_origin_score = remains[-1] # it might not be origin score when it is not needed for the dataset, then will be ignored text_inputs = utils.cast_type(text_inputs, LONG, config.use_gpu) mask_input = utils.cast_type(mask_input, FLOAT, config.use_gpu) len_seq = utils.cast_type(len_seq, FLOAT, config.use_gpu) with torch.no_grad(): model_outputs = model(text_inputs=text_inputs, mask_input=mask_input, len_seq=len_seq, len_sents=len_sents, tid=tid, mode="") coh_score = model_outputs[0] if config.output_size == 1: coh_score = coh_score.view(text_inputs.shape[0]) else: coh_score = coh_score.view(text_inputs.shape[0], -1) if config.output_size == 1: label_y = utils.cast_type(label_y, FLOAT, config.use_gpu) else: label_y = utils.cast_type(label_y, LONG, config.use_gpu) label_y = label_y.view(text_inputs.shape[0]) if loss_func is not None: loss = loss_func(coh_score, label_y) losses.append(loss.item()) evaluator.eval_update(coh_score, label_y, tid, cur_origin_score) # for the project of centering transformer if config.gen_logs and config.target_model.lower() == "cent_attn": batch_adj_list = model_outputs[1] batch_root_ds = model_outputs[2] batch_seg_map = model_outputs[3] batch_cp_ind = model_outputs[4] batch_num_sents = model_outputs[5] adj_list = adj_list + batch_adj_list root_ds_list = root_ds_list + batch_root_ds seg_map_list = seg_map_list + batch_seg_map cp_ind_list = cp_ind_list + batch_cp_ind num_sents_list = num_sents_list + batch_num_sents tid_list = tid_list + tid.flatten().tolist() label_list = label_list + label_y.flatten().tolist() cp_seq_list = update_cp_seq(len_seq, len_sents, config, model_outputs[4], text_inputs, cp_seq_list) # end with torch.no_grad() # end for batch_num eval_measure = evaluator.eval_measure(is_test) eval_best_val = None if is_test: eval_best_val = max(evaluator.eval_history) if loss_func is not None: valid_loss = sum(losses) / len(losses) if is_test: logger.info("Total valid loss {}".format(valid_loss)) else: valid_loss = np.inf if is_test: logger.info("{} on Test {}".format(evaluator.eval_type, eval_measure)) logger.info("Best {} on Test {}".format(evaluator.eval_type, eval_best_val)) else: logger.info("{} on Valid {}".format(evaluator.eval_type, eval_measure)) valid_itpt = [tid_list, label_list, adj_list, root_ds_list, seg_map_list, cp_seq_list] return valid_loss, eval_measure, eval_best_val, valid_itpt
def get_fwrd_centers(self, text_inputs, mask_input, len_sents): """ Determine fowrard-looking centers using an attention matrix in a PLM """ batch_size = text_inputs.size(0) fwrd_repr = torch.zeros(batch_size, self.max_num_sents, self.topk_fwr, self.base_encoder.encoder_out_size) fwrd_repr = utils.cast_type(fwrd_repr, FLOAT, self.use_gpu) avg_sents_repr = torch.zeros( batch_size, self.max_num_sents, self.base_encoder.encoder_out_size ) # averaged sents repr in the sent level encoding avg_sents_repr = utils.cast_type(avg_sents_repr, FLOAT, self.use_gpu) batch_cp_ind = torch.zeros( batch_size, self.max_num_sents) # only used for manual analysis later batch_cp_ind = utils.cast_type(batch_cp_ind, LONG, self.use_gpu) cur_ind = torch.zeros(batch_size, dtype=torch.int64) cur_ind = utils.cast_type(cur_ind, LONG, self.use_gpu) len_sents = utils.cast_type(len_sents, LONG, self.use_gpu) for sent_i in range(self.max_num_sents): cur_sent_lens = len_sents[:, sent_i] cur_max_len = int(torch.max(cur_sent_lens)) if cur_max_len > 0: cur_sent_ids = torch.zeros(batch_size, cur_max_len, dtype=torch.int64) cur_sent_ids = utils.cast_type(cur_sent_ids, LONG, self.use_gpu) cur_mask = torch.zeros(batch_size, cur_max_len, dtype=torch.int64) cur_mask = utils.cast_type(cur_mask, FLOAT, self.use_gpu) prev_ind = cur_ind cur_ind = cur_ind + cur_sent_lens for batch_ind, sent_len in enumerate(cur_sent_lens): cur_loc = cur_ind[batch_ind] prev_loc = prev_ind[batch_ind] cur_sent_ids[batch_ind, :cur_loc - prev_loc] = text_inputs[batch_ind, prev_loc:cur_loc] cur_mask[batch_ind, :cur_loc - prev_loc] = mask_input[batch_ind, prev_loc:cur_loc] # encode each sentence cur_encoded = self.base_encoder(cur_sent_ids, cur_mask, cur_sent_lens) encoded_sent = cur_encoded[ 0] # encoded output for the current sent attn_sent = cur_encoded[ 1] # averaged attention for the current sent (batch, item, item) ## filter out: we do not consider special tokens and punctation as a center; <.>, <sep>, and <cls> list_diag = [] for batch_ind, cur_mat in enumerate( attn_sent): # cur_mat:(item, item) cur_diag = torch.diag(cur_mat, diagonal=0) ## masking as the length of each sentence cur_batch_sent_len = int( cur_sent_lens[batch_ind]) # i th sentence with batch if cur_batch_sent_len > 3: cur_diag[ cur_batch_sent_len - 3:] = 0 # also remove puntation # 因为XLNet的SEP CLS 在句子末尾 else: cur_diag[cur_batch_sent_len - 2:] = 0 # only remove the special tokens list_diag.append(cur_diag) # end for attn_diag = torch.stack( list_diag) # because torch.daig does not support batch ## select forward-looking centers by indices temp_fwr_centers, fwr_sort_ind = torch.sort( attn_diag, dim=1, descending=True) # forward centers are selected by attn fwr_sort_ind = fwr_sort_ind[:, :self.topk_fwr] batch_cp_ind[:, sent_i] = fwr_sort_ind[:, 0] # only consider the top-1 item for Cp # to handle execeptional case when the sent is shorter than topk fwr_centers = torch.zeros(batch_size, self.topk_fwr) fwr_centers = utils.cast_type(fwr_centers, LONG, self.use_gpu) fwr_centers[:, :fwr_sort_ind.size(1)] = fwr_sort_ind selected = encoded_sent.gather( 1, fwr_centers.unsqueeze(-1).expand( batch_size, self.topk_fwr, self.base_encoder.encoder_out_size)) fwrd_repr[:, sent_i, :fwr_centers.size(1)] = selected ## make a sentence representation by averaging cur_sent_lens = cur_sent_lens + 1e-9 # prevent zero division cur_avg_repr = torch.div(torch.sum(encoded_sent, dim=1), cur_sent_lens.unsqueeze(1)) avg_sents_repr[:, sent_i] = cur_avg_repr # end if # end for sent_i return avg_sents_repr, fwrd_repr, batch_cp_ind
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""): batch_size = text_inputs.size(0) #### stage1: sentence level representations sent_mask = torch.sign(len_sents) # (batch_size, len_sent) sent_mask = utils.cast_type(sent_mask, FLOAT, self.use_gpu) num_sents = sent_mask.sum(dim=1) # (batch_size) avg_sents_repr = torch.zeros( batch_size, self.max_num_sents, self.base_encoder.encoder_out_size ) # averaged sents repr in the sent level encoding avg_sents_repr = utils.cast_type(avg_sents_repr, FLOAT, self.use_gpu) cur_ind = torch.zeros(batch_size, dtype=torch.int64) cur_ind = utils.cast_type(cur_ind, LONG, self.use_gpu) len_sents = utils.cast_type(len_sents, LONG, self.use_gpu) for sent_i in range(self.max_num_sents): cur_sent_lens = len_sents[:, sent_i] cur_max_len = int(torch.max(cur_sent_lens)) if cur_max_len > 0: cur_sent_ids = torch.zeros(batch_size, cur_max_len, dtype=torch.int64) cur_sent_ids = utils.cast_type(cur_sent_ids, LONG, self.use_gpu) cur_mask = torch.zeros(batch_size, cur_max_len, dtype=torch.int64) cur_mask = utils.cast_type(cur_mask, FLOAT, self.use_gpu) prev_ind = cur_ind cur_ind = cur_ind + cur_sent_lens for batch_ind, sent_len in enumerate(cur_sent_lens): cur_loc = cur_ind[batch_ind] prev_loc = prev_ind[batch_ind] cur_sent_ids[batch_ind, :cur_loc - prev_loc] = text_inputs[batch_ind, prev_loc:cur_loc] cur_mask[batch_ind, :cur_loc - prev_loc] = mask_input[batch_ind, prev_loc:cur_loc] cur_encoded = self.base_encoder(cur_sent_ids, cur_mask, cur_sent_lens) encoded_sent = cur_encoded[ 0] # encoded output for the current sent cur_sent_lens = cur_sent_lens + 1e-9 # prevent zero division cur_avg_repr = torch.div(torch.sum(encoded_sent, dim=1), cur_sent_lens.unsqueeze(1)) avg_sents_repr[:, sent_i] = cur_avg_repr # encoder sentence encoded_sents = avg_sents_repr mask_sent = torch.arange( self.max_num_sents, device=num_sents.device).expand( len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1) mask_sent = utils.cast_type(mask_sent, BOOL, self.use_gpu) num_sents = utils.cast_type(num_sents, FLOAT, self.use_gpu) #### stage2: update sentence representations using the tree transformer encoded_sents, break_probs = self.tt_encoder( encoded_sents, mask_sent ) # ['features'], ['node_order'], ['adjacency_list'], ['edge_order'] #### stage3: document attention context_weight = self.context_weight.expand(encoded_sents.shape[0], encoded_sents.shape[2], 1) attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2) attn_weight = self.tanh(attn_weight) attn_weight = masked_softmax(attn_weight, sent_mask) # attention applied attn_vec = torch.bmm(encoded_sents.transpose(1, 2), attn_weight.unsqueeze(2)) ilc_vec = attn_vec.squeeze(2) #### FC layer fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.output_size == 1: fc_out = self.sigmoid(fc_out) outputs = [] outputs.append(fc_out) # ### Sentence Ordering Task : sum_score (MSELoss) # ### avg_sents_repr : [batch_size, sent_num] # order_label_list = [] # order_score_list = torch.empty(0) # for batch_i in range(batch_size): # order_label = 0 # order_score = 0 # shuffled_sents = torch.randperm(int(num_sents[batch_i].item())) # for sent_i in range(shuffled_sents.shape[0]-1): # sent_embed_1 = avg_sents_repr[batch_i, shuffled_sents[sent_i].item()] # sent_embed_2 = avg_sents_repr[batch_i, shuffled_sents[sent_i+1].item()] # if shuffled_sents[sent_i].item() < shuffled_sents[sent_i+1].item() : order_label += 1 # so_fc_out = self.so_linear_1(torch.cat((sent_embed_1, sent_embed_2), dim=0)) # # so_fc_out = self.leak_relu(so_fc_out) # # so_fc_out = self.dropout_layer(so_fc_out) # # so_fc_out = self.so_linear_2(so_fc_out) # # so_fc_out = self.leak_relu(so_fc_out) # # so_fc_out = self.dropout_layer(so_fc_out) # # so_fc_out = self.so_linear_out(so_fc_out) # so_fc_out = self.sigmoid(so_fc_out) # # print(so_fc_out) # order_score += so_fc_out # order_label_list.append(float(order_label)) # if order_score == 0: order_score = torch.Tensor([0.0]).cuda() # if order_score_list.shape[0] == 0: # order_score_list = order_score # else: # order_score_list = torch.cat((order_score_list, order_score), dim=0) # # print(order_label_list) # [6, 8, 8, 6, 3, 8, 8, 7, 5, 6, 7, 12, 5, 4, 11, 4] # # print(order_score_list) # tensor([4.6839, 2.7560, 3.9098, 5.5054, 0.7618, 3.9745, 7.4416, 6.4322, 1.1115,2.5257, 2.6214, 6.3680, 1.9489, 3.5597, 6.2616, 1.5653],device='cuda:0', grad_fn=<CatBackward>) ### Sentence Ordering Task : list_score (BCELoss) ### avg_sents_repr : [batch_size, sent_num] order_label_list = [] order_score_list = [] for batch_i in range(batch_size): order_label = [] order_score = [] shuffled_sents = torch.randperm(int(num_sents[batch_i].item())) # print("shuffled_sents: ", len(shuffled_sents)) for sent_i in range(shuffled_sents.shape[0] - 1): sent_embed_1 = encoded_sents[batch_i, shuffled_sents[sent_i].item()] sent_embed_2 = encoded_sents[batch_i, shuffled_sents[sent_i + 1].item()] if shuffled_sents[sent_i].item() < shuffled_sents[sent_i + 1].item(): order_label.append(1.0) else: order_label.append(0.0) so_fc_out = self.so_linear_1( torch.cat((sent_embed_1, sent_embed_2), dim=0)) # so_fc_out = self.leak_relu(so_fc_out) # so_fc_out = self.dropout_layer(so_fc_out) so_fc_out = self.so_linear_2(so_fc_out) # so_fc_out = self.leak_relu(so_fc_out) # so_fc_out = self.dropout_layer(so_fc_out) so_fc_out = self.so_linear_out(so_fc_out) so_fc_out = self.sigmoid(so_fc_out) # print(so_fc_out) order_score.append(so_fc_out) # if order_score.shape[0] == 0: order_score = so_fc_out # else : order_score = torch.cat((order_score, so_fc_out), dim=0) order_label_list.append(order_label) order_score_list.append(order_score) # print("order_label: ", order_label, sum(order_label)) # print("order_score: ", order_score) # if order_score == 0: order_score = torch.Tensor([0.0]).cuda() # if order_score_list.shape[0] == 0: # order_score_list = order_score # else: # order_score_list = torch.cat((order_score_list, order_score), dim=0) # print(order_label_list) # [6, 8, 8, 6, 3, 8, 8, 7, 5, 6, 7, 12, 5, 4, 11, 4] # print(order_score_list) # tensor([4.6839, 2.7560, 3.9098, 5.5054, 0.7618, 3.9745, 7.4416, 6.4322, 1.1115,2.5257, 2.6214, 6.3680, 1.9489, 3.5597, 6.2616, 1.5653],device='cuda:0', grad_fn=<CatBackward>) if len(order_label_list) != len(order_score_list): order_label_list = [] order_score_list = [] outputs.append(order_label_list) outputs.append(order_score_list) # return fc_out return outputs
def forward(self, batch_size, inputs=None, init_state=None, attn_context=None, mode=TEACH_FORCE, gen_type='greedy', beam_size=4): # sanity checks ret_dict = dict() h_0 = init_state.squeeze(0).unsqueeze(1).repeat( 1, self.max_length - 1, 1) if mode == GEN: inputs = None if gen_type != 'beam': beam_size = 1 if inputs is not None: decoder_input = inputs else: # prepare the BOS inputs with torch.no_grad(): bos_var = Variable(torch.LongTensor([self.sos_id])) bos_var = cast_type(bos_var, LONG, self.use_gpu) decoder_input = bos_var.expand(batch_size * beam_size, 1) # logging.info("dec input: {}".format(decoder_input.shape)) h_0 = init_state.squeeze(0).unsqueeze(1).repeat(beam_size, 1, 1) # 12, 1, 100 # logging.info("h0: {}".format(h_0.shape)) if mode == GEN and gen_type == 'beam': # if beam search, repeat the initial states of the RNN if self.rnn_cell is nn.LSTM: h, c = init_state decoder_hidden = (self.repeat_state(h, batch_size, beam_size), self.repeat_state(c, batch_size, beam_size)) else: decoder_hidden = self.repeat_state(init_state, batch_size, beam_size) else: decoder_hidden = init_state decoder_outputs = [] # a list of logprob sequence_symbols = [] # a list word ids back_pointers = [] # a list of parent beam ID lengths = np.array([self.max_length] * batch_size * beam_size) def decode(step, cum_sum, step_output, step_attn): decoder_outputs.append(step_output) step_output_slice = step_output.squeeze(1) if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) if gen_type == 'greedy': symbols = step_output_slice.topk(1)[1] elif gen_type == 'sample': symbols = self.gumbel_max(step_output_slice) elif gen_type == 'beam': if step == 0: seq_score = step_output_slice.view(batch_size, -1) seq_score = seq_score[:, 0:self.output_size] else: seq_score = cum_sum + step_output_slice seq_score = seq_score.view(batch_size, -1) top_v, top_id = seq_score.topk(beam_size) back_ptr = top_id.div(self.output_size).view(-1, 1) symbols = top_id.fmod(self.output_size).view(-1, 1) cum_sum = top_v.view(-1, 1) back_pointers.append(back_ptr) else: raise ValueError("Unsupported decoding mode") sequence_symbols.append(symbols) eos_batches = symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > di) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return cum_sum, symbols # Manual unrolling is used to support random teacher forcing. # If teacher_forcing_ratio is True or False instead of a probability, # the unrolling can be done in graph if mode == TEACH_FORCE: decoder_output, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, attn_context, h_0) # in teach forcing mode, we don't need symbols. decoder_outputs = decoder_output else: # do free running here cum_sum = None for di in range(self.max_length): decoder_output, decoder_hidden, step_attn = self.forward_step( decoder_input, decoder_hidden, attn_context, h_0) cum_sum, symbols = decode(di, cum_sum, decoder_output, step_attn) decoder_input = symbols decoder_outputs = torch.cat(decoder_outputs, dim=1) if gen_type == 'beam': # do back tracking here to recover the 1-best according to # beam search. final_seq_symbols = [] cum_sum = cum_sum.view(-1, beam_size) max_seq_id = cum_sum.topk(1)[1].data.cpu().view(-1).numpy() rev_seq_symbols = sequence_symbols[::-1] rev_back_ptrs = back_pointers[::-1] for symbols, back_ptrs in zip(rev_seq_symbols, rev_back_ptrs): symbol2ds = symbols.view(-1, beam_size) back2ds = back_ptrs.view(-1, beam_size) # logging.info(symbol2ds) selected_symbols = [] selected_parents = [] for b_id in range(batch_size): selected_parents.append(back2ds[b_id, max_seq_id[b_id]]) selected_symbols.append(symbol2ds[b_id, max_seq_id[b_id]]) # logging.info(selected_symbols) final_seq_symbols.append( torch.stack(selected_symbols, 0).unsqueeze(1)) max_seq_id = torch.stack( selected_parents).data.cpu().numpy() sequence_symbols = final_seq_symbols[::-1] # save the decoded sequence symbols and sequence length ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() return decoder_outputs, decoder_hidden, ret_dict
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""): # mask_input: (batch, max_tokens), len_sents: (batch, max_num_sents) batch_size = text_inputs.size(0) mask_sent = torch.sign(len_sents) # (batch_size, len_sent) mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) num_sents = mask_sent.sum(dim=1) # (batch_size) #### Stage1 and 2: sentence repr and discourse segments parser adj_mat, sent_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind = self.centering_attn( text_inputs, mask_input, len_sents, num_sents, tid) # #### doc-level encoding input text (disable this part if GPU memory is not enough) # encoder_doc_out = self.base_encoder(text_inputs, mask_input, len_seq) # encoded_doc = encoder_doc_out[0] # if self.output_attentions: # attn_doc_avg = encoder_doc_out[1] # averaged mh attentions (batch, item, item) # mask_sent = torch.sign(len_sents) # (batch_size, len_sent) # mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) # num_sents = mask_sent.sum(dim=1) # (batch_size) # sent_repr = self.sent_repr_avg(batch_size, encoded_doc, len_sents) # torch.set_printoptions(profile="full") # print(adj_mat[0]) #### Stage3: Structure-aware transformer mask_sent_tr = torch.arange( self.max_num_sents, device=num_sents.device).expand( len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1) mask_sent_tr = utils.cast_type(mask_sent_tr, BOOL, self.use_gpu) encoded_sents, break_probs = self.tt_encoder( sent_repr, mask_sent_tr, adj_mat ) # ['features'], ['node_order'], ['adjacency_list'], ['edge_order'] #### Stage4: Document Attention context_weight = self.context_weight.expand(encoded_sents.shape[0], encoded_sents.shape[2], 1) attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2) attn_weight = self.tanh(attn_weight) attn_weight = masked_softmax(attn_weight, mask_sent) attn_vec = torch.bmm(encoded_sents.transpose(1, 2), attn_weight.unsqueeze(2)) ilc_vec = attn_vec.squeeze(2) #### FC layer fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.output_size == 1: fc_out = self.sigmoid(fc_out) # prepare output to return outputs = [] if self.gen_logs: outputs.append(batch_adj_list) outputs.append(batch_root_list) outputs.append(batch_segMap) outputs.append(batch_cp_ind) outputs.append(num_sents.tolist()) ### Sentence Ordering Task : list_score (BCELoss) ### avg_sents_repr : [batch_size, sent_num] ## Train过程 order_label_list = [] order_score_list = [] order_score_MaxMin_list = [] for batch_i in range(batch_size): order_label = [] order_score = [] order_right_score = [] order_score_MaxMin = [] shuffled_sents = torch.randperm(int(num_sents[batch_i].item())) # print("shuffled_sents: ", len(shuffled_sents)) for sent_i in range(shuffled_sents.shape[0] - 1): sent_embed_1 = encoded_sents[batch_i, shuffled_sents[sent_i].item()] sent_embed_2 = encoded_sents[batch_i, shuffled_sents[sent_i + 1].item()] so_fc_out = self.so_linear_1( torch.cat((sent_embed_1, sent_embed_2), dim=0)) # so_fc_out = self.leak_relu(so_fc_out) # so_fc_out = self.dropout_layer(so_fc_out) so_fc_out = self.so_linear_2(so_fc_out) # so_fc_out = self.leak_relu(so_fc_out) # so_fc_out = self.dropout_layer(so_fc_out) so_fc_out = self.so_linear_out(so_fc_out) so_fc_out = self.sigmoid(so_fc_out) order_score.append(so_fc_out) if shuffled_sents[sent_i].item() < shuffled_sents[sent_i + 1].item(): order_label.append(1.0) order_right_score.append(so_fc_out) else: order_label.append(0.0) if len(order_right_score) != 0: # order_score_MaxMin.append(max(order_right_score)) # order_score_MaxMin.append(min(order_right_score)) order_score_MaxMin.append( sum(order_right_score) / len(order_right_score)) order_label_list.append(order_label) order_score_list.append(order_score) order_score_MaxMin_list.append(order_score_MaxMin) if len(order_label_list) != len(order_score_list): order_label_list = [] order_score_list = [] if len(order_score_MaxMin_list) != batch_size: order_score_MaxMin_list = [] outputs.append(fc_out) outputs.append(order_label_list) outputs.append(order_score_list) outputs.append(order_score_MaxMin_list) # return fc_out return outputs # end def forward # end class
def sample_gumbel(self, logits, use_gpu, eps=1e-20): u = torch.rand(logits.size()) sample = -torch.log(-torch.log(u + eps) + eps) sample = cast_type(sample, FLOAT, use_gpu) return sample
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""): # print(text_inputs) batch_size = text_inputs.size(0) #### stage1: sentence level representations sent_mask = torch.sign(len_sents) # (batch_size, len_sent) sent_mask = utils.cast_type(sent_mask, FLOAT, self.use_gpu) num_sents = sent_mask.sum( dim=1) # (batch_size) 通过上面的sign符号函数获得文章中句子数量num_sents avg_sents_repr = torch.zeros( batch_size, self.max_num_sents, self.base_encoder.encoder_out_size ) # averaged sents repr in the sent level encoding avg_sents_repr = utils.cast_type(avg_sents_repr, FLOAT, self.use_gpu) cur_ind = torch.zeros(batch_size, dtype=torch.int64) cur_ind = utils.cast_type(cur_ind, LONG, self.use_gpu) len_sents = utils.cast_type(len_sents, LONG, self.use_gpu) for sent_i in range(self.max_num_sents): # max_num_sents 文章中句子数量 cur_sent_lens = len_sents[:, sent_i] # (batch_size) 一个位置上对应的batch中每个文章的句子长度 cur_max_len = int(torch.max(cur_sent_lens)) # 找到当前位置句子长度最大值 if cur_max_len > 0: # 如果最大值>0 即不都是padding cur_sent_ids = torch.zeros(batch_size, cur_max_len, dtype=torch.int64) cur_sent_ids = utils.cast_type(cur_sent_ids, LONG, self.use_gpu) cur_mask = torch.zeros(batch_size, cur_max_len, dtype=torch.int64) cur_mask = utils.cast_type(cur_mask, FLOAT, self.use_gpu) prev_ind = cur_ind # 两个指针,一个指向当前句子开头 cur_ind = cur_ind + cur_sent_lens # 一个指向当前句子结尾 for batch_ind, sent_len in enumerate(cur_sent_lens): cur_loc = cur_ind[batch_ind] prev_loc = prev_ind[batch_ind] cur_sent_ids[batch_ind, :cur_loc - prev_loc] = text_inputs[batch_ind, prev_loc:cur_loc] cur_mask[batch_ind, :cur_loc - prev_loc] = mask_input[batch_ind, prev_loc:cur_loc] cur_encoded = self.base_encoder(cur_sent_ids, cur_mask, cur_sent_lens) encoded_sent = cur_encoded[ 0] # encoded output for the current sent cur_sent_lens = cur_sent_lens + 1e-9 # prevent zero division cur_avg_repr = torch.div(torch.sum(encoded_sent, dim=1), cur_sent_lens.unsqueeze(1)) avg_sents_repr[:, sent_i] = cur_avg_repr # encoder sentence mask_sent = torch.arange( self.max_num_sents, device=num_sents.device).expand( len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1) mask_sent = utils.cast_type(mask_sent, BOOL, self.use_gpu) num_sents = utils.cast_type(num_sents, FLOAT, self.use_gpu) encoded_sents = avg_sents_repr #### Stage2: Avg ilc_vec = torch.div(torch.sum(encoded_sents, dim=1), num_sents.unsqueeze(1)) #### FC layer # fc1 + (leak_relu + dropout) + fc2 + (leak_relu + dropout) + fc3 + sigmoit 三层线性层 fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.output_size == 1: fc_out = self.sigmoid(fc_out) outputs = [] outputs.append(fc_out) # return fc_out # 增加Multi-task: 依次计算相邻两个句子之间的Score,假设目标是每两个句子间Score为1,目标函数为Score之和与句子数-1之间的MSE # print(avg_sents_repr) order_score = torch.zeros(batch_size) order_score = utils.cast_type(order_score, FLOAT, self.use_gpu) for batch_i in range(batch_size): batch_score = 0.0 for sent_i in range(int(num_sents[batch_i].item()) - 1): sents_repr_concated = torch.cat( (avg_sents_repr[batch_i, sent_i], avg_sents_repr[batch_i, sent_i + 1]), 0) fc_order_out = self.linear_order_1(sents_repr_concated) fc_order_out = self.leak_relu(fc_order_out) fc_order_out = self.dropout_layer(fc_order_out) fc_order_out = self.linear_order_2(fc_order_out) fc_order_out = self.leak_relu(fc_order_out) fc_order_out = self.dropout_layer(fc_order_out) fc_order_out = self.linear_order_out(fc_order_out) batch_score += self.sigmoid(fc_order_out) order_score[batch_i] = batch_score # print(order_score) # for sent_i in range(self.max_num_sents-1): # sents_repr_concated = torch.zeros(batch_size, self.base_encoder.encoder_out_size) # sents_repr_concated = utils.cast_type(sents_repr_concated, FLOAT, self.use_gpu) # sents_repr_concated = torch.cat((avg_sents_repr[:, sent_i], avg_sents_repr[:, sent_i+1]), 1) # fc_order_out = self.linear_order_1(sents_repr_concated) # fc_order_out = self.leak_relu(fc_order_out) # fc_order_out = self.dropout_layer(fc_order_out) # fc_order_out = self.linear_order_2(fc_order_out) # fc_order_out = self.leak_relu(fc_order_out) # fc_order_out = self.dropout_layer(fc_order_out) # fc_order_out = self.linear_order_out(fc_order_out) # score = self.sigmoid(fc_order_out) # # print(score.reshape(batch_size)) # order_score += score.reshape(batch_size) print(num_sents) print(order_score) outputs.append(order_score) outputs.append(num_sents) return outputs
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""): # mask_input: (batch, max_tokens), len_sents: (batch, max_num_sents) batch_size = text_inputs.size(0) mask_sent = torch.sign(len_sents) # (batch_size, len_sent) mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) num_sents = mask_sent.sum(dim=1) # (batch_size) #### Stage1 and 2: sentence repr and discourse segments parser adj_mat, sent_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind = self.centering_attn(text_inputs, mask_input, len_sents, num_sents, tid) # #### doc-level encoding input text (disable this part if GPU memory is not enough) # encoder_doc_out = self.base_encoder(text_inputs, mask_input, len_seq) # encoded_doc = encoder_doc_out[0] # if self.output_attentions: # attn_doc_avg = encoder_doc_out[1] # averaged mh attentions (batch, item, item) # mask_sent = torch.sign(len_sents) # (batch_size, len_sent) # mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) # num_sents = mask_sent.sum(dim=1) # (batch_size) # sent_repr = self.sent_repr_avg(batch_size, encoded_doc, len_sents) # torch.set_printoptions(profile="full") # print(adj_mat[0]) #### Stage3: Structure-aware transformer mask_sent_tr = torch.arange(self.max_num_sents, device=num_sents.device).expand(len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1) mask_sent_tr = utils.cast_type(mask_sent_tr, BOOL, self.use_gpu) encoded_sents, break_probs = self.tt_encoder(sent_repr, mask_sent_tr, adj_mat) # ['features'], ['node_order'], ['adjacency_list'], ['edge_order'] #### Stage4: Document Attention context_weight = self.context_weight.expand(encoded_sents.shape[0], encoded_sents.shape[2], 1) attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2) attn_weight = self.tanh(attn_weight) attn_weight = masked_softmax(attn_weight, mask_sent) attn_vec = torch.bmm(encoded_sents.transpose(1, 2), attn_weight.unsqueeze(2)) ilc_vec = attn_vec.squeeze(2) #### FC layer fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.output_size == 1: fc_out = self.sigmoid(fc_out) # prepare output to return outputs = [] if self.gen_logs: outputs.append(batch_adj_list) outputs.append(batch_root_list) outputs.append(batch_segMap) outputs.append(batch_cp_ind) outputs.append(num_sents.tolist()) ### Sentence Ordering Task : list_score (BCELoss) ### avg_sents_repr : [batch_size, sent_num] ## 如果是predict,则不用shuffle过程,直接得到Sentence Order Score # def Gaussian_prob(data, avg, sig): # data = data.cpu() # sqrt_2pi = np.power(2 * np.pi, 0.5) # coef = 1 / (sqrt_2pi * sig) # powercoef = -1 / (2 * np.power(sig, 2)) # mypow = powercoef * (np.power((data-avg), 2)) # return coef * (np.exp(mypow)).cuda() # if mode == "predict": # batch_order_score = [] # for batch_i in range(batch_size): # order_score = [] # for sent_i in range(int(num_sents[batch_i].item())-1): # sent_embed_1 = encoded_sents[batch_i, sent_i] # sent_embed_2 = encoded_sents[batch_i, sent_i+1] # so_fc_out = self.so_linear_1(torch.cat((sent_embed_1, sent_embed_2), dim=0)) # so_fc_out = self.so_linear_2(so_fc_out) # so_fc_out = self.so_linear_out(so_fc_out) # so_fc_out = self.sigmoid(so_fc_out) # order_score.append(so_fc_out) # if len(order_score) != 0: # sent_order_score = torch.mean(torch.stack(order_score), dim=0) # else: sent_order_score = torch.Tensor([0.]).cuda() # class_score = torch.cat((Gaussian_prob(sent_order_score, 0, 0.4), torch.cat((Gaussian_prob(sent_order_score, 0.5, 0.4), Gaussian_prob(sent_order_score, 1, 0.4)), dim=0)), dim=0) # batch_order_score.append(class_score) # print("sent_order_score: ", sent_order_score) # batch_order_score = torch.stack((batch_order_score), dim=0) # # print("batch_order_score", batch_order_score) # print("fc_out: ", fc_out) # fc_out += batch_order_score # outputs.append(fc_out) # return outputs ## Train过程 order_label_list = [] order_score_list = [] for batch_i in range(batch_size): order_label = [] order_score = [] shuffled_sents = torch.randperm(int(num_sents[batch_i].item())) # print("shuffled_sents: ", len(shuffled_sents)) for sent_i in range(shuffled_sents.shape[0]-1): sent_embed_1 = encoded_sents[batch_i, shuffled_sents[sent_i].item()] sent_embed_2 = encoded_sents[batch_i, shuffled_sents[sent_i+1].item()] if shuffled_sents[sent_i].item() < shuffled_sents[sent_i+1].item() : order_label.append(1.0) else : order_label.append(0.0) so_fc_out = self.so_linear_1(torch.cat((sent_embed_1, sent_embed_2), dim=0)) # so_fc_out = self.leak_relu(so_fc_out) # so_fc_out = self.dropout_layer(so_fc_out) so_fc_out = self.so_linear_2(so_fc_out) # so_fc_out = self.leak_relu(so_fc_out) # so_fc_out = self.dropout_layer(so_fc_out) so_fc_out = self.so_linear_out(so_fc_out) so_fc_out = self.sigmoid(so_fc_out) # print(so_fc_out) order_score.append(so_fc_out) # if order_score.shape[0] == 0: order_score = so_fc_out # else : order_score = torch.cat((order_score, so_fc_out), dim=0) order_label_list.append(order_label) order_score_list.append(order_score) # print("order_label: ", order_label, sum(order_label)) # print("order_score: ", order_score) # if order_score == 0: order_score = torch.Tensor([0.0]).cuda() # if order_score_list.shape[0] == 0: # order_score_list = order_score # else: # order_score_list = torch.cat((order_score_list, order_score), dim=0) # print(order_label_list) # [6, 8, 8, 6, 3, 8, 8, 7, 5, 6, 7, 12, 5, 4, 11, 4] # print(order_score_list) # tensor([4.6839, 2.7560, 3.9098, 5.5054, 0.7618, 3.9745, 7.4416, 6.4322, 1.1115,2.5257, 2.6214, 6.3680, 1.9489, 3.5597, 6.2616, 1.5653],device='cuda:0', grad_fn=<CatBackward>) if len(order_label_list) != len(order_score_list) : order_label_list = [] order_score_list = [] outputs.append(fc_out) outputs.append(order_label_list) outputs.append(order_score_list) # return fc_out return outputs # end def forward # end class
def __init__(self, config, corpus_target): logger.info('Loading embeddings from: ' + config.path_pretrained_emb) # Init parameters self.emb_path = config.path_pretrained_emb self.embed_size = config.embed_size if config.encoder_type =="transf": self.embed_size = config.d_model self.embedding_dim = None self.embeddings = {} self.emb_matrix = None self.x_embed = None # embedding layer will be returned self.tokenizer = corpus_target.tokenizer # load embedding self.pad_id = None if self.emb_path.startswith("bert-"): # Bert returns embedding layer itself, not matrix such as other pretrained embedding self.x_embed = self.load_pretrained_bert() # from bert library self.embedding_dim = 768 # self.pad_id = 0 self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>") elif self.emb_path.startswith("xlnet-"): self.x_embed = None # xlnet does not use additional pretrained embedding class # self.embedding_dim = 768 # if "large" in self.emb_path: # self.embedding_dim = 1024 self.embedding_dim = 0 # depreicated, should be cleanled later (2020.03.14) self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>") elif self.emb_path.startswith("t5-"): self.x_embed = None # xlnet does not use additional pretrained embedding class self.embedding_dim = 1024 self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>") elif self.emb_path.startswith("bart-"): self.x_embed = None # xlnet does not use additional pretrained embedding class self.embedding_dim = 1024 self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>") else: # manual version self.vocab_size = len(corpus_target.vocab) self.vocab = corpus_target.vocab # word2id self.rev_vocab = corpus_target.rev_vocab # id2word # self.pad_id = self.rev_vocab[PAD] self.pad_id = corpus_target.pad_id self.num_special_vocab = corpus_target.num_special_vocab if not self.emb_path.lower().startswith("none") and len(self.emb_path) > 1: self.load_pretrained_file() # from pretrained file self.emb_matrix = np.zeros((len(self.vocab), self.embed_size)) self.get_emb_matrix_given_vocab() # assign emb_matrix -> (len(vocab), embed_size) self.x_embed = nn.Embedding(self.vocab_size, self.embed_size, padding_idx=self.pad_id) self.x_embed = self.x_embed.from_pretrained(torch.FloatTensor(self.emb_matrix)) self.x_embed.weight.data[self.pad_id] = 0.0 # zero padding # padding_indx is disappeared when we use "from_pretrained" in pytorch 1.0 (bug?) self.x_embed.padding_idx = self.pad_id self.embedding_dim = self.x_embed.embedding_dim if config.use_gpu and self.x_embed is not None: self.x_embed = utils.cast_type(self.x_embed, FLOAT, config.use_gpu) return
def np2var(self, inputs, dtype): if inputs is None: return None return utils.cast_type(Variable(torch.from_numpy(inputs)), dtype, self.use_gpu)
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, mode=""): batch_size = text_inputs.size(0) # # if self.pad_level == "sent" or self.pad_level == "sentence": text_inputs = text_inputs.view( batch_size, text_inputs.size(1) * text_inputs.size(2)) #### word level encoding encoder_out = self.base_encoder(text_inputs, mask_input, len_seq) #### sentence represntations sent_mask = torch.sign(len_sents) # (batch_size, len_sent) num_sents = sent_mask.sum(dim=1) # (batch_size) sent_repr = torch.zeros(batch_size, self.max_num_sents, self.base_encoder.encoder_out_size) sent_repr = utils.cast_type(sent_repr, FLOAT, self.use_gpu) for cur_ind_doc in range(batch_size): list_sent_len = len_sents[cur_ind_doc] cur_sent_num = int(num_sents[cur_ind_doc]) cur_loc_sent = 0 list_cur_doc_sents = [] for cur_ind_sent in range(cur_sent_num): cur_sent_len = int(list_sent_len[cur_ind_sent]) cur_sent_repr = torch.div( torch.sum( encoder_out[cur_ind_doc, cur_loc_sent:cur_loc_sent + cur_sent_len], dim=0), cur_sent_len) # avg version cur_sent_repr = cur_sent_repr.view( 1, 1, -1) # restore to (1, 1, xrnn_cell_size) list_cur_doc_sents.append(cur_sent_repr) cur_loc_sent = cur_loc_sent + cur_sent_len # end for cur_len_sent cur_sents_repr = torch.stack( list_cur_doc_sents, dim=1) # (batch_size, num_sents, rnn_cell_size) cur_sents_repr = cur_sents_repr.squeeze(2) sent_repr[cur_ind_doc, :cur_sent_num, :] = cur_sents_repr # end for cur_doc # encoder sentence mask_sent = torch.arange( self.max_num_sents, device=num_sents.device).expand( len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1) mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) num_sents = utils.cast_type(num_sents, FLOAT, self.use_gpu) # get averaging ilc_vec_sent = torch.div( torch.sum(sent_repr, dim=1), num_sents.unsqueeze(1)) # (batch_size, rnn_cell_size) sim_avg_sent = self.sim_cosine(ilc_vec_sent.unsqueeze(1), sent_repr) sim_avg_sent = sim_avg_sent * mask_sent # get distance vector avg_pooled_sent = torch.zeros(sent_repr.shape[0], self.size_avg_pool_sent) avg_pooled_sent = utils.cast_type(avg_pooled_sent, FLOAT, self.use_gpu) for cur_batch, cur_tensor in enumerate(sim_avg_sent): cur_seq_len = int(num_sents[cur_batch]) cur_tensor = cur_tensor.unsqueeze(0) crop_tensor = cur_tensor.narrow(1, 0, cur_seq_len) crop_tensor = crop_tensor.unsqueeze(1) sim_conv = self.conv_sent( crop_tensor ) ## This part should be tested by sentence level later! sim_conv = self.leak_relu(sim_conv) sim_conv = self.dropout_01(sim_conv) # sim_conv = self.dropout_02(sim_conv) cur_avg_pooled = self.max_adapt_pool1_sent(sim_conv) avg_pooled_sent[cur_batch, :] = cur_avg_pooled #### FC layer ilc_vec = torch.cat((ilc_vec_sent, avg_pooled_sent), dim=1) # concat the centroid and similarity vector # test attention part fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.output_size == 1: fc_out = self.sigmoid(fc_out) return fc_out
def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""): # mask_input: (batch, max_tokens), len_sents: (batch, max_num_sents) batch_size = text_inputs.size(0) mask_sent = torch.sign(len_sents) # (batch_size, len_sent) mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) num_sents = mask_sent.sum(dim=1) # (batch_size) #### Stage1 and 2: sentence repr and discourse segments parser adj_mat, sent_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind = self.centering_attn( text_inputs, mask_input, len_sents, num_sents, tid) # #### doc-level encoding input text (disable this part if GPU memory is not enough) # encoder_doc_out = self.base_encoder(text_inputs, mask_input, len_seq) # encoded_doc = encoder_doc_out[0] # if self.output_attentions: # attn_doc_avg = encoder_doc_out[1] # averaged mh attentions (batch, item, item) # mask_sent = torch.sign(len_sents) # (batch_size, len_sent) # mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu) # num_sents = mask_sent.sum(dim=1) # (batch_size) # sent_repr = self.sent_repr_avg(batch_size, encoded_doc, len_sents) # torch.set_printoptions(profile="full") # print(adj_mat[0]) #### Stage3: Structure-aware transformer mask_sent_tr = torch.arange( self.max_num_sents, device=num_sents.device).expand( len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1) mask_sent_tr = utils.cast_type(mask_sent_tr, BOOL, self.use_gpu) encoded_sents, break_probs = self.tt_encoder( sent_repr, mask_sent_tr, adj_mat ) # ['features'], ['node_order'], ['adjacency_list'], ['edge_order'] #### Stage4: Document Attention context_weight = self.context_weight.expand(encoded_sents.shape[0], encoded_sents.shape[2], 1) attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2) attn_weight = self.tanh(attn_weight) attn_weight = masked_softmax(attn_weight, mask_sent) attn_vec = torch.bmm(encoded_sents.transpose(1, 2), attn_weight.unsqueeze(2)) ilc_vec = attn_vec.squeeze(2) #### FC layer fc_out = self.linear_1(ilc_vec) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_2(fc_out) fc_out = self.leak_relu(fc_out) fc_out = self.dropout_layer(fc_out) fc_out = self.linear_out(fc_out) if self.output_size == 1: fc_out = self.sigmoid(fc_out) # prepare output to return outputs = [] outputs.append(fc_out) if self.gen_logs: outputs.append(batch_adj_list) outputs.append(batch_root_list) outputs.append(batch_segMap) outputs.append(batch_cp_ind) outputs.append(num_sents.tolist()) # return fc_out return outputs # end def forward # end class
def init_hidden_layers(self, batch_size, num_layers, rnn_cell_size): hid = torch.autograd.Variable(torch.zeros(num_layers, batch_size, rnn_cell_size)) hid = utils.cast_type(hid, FLOAT, self.use_gpu) nn.init.xavier_uniform_(hid) return hid