class CompressDecoder(torch.nn.Module): def __init__(self, context_dim, dec_state_dim, enc_hid_dim, text_field_embedder, aggressive_compression: int = -1, keep_threshold: float = 0.5, abs_board_file="/home/cc/exComp/board.txt", gather='mean', dropout=0.5, dropout_emb=0.2, valid_tmp_path='/scratch/cluster/jcxu/exComp', serilization_name: str = "", vocab=None, elmo: bool = False, elmo_weight: str = "elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"): super().__init__() self.use_elmo = elmo self.serilization_name = serilization_name if elmo: from allennlp.modules.elmo import Elmo, batch_to_ids from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper self.vocab = vocab options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json" weight_file = elmo_weight self.elmo = Elmo(options_file, weight_file, 1, dropout=dropout_emb) # print(self.elmo.get_output_dim()) # self.word_emb_dim = text_field_embedder.get_output_dim() # self._context_layer = PytorchSeq2SeqWrapper( # torch.nn.LSTM(self.word_emb_dim + self.elmo.get_output_dim(), self.word_emb_dim, # batch_first=True, bidirectional=True)) self.word_emb_dim = self.elmo.get_output_dim() else: self._text_field_embedder = text_field_embedder self.word_emb_dim = text_field_embedder.get_output_dim() self.XEloss = torch.nn.CrossEntropyLoss(reduction='none') self.device = get_device() # self.rouge_metrics_compression = RougeStrEvaluation(name='cp', path_to_valid=valid_tmp_path, # writting_address=valid_tmp_path, # serilization_name=serilization_name) # self.rouge_metrics_compression_best_possible = RougeStrEvaluation(name='cp_ub', path_to_valid=valid_tmp_path, # writting_address=valid_tmp_path, # serilization_name=serilization_name) self.enc = EncCompression(inp_dim=self.word_emb_dim, hid_dim=enc_hid_dim, gather=gather) # TODO dropout self.aggressive_compression = aggressive_compression self.relu = torch.nn.ReLU() self.attn = NewAttention(enc_dim=self.enc.get_output_dim(), dec_dim=self.enc.get_output_dim_unit() * 2 + dec_state_dim) self.concat_size = self.enc.get_output_dim() + self.enc.get_output_dim_unit() * 2 + dec_state_dim self.valid_tmp_path = valid_tmp_path if self.aggressive_compression < 0: self.XELoss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=-1) # self.nn_lin = torch.nn.Linear(self.concat_size, self.concat_size) # self.nn_lin2 = torch.nn.Linear(self.concat_size, 2) self.ff = FeedForward(input_dim=self.concat_size, num_layers=3, hidden_dims=[self.concat_size, self.concat_size, 2], activations=[torch.nn.Tanh(), torch.nn.Tanh(), lambda x: x], dropout=dropout ) # Keep thresold # self.keep_thres = list(np.arange(start=0.2, stop=0.6, step=0.075)) self.keep_thres = [0.0, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 1.0] self.rouge_metrics_compression_dict = OrderedDict() for thres in self.keep_thres: self.rouge_metrics_compression_dict["{}".format(thres)] = RougeStrEvaluation(name='cp_{}'.format(thres), path_to_valid=valid_tmp_path, writting_address=valid_tmp_path, serilization_name=serilization_name) def encode_sent_and_span_paral(self, text, # batch, max_sent, max_word text_msk, # batch, max_sent, max_word span, # batch, max_sent_num, max_span_num, max_word sent_idx # batch size ): this_text = two_dim_index_select(text['tokens'], sent_idx) # batch, max_word from allennlp.modules.elmo import batch_to_ids if self.use_elmo: this_text_list: List = this_text.tolist() text_str_list = [] for sample in this_text_list: s = [self.vocab.get_token_from_index(x) for x in sample] text_str_list.append(s) character_ids = batch_to_ids(text_str_list).to(self.device) this_context = self.elmo(character_ids) # print(this_context['elmo_representations'][0].size()) this_context = this_context['elmo_representations'][0] else: this_text = {'tokens': this_text} this_context = self._text_field_embedder(this_text) num_doc, max_word, inp_dim = this_context.size() batch_size = sent_idx.size()[0] assert batch_size == num_doc # text is the original text of the selected sentence. # this_context = two_dim_index_select(context, sent_idx) # batch, max_word, hdim this_context_mask = two_dim_index_select(text_msk, sent_idx) # batch, max_word this_span = two_dim_index_select(span, sent_idx) # batch , nspan, max_word concat_rep_of_compression, \ span_msk, original_sent_rep = self.enc.forward(word_emb=this_context, word_emb_msk=this_context_mask, span=this_span) return concat_rep_of_compression, span_msk, original_sent_rep def encode_sent_and_span(self, text, text_msk, span, batch_idx, sent_idx): context = self._text_field_embedder(text) num_doc, max_sent, max_word, inp_dim = context.size() num_doc_, max_sent_, nspan = span.size()[0:-1] assert num_doc == num_doc_ assert max_sent == max_sent_ this_context = context[batch_idx, sent_idx, :, :].unsqueeze(0) this_span = span[batch_idx, sent_idx, :, :].unsqueeze(0) this_context_mask = text_msk[batch_idx, sent_idx, :].unsqueeze(0) flattened_enc, attn_dist, \ spans_rep, span_msk, score \ = self.enc.forward(word_emb=this_context, word_emb_msk=this_context_mask, span=this_span) return flattened_enc, spans_rep, span_msk # 1, hid*2 1, span num, hid 1, span num def indep_compression_judger(self, reps): # t, batch_size_, max_span_num,self.concat_size timestep, batch_size, max_span_num, dim = reps.size() score = self.ff.forward(reps) # lin_out = self.nn_lin(reps) # activated = torch.sigmoid(lin_out) # score = self.nn_lin2(activated) if random.random() < 0.005: print("score: {}".format(score[0])) return score def get_out_dim(self): return self.concat_size def forward_parallel(self, sent_decoder_states, # t, batch, hdim sent_decoder_outputs_logit, # t, batch document_rep, # batch, hdim text, # batch, max_sent, max_word text_msk, # batch, max_sent, max_word span): # batch, max_sent_num, max_span_num, max_word # Encode compression options given sent emission. # output scores, attn dist, ... t, batch_size, hdim = sent_decoder_states.size() t_, batch_size_ = sent_decoder_outputs_logit.size() # invalid bits are -1 batch, max_sent, max_span_num, max_word = span.size() # assert t == t_ t = min(t, t_) assert batch_size == batch == batch_size_ if self.aggressive_compression > 0: all_attn_dist = torch.zeros((t, batch_size, max_span_num)).to(self.device) all_scores = torch.ones((t, batch_size, max_span_num)).to(self.device) * -100 else: all_attn_dist = None all_scores = None all_reps = torch.zeros((t, batch_size_, max_span_num, self.concat_size), device=self.device) for timestep in range(t): dec_state = sent_decoder_states[timestep] # batch, dim logit = sent_decoder_outputs_logit[timestep] # batch # valid_mask = (logit > 0) positive_logit = self.relu(logit.float()).long() # turn -1 to 0 span_t, span_msk_t, sent_t = self.encode_sent_and_span_paral(text=text, text_msk=text_msk, span=span, sent_idx=positive_logit) # sent_t : batch, sent_dim # span_t: batch, span_num, span_dim # span_msk_t: batch, span_num [[1,1,1,0,0,0], concated_rep_high_level = torch.cat([dec_state, document_rep, sent_t], dim=1) # batch, DIM if self.aggressive_compression > 0: attn_dist, score = self.attn.forward_one_step(enc_state=span_t, dec_state=concated_rep_high_level, enc_mask=span_msk_t.float()) # attn_dist: batch, span num # score: batch, span num # concated_rep: batch, dim ==> batch, 1, dim ==> batch, max_span_num, dim expanded_concated_rep = concated_rep_high_level.unsqueeze(1).expand((batch, max_span_num, -1)) all_reps[timestep, :, :, :] = torch.cat([expanded_concated_rep, span_t], dim=2) if self.aggressive_compression > 0: all_attn_dist[timestep, :, :] = attn_dist all_scores[timestep, :, :] = score return all_attn_dist, all_scores, all_reps def comp_loss_inf_deletion(self, decoder_outputs_logit, # gold label!!!! # span_seq_label, # batch, max sent num span_rouge, # batch, max sent num, max compression num scores, comp_rouge_ratio, loss_thres=1 ): """ :param decoder_outputs_logit: :param span_rouge: [batch, max_sent, max_compression] :param scores: [timestep, batch, max_compression, 2] :param comp_rouge_ratio: [batch_size, max_sent, max_compression] :return: """ tim, bat = decoder_outputs_logit.size() time, batch, max_span, _ = scores.size() batch_, sent_len, max_sp = span_rouge.size() assert batch_ == batch == bat assert time == tim assert max_sp == max_span goal_rouge_label = torch.ones((tim, batch, max_span), device=self.device, dtype=torch.long, ) * (-1) weights = torch.ones((tim, batch, max_span), device=self.device, dtype=torch.float) decoder_outputs_logit_mask = (decoder_outputs_logit >= 0).unsqueeze(2).expand( (time, batch, max_span)).float().view(-1) decoder_outputs_logit = torch.nn.functional.relu(decoder_outputs_logit).long() z = torch.zeros((1), device=self.device) for tt in range(tim): decoder_outputs_logit_t = decoder_outputs_logit[tt] out = two_dim_index_select(inp=comp_rouge_ratio, index=decoder_outputs_logit_t) label = torch.gt(out, loss_thres).long() mini_mask = torch.gt(out, 0.01).float() # baseline_mask = 1 - torch.lt(torch.abs(out - 0.99), 0.01).float() # baseline will be 0 # weight = torch.max(input=-out + 0.5, other=z) + 1 # weights[tt] = mini_mask * baseline_mask weights[tt] = mini_mask goal_rouge_label[tt] = label probs = scores.view(-1, 2) goal_rouge_label = goal_rouge_label.view(-1) weights = weights.view(-1) loss = self.XELoss(input=probs, target=goal_rouge_label) loss = loss * decoder_outputs_logit_mask * weights return torch.mean(loss) def comp_loss(self, decoder_outputs_logit, # gold label!!!! scores, span_seq_label, # batch, max sent num span_rouge, # batch, max sent num, max compression num comp_rouge_ratio ): t, batch = decoder_outputs_logit.size() t_, batch_, comp_num = scores.size() b, max_sent = span_seq_label.size() # b_, max_sen, max_comp_, _ = span.size() _b, max_sent_, max_comp = span_rouge.size() assert batch == batch_ == b == _b assert max_sent_ == max_sent assert comp_num == max_comp span_seq_label = span_seq_label.long() total_loss = torch.zeros((t, b)).to(self.device) # print(decoder_outputs_logit) # print(span_seq_label) for timestep in range(t): # this is the sent idx for batch_idx in range(b): logit = decoder_outputs_logit[timestep][batch_idx] # print(logit) # decoder_outputs_logit should be the gold label for sentence emission. # if it's 0 or -1, then we skip supervision. if logit < 0: continue ref_rouge_score = comp_rouge_ratio[batch_idx][logit] num_of_compression = ref_rouge_score.size()[0] _supervision_label_msk = (ref_rouge_score > 0.98).float() label = torch.from_numpy(np.arange(num_of_compression)).to(self.device).long() score_t = scores[timestep][batch_idx].unsqueeze(0) # comp num score_t = score_t.expand(num_of_compression, -1) # label = span_seq_label[batch_idx][logit].unsqueeze(0) loss = self.XEloss(score_t, label) # print(loss) loss = _supervision_label_msk * loss total_loss[timestep][batch_idx] = torch.sum(loss) # sent_msk_t = two_dim_index_select(sent_mask, logit) return torch.mean(total_loss) def _dec_compression_one_step(self, predict_compression, sp_meta, word_sent: List[str], keep_threshold: List[float], context: List[List[str]] = None): full_set_len = set(range(len(word_sent))) # max_comp, _ = predict_compression.size preds = [full_set_len.copy() for _ in range(len(keep_threshold))] # Show all of the compression spans stat_compression = {} for comp_idx, comp_meta in enumerate(sp_meta): p = predict_compression[comp_idx][1] node_type, sel_idx, rouge, ratio = comp_meta if node_type != "BASELINE": selected_words = [x for idx, x in enumerate(word_sent) if idx in sel_idx] selected_words_str = "_".join(selected_words) stat_compression["{}".format(selected_words_str)] = { "prob": float("{0:.2f}".format(p)), # float("{0:.2f}".format()) "type": node_type, "rouge": float("{0:.2f}".format(rouge)), "ratio": float("{0:.2f}".format(ratio)), "sel_idx": sel_idx, "len": len(sel_idx) } stat_compression_order = OrderedDict( sorted(stat_compression.items(), key=lambda item: item[1]["prob"], reverse=True)) # Python 3 for idx, _keep_thres in enumerate(keep_threshold): history: List[str] = context[idx] his_set = set((" ".join(history)).split(" ")) for key, value in stat_compression_order.items(): p = value['prob'] sel_idx = value['sel_idx'] sel_txt = set([word_sent[x] for x in sel_idx]) if sel_txt - his_set == set(): # print("Save big!") # print("Context: {}\tCandidate: {}".format(his_set, sel_txt)) preds[idx] = preds[idx] - set(value['sel_idx']) continue if p > _keep_thres: preds[idx] = preds[idx] - set(value['sel_idx']) preds = [list(x) for x in preds] for pred in preds: pred.sort() # Visual output visual_outputs: List[str] = [] words_for_evaluation: List[str] = [] meta_keep_ratio_word = [] for idx, compression in enumerate(preds): output = [word_sent[jdx] if (jdx in compression) else '_' + word_sent[jdx] + '_' for jdx in range(len(word_sent))] visual_outputs.append(" ".join(output)) words = [word_sent[x] for x in compression] meta_keep_ratio_word.append(float(len(words) / len(word_sent))) # meta_kepp_ratio_span.append(1 - float(len(survery['type'][idx]) / len(sp_meta))) words = " ".join(words) words = easy_post_processing(words) # print(words) words_for_evaluation.append(words) d: List[List] = [] for kep_th, vis, words_eva, keep_word_ratio in zip(keep_threshold, visual_outputs, words_for_evaluation, meta_keep_ratio_word): d.append([kep_th, vis, words_eva, keep_word_ratio]) return stat_compression_order, d def decode_inf_deletion(self, sent_decoder_outputs_logit, # time, batch span_prob, # time, batch, max_comp, 2 metadata: List, span_meta: List, span_rouge, # batch, sent, max_comp keep_threshold: List[float] ): batch_size, max_sent_num, max_comp_num = span_rouge.size() t, batsz, max_comp, _ = span_prob.size() span_score = torch.nn.functional.softmax(span_prob, dim=3).cpu().numpy() timestep, batch = sent_decoder_outputs_logit.size() sent_decoder_outputs_logit = sent_decoder_outputs_logit.cpu().data for idx, m in enumerate(metadata): abs_s = [" ".join(s) for s in m["abs_list"]] comp_exe = CompExecutor(span_meta=span_meta[idx], sent_idxs=sent_decoder_outputs_logit[:, idx], prediction_score=span_score[:, idx, :, :], abs_str=abs_s, name=m['name'], doc_list=m["doc_list"], keep_threshold=keep_threshold, part=m['name'], ser_dir=self.valid_tmp_path, ser_fname=self.serilization_name ) # processed_words, del_record, \ # compressions, full_sents, \ bag_pred_eval = comp_exe.run() full_sents: List[List[str]] = comp_exe.full_sents # assemble full sents full_sents = [" ".join(x) for x in full_sents] # visual to console for idx in range(len(keep_threshold)): self.rouge_metrics_compression_dict["{}".format(keep_threshold[idx])](pred=bag_pred_eval[idx], ref=[abs_s], origin=full_sents )
class CitationRanker(Model): def __init__(self, vocab: Vocabulary, title_embedder: TextFieldEmbedder, abstract_embedder: TextFieldEmbedder, dense_dim=75) -> None: super().__init__(vocab) self.title_embedder = title_embedder self.abstract_embedder = abstract_embedder self.intermediate_dim = 6 self.n_layers = 3 self.layer_dims = [dense_dim for i in range(self.n_layers - 1)] self.layer_dims.append(1) self.activations = [ Activation.by_name("elu")(), Activation.by_name("elu")(), Activation.by_name("sigmoid")() ] self.layers = FeedForward(input_dim=self.intermediate_dim, num_layers=self.n_layers, hidden_dims=self.layer_dims, activations=self.activations) def forward(self, query_title: Dict[str, tensor], query_abstract: Dict[str, tensor], candidate_title: Dict[str, tensor], candidate_abstract: Dict[str, tensor], candidate_citations: tensor, title_intersection: tensor, abstract_intersection: tensor, cos_sim: tensor, label: tensor = None) -> Dict[str, tensor]: query_title_embed = self.title_embedder.forward(query_title["tokens"]) query_abstract_embed = self.abstract_embedder.forward( query_abstract["tokens"]) candidate_title_embed = self.title_embedder.forward( candidate_title["tokens"]) candidate_abstract_embed = self.abstract_embedder.forward( candidate_title["tokens"]) title_cos_sim = CosineSimilarity().forward( query_title_embed, candidate_title_embed).unsqueeze(-1) abstract_cos_sim = CosineSimilarity().forward( query_abstract_embed, candidate_abstract_embed).unsqueeze(-1) intermediate_output = cat( (title_cos_sim, abstract_cos_sim, candidate_citations, title_intersection, abstract_intersection, cos_sim), dim=-1) pred = self.layers.forward(intermediate_output) output = {"cite_prob": pred} if label is not None: output["loss"] = self._compute_loss(pred, label) return output def _compute_loss(self, pred: tensor, label: tensor) -> tensor: #in existing implementation training examples with even indices are positive/odd indices are negative positive = pred[::2] negative = pred[1::2] #"margin is given by the difference in label" margin = label[::2] - label[1::2] delta = clamp(margin + negative - positive, min=0) return mean(delta)