예제 #1
0
class CompressDecoder(torch.nn.Module):
    def __init__(self,
                 context_dim,
                 dec_state_dim,
                 enc_hid_dim,
                 text_field_embedder,
                 aggressive_compression: int = -1,
                 keep_threshold: float = 0.5,
                 abs_board_file="/home/cc/exComp/board.txt",
                 gather='mean',
                 dropout=0.5,
                 dropout_emb=0.2,
                 valid_tmp_path='/scratch/cluster/jcxu/exComp',
                 serilization_name: str = "",
                 vocab=None,
                 elmo: bool = False,
                 elmo_weight: str = "elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"):
        super().__init__()
        self.use_elmo = elmo
        self.serilization_name = serilization_name
        if elmo:
            from allennlp.modules.elmo import Elmo, batch_to_ids
            from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
            self.vocab = vocab

            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
            weight_file = elmo_weight
            self.elmo = Elmo(options_file, weight_file, 1, dropout=dropout_emb)
            # print(self.elmo.get_output_dim())
            # self.word_emb_dim = text_field_embedder.get_output_dim()
            # self._context_layer = PytorchSeq2SeqWrapper(
            #     torch.nn.LSTM(self.word_emb_dim + self.elmo.get_output_dim(), self.word_emb_dim,
            #                   batch_first=True, bidirectional=True))
            self.word_emb_dim = self.elmo.get_output_dim()
        else:
            self._text_field_embedder = text_field_embedder
            self.word_emb_dim = text_field_embedder.get_output_dim()

        self.XEloss = torch.nn.CrossEntropyLoss(reduction='none')
        self.device = get_device()

        # self.rouge_metrics_compression = RougeStrEvaluation(name='cp', path_to_valid=valid_tmp_path,
        #                                                     writting_address=valid_tmp_path,
        #                                                     serilization_name=serilization_name)
        # self.rouge_metrics_compression_best_possible = RougeStrEvaluation(name='cp_ub', path_to_valid=valid_tmp_path,
        #                                                                   writting_address=valid_tmp_path,
        #                                                                   serilization_name=serilization_name)
        self.enc = EncCompression(inp_dim=self.word_emb_dim, hid_dim=enc_hid_dim, gather=gather)  # TODO dropout

        self.aggressive_compression = aggressive_compression
        self.relu = torch.nn.ReLU()

        self.attn = NewAttention(enc_dim=self.enc.get_output_dim(),
                                 dec_dim=self.enc.get_output_dim_unit() * 2 + dec_state_dim)

        self.concat_size = self.enc.get_output_dim() + self.enc.get_output_dim_unit() * 2 + dec_state_dim
        self.valid_tmp_path = valid_tmp_path
        if self.aggressive_compression < 0:
            self.XELoss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
            # self.nn_lin = torch.nn.Linear(self.concat_size, self.concat_size)
            # self.nn_lin2 = torch.nn.Linear(self.concat_size, 2)

            self.ff = FeedForward(input_dim=self.concat_size, num_layers=3,
                                  hidden_dims=[self.concat_size, self.concat_size, 2],
                                  activations=[torch.nn.Tanh(), torch.nn.Tanh(), lambda x: x],
                                  dropout=dropout
                                  )
            # Keep thresold

            # self.keep_thres = list(np.arange(start=0.2, stop=0.6, step=0.075))
            self.keep_thres = [0.0, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 1.0]
            self.rouge_metrics_compression_dict = OrderedDict()
            for thres in self.keep_thres:
                self.rouge_metrics_compression_dict["{}".format(thres)] = RougeStrEvaluation(name='cp_{}'.format(thres),
                                                                                             path_to_valid=valid_tmp_path,
                                                                                             writting_address=valid_tmp_path,
                                                                                             serilization_name=serilization_name)

    def encode_sent_and_span_paral(self, text,  # batch, max_sent, max_word
                                   text_msk,  # batch, max_sent, max_word
                                   span,  # batch, max_sent_num, max_span_num, max_word
                                   sent_idx  # batch size
                                   ):
        this_text = two_dim_index_select(text['tokens'], sent_idx)  # batch, max_word
        from allennlp.modules.elmo import batch_to_ids
        if self.use_elmo:
            this_text_list: List = this_text.tolist()
            text_str_list = []
            for sample in this_text_list:
                s = [self.vocab.get_token_from_index(x) for x in sample]
                text_str_list.append(s)
            character_ids = batch_to_ids(text_str_list).to(self.device)
            this_context = self.elmo(character_ids)
            # print(this_context['elmo_representations'][0].size())
            this_context = this_context['elmo_representations'][0]
        else:
            this_text = {'tokens': this_text}
            this_context = self._text_field_embedder(this_text)

        num_doc, max_word, inp_dim = this_context.size()
        batch_size = sent_idx.size()[0]
        assert batch_size == num_doc

        # text is the original text of the selected sentence.
        # this_context = two_dim_index_select(context, sent_idx)  # batch, max_word, hdim
        this_context_mask = two_dim_index_select(text_msk, sent_idx)  # batch, max_word
        this_span = two_dim_index_select(span, sent_idx)  # batch , nspan, max_word

        concat_rep_of_compression, \
        span_msk, original_sent_rep = self.enc.forward(word_emb=this_context,
                                                       word_emb_msk=this_context_mask,
                                                       span=this_span)
        return concat_rep_of_compression, span_msk, original_sent_rep

    def encode_sent_and_span(self, text, text_msk, span, batch_idx, sent_idx):
        context = self._text_field_embedder(text)
        num_doc, max_sent, max_word, inp_dim = context.size()
        num_doc_, max_sent_, nspan = span.size()[0:-1]
        assert num_doc == num_doc_
        assert max_sent == max_sent_
        this_context = context[batch_idx, sent_idx, :, :].unsqueeze(0)
        this_span = span[batch_idx, sent_idx, :, :].unsqueeze(0)
        this_context_mask = text_msk[batch_idx, sent_idx, :].unsqueeze(0)
        flattened_enc, attn_dist, \
        spans_rep, span_msk, score \
            = self.enc.forward(word_emb=this_context,
                               word_emb_msk=this_context_mask,
                               span=this_span)
        return flattened_enc, spans_rep, span_msk
        # 1, hid*2      1, span num, hid        1, span num

    def indep_compression_judger(self, reps):
        # t, batch_size_, max_span_num,self.concat_size
        timestep, batch_size, max_span_num, dim = reps.size()
        score = self.ff.forward(reps)
        # lin_out = self.nn_lin(reps)
        # activated = torch.sigmoid(lin_out)
        # score = self.nn_lin2(activated)
        if random.random() < 0.005:
            print("score: {}".format(score[0]))
        return score

    def get_out_dim(self):
        return self.concat_size

    def forward_parallel(self, sent_decoder_states,  # t, batch, hdim
                         sent_decoder_outputs_logit,  # t, batch
                         document_rep,  # batch, hdim
                         text,  # batch, max_sent, max_word
                         text_msk,  # batch, max_sent, max_word
                         span):  # batch, max_sent_num, max_span_num, max_word
        # Encode compression options given sent emission.
        # output scores, attn dist, ...
        t, batch_size, hdim = sent_decoder_states.size()
        t_, batch_size_ = sent_decoder_outputs_logit.size()  # invalid bits are -1
        batch, max_sent, max_span_num, max_word = span.size()
        # assert t == t_
        t = min(t, t_)
        assert batch_size == batch == batch_size_
        if self.aggressive_compression > 0:
            all_attn_dist = torch.zeros((t, batch_size, max_span_num)).to(self.device)
            all_scores = torch.ones((t, batch_size, max_span_num)).to(self.device) * -100
        else:
            all_attn_dist = None
            all_scores = None
        all_reps = torch.zeros((t, batch_size_, max_span_num, self.concat_size), device=self.device)
        for timestep in range(t):
            dec_state = sent_decoder_states[timestep]  # batch, dim
            logit = sent_decoder_outputs_logit[timestep]  # batch

            # valid_mask = (logit > 0)
            positive_logit = self.relu(logit.float()).long()  # turn -1 to 0

            span_t, span_msk_t, sent_t = self.encode_sent_and_span_paral(text=text,
                                                                         text_msk=text_msk,
                                                                         span=span,
                                                                         sent_idx=positive_logit)
            # sent_t : batch, sent_dim
            # span_t: batch, span_num, span_dim
            # span_msk_t: batch, span_num [[1,1,1,0,0,0],

            concated_rep_high_level = torch.cat([dec_state, document_rep, sent_t], dim=1)
            # batch, DIM
            if self.aggressive_compression > 0:
                attn_dist, score = self.attn.forward_one_step(enc_state=span_t,
                                                              dec_state=concated_rep_high_level,
                                                              enc_mask=span_msk_t.float())
            # attn_dist: batch, span num
            # score:    batch, span num

            # concated_rep: batch, dim ==> batch, 1, dim ==> batch, max_span_num, dim
            expanded_concated_rep = concated_rep_high_level.unsqueeze(1).expand((batch, max_span_num, -1))
            all_reps[timestep, :, :, :] = torch.cat([expanded_concated_rep, span_t], dim=2)
            if self.aggressive_compression > 0:
                all_attn_dist[timestep, :, :] = attn_dist
                all_scores[timestep, :, :] = score

        return all_attn_dist, all_scores, all_reps

    def comp_loss_inf_deletion(self,
                               decoder_outputs_logit,  # gold label!!!!
                               # span_seq_label,  # batch, max sent num
                               span_rouge,  # batch, max sent num, max compression num
                               scores,
                               comp_rouge_ratio,
                               loss_thres=1
                               ):
        """

        :param decoder_outputs_logit:
        :param span_rouge: [batch, max_sent, max_compression]
        :param scores: [timestep, batch, max_compression, 2]
        :param comp_rouge_ratio: [batch_size, max_sent, max_compression]
        :return:
        """
        tim, bat = decoder_outputs_logit.size()
        time, batch, max_span, _ = scores.size()
        batch_, sent_len, max_sp = span_rouge.size()
        assert batch_ == batch == bat
        assert time == tim
        assert max_sp == max_span
        goal_rouge_label = torch.ones((tim, batch, max_span), device=self.device, dtype=torch.long,
                                      ) * (-1)
        weights = torch.ones((tim, batch, max_span), device=self.device, dtype=torch.float)
        decoder_outputs_logit_mask = (decoder_outputs_logit >= 0).unsqueeze(2).expand(
            (time, batch, max_span)).float().view(-1)
        decoder_outputs_logit = torch.nn.functional.relu(decoder_outputs_logit).long()
        z = torch.zeros((1), device=self.device)
        for tt in range(tim):
            decoder_outputs_logit_t = decoder_outputs_logit[tt]
            out = two_dim_index_select(inp=comp_rouge_ratio, index=decoder_outputs_logit_t)
            label = torch.gt(out, loss_thres).long()

            mini_mask = torch.gt(out, 0.01).float()

            # baseline_mask = 1 - torch.lt(torch.abs(out - 0.99), 0.01).float()  # baseline will be 0

            # weight = torch.max(input=-out + 0.5, other=z) + 1
            # weights[tt] = mini_mask * baseline_mask
            weights[tt] = mini_mask
            goal_rouge_label[tt] = label
        probs = scores.view(-1, 2)
        goal_rouge_label = goal_rouge_label.view(-1)
        weights = weights.view(-1)
        loss = self.XELoss(input=probs, target=goal_rouge_label)
        loss = loss * decoder_outputs_logit_mask * weights
        return torch.mean(loss)

    def comp_loss(self, decoder_outputs_logit,  # gold label!!!!
                  scores,
                  span_seq_label,  # batch, max sent num
                  span_rouge,  # batch, max sent num, max compression num
                  comp_rouge_ratio
                  ):
        t, batch = decoder_outputs_logit.size()
        t_, batch_, comp_num = scores.size()
        b, max_sent = span_seq_label.size()
        # b_, max_sen, max_comp_, _ = span.size()
        _b, max_sent_, max_comp = span_rouge.size()
        assert batch == batch_ == b == _b
        assert max_sent_ == max_sent
        assert comp_num == max_comp
        span_seq_label = span_seq_label.long()
        total_loss = torch.zeros((t, b)).to(self.device)
        # print(decoder_outputs_logit)
        # print(span_seq_label)
        for timestep in range(t):

            # this is the sent idx
            for batch_idx in range(b):
                logit = decoder_outputs_logit[timestep][batch_idx]
                # print(logit)
                # decoder_outputs_logit should be the gold label for sentence emission.
                # if it's 0 or -1, then we skip supervision.
                if logit < 0:
                    continue
                ref_rouge_score = comp_rouge_ratio[batch_idx][logit]
                num_of_compression = ref_rouge_score.size()[0]

                _supervision_label_msk = (ref_rouge_score > 0.98).float()
                label = torch.from_numpy(np.arange(num_of_compression)).to(self.device).long()
                score_t = scores[timestep][batch_idx].unsqueeze(0)  # comp num
                score_t = score_t.expand(num_of_compression, -1)
                # label = span_seq_label[batch_idx][logit].unsqueeze(0)

                loss = self.XEloss(score_t, label)
                # print(loss)
                loss = _supervision_label_msk * loss
                total_loss[timestep][batch_idx] = torch.sum(loss)
                # sent_msk_t = two_dim_index_select(sent_mask, logit)

        return torch.mean(total_loss)

    def _dec_compression_one_step(self, predict_compression,
                                  sp_meta,
                                  word_sent: List[str], keep_threshold: List[float],
                                  context: List[List[str]] = None):

        full_set_len = set(range(len(word_sent)))
        # max_comp, _ = predict_compression.size

        preds = [full_set_len.copy() for _ in range(len(keep_threshold))]

        # Show all of the compression spans
        stat_compression = {}
        for comp_idx, comp_meta in enumerate(sp_meta):
            p = predict_compression[comp_idx][1]
            node_type, sel_idx, rouge, ratio = comp_meta
            if node_type != "BASELINE":
                selected_words = [x for idx, x in enumerate(word_sent) if idx in sel_idx]
                selected_words_str = "_".join(selected_words)
                stat_compression["{}".format(selected_words_str)] = {
                    "prob": float("{0:.2f}".format(p)),  # float("{0:.2f}".format())
                    "type": node_type,
                    "rouge": float("{0:.2f}".format(rouge)),
                    "ratio": float("{0:.2f}".format(ratio)),
                    "sel_idx": sel_idx,
                    "len": len(sel_idx)
                }
        stat_compression_order = OrderedDict(
            sorted(stat_compression.items(), key=lambda item: item[1]["prob"], reverse=True))  # Python 3
        for idx, _keep_thres in enumerate(keep_threshold):
            history: List[str] = context[idx]
            his_set = set((" ".join(history)).split(" "))
            for key, value in stat_compression_order.items():
                p = value['prob']
                sel_idx = value['sel_idx']
                sel_txt = set([word_sent[x] for x in sel_idx])
                if sel_txt - his_set == set():
                    # print("Save big!")
                    # print("Context: {}\tCandidate: {}".format(his_set, sel_txt))
                    preds[idx] = preds[idx] - set(value['sel_idx'])
                    continue
                if p > _keep_thres:
                    preds[idx] = preds[idx] - set(value['sel_idx'])

        preds = [list(x) for x in preds]
        for pred in preds:
            pred.sort()
        # Visual output
        visual_outputs: List[str] = []
        words_for_evaluation: List[str] = []
        meta_keep_ratio_word = []

        for idx, compression in enumerate(preds):
            output = [word_sent[jdx] if (jdx in compression) else '_' + word_sent[jdx] + '_' for jdx in
                      range(len(word_sent))]
            visual_outputs.append(" ".join(output))

            words = [word_sent[x] for x in compression]
            meta_keep_ratio_word.append(float(len(words) / len(word_sent)))
            # meta_kepp_ratio_span.append(1 - float(len(survery['type'][idx]) / len(sp_meta)))
            words = " ".join(words)
            words = easy_post_processing(words)
            # print(words)
            words_for_evaluation.append(words)
        d: List[List] = []
        for kep_th, vis, words_eva, keep_word_ratio in zip(keep_threshold, visual_outputs, words_for_evaluation,
                                                           meta_keep_ratio_word):
            d.append([kep_th, vis, words_eva, keep_word_ratio])
        return stat_compression_order, d

    def decode_inf_deletion(self,
                            sent_decoder_outputs_logit,  # time, batch
                            span_prob,  # time, batch, max_comp, 2
                            metadata: List,
                            span_meta: List,
                            span_rouge,  # batch, sent, max_comp
                            keep_threshold: List[float]
                            ):
        batch_size, max_sent_num, max_comp_num = span_rouge.size()
        t, batsz, max_comp, _ = span_prob.size()
        span_score = torch.nn.functional.softmax(span_prob, dim=3).cpu().numpy()
        timestep, batch = sent_decoder_outputs_logit.size()
        sent_decoder_outputs_logit = sent_decoder_outputs_logit.cpu().data

        for idx, m in enumerate(metadata):
            abs_s = [" ".join(s) for s in m["abs_list"]]
            comp_exe = CompExecutor(span_meta=span_meta[idx],
                                    sent_idxs=sent_decoder_outputs_logit[:, idx],
                                    prediction_score=span_score[:, idx, :, :],
                                    abs_str=abs_s,
                                    name=m['name'],
                                    doc_list=m["doc_list"],
                                    keep_threshold=keep_threshold,
                                    part=m['name'], ser_dir=self.valid_tmp_path,
                                    ser_fname=self.serilization_name
                                    )
            # processed_words, del_record, \
            # compressions, full_sents, \
            bag_pred_eval = comp_exe.run()
            full_sents: List[List[str]] = comp_exe.full_sents
            # assemble full sents
            full_sents = [" ".join(x) for x in full_sents]

            # visual to console
            for idx in range(len(keep_threshold)):
                self.rouge_metrics_compression_dict["{}".format(keep_threshold[idx])](pred=bag_pred_eval[idx],
                                                                                      ref=[abs_s], origin=full_sents
                                                                                      )
예제 #2
0
class CitationRanker(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 title_embedder: TextFieldEmbedder,
                 abstract_embedder: TextFieldEmbedder,
                 dense_dim=75) -> None:

        super().__init__(vocab)

        self.title_embedder = title_embedder
        self.abstract_embedder = abstract_embedder
        self.intermediate_dim = 6
        self.n_layers = 3
        self.layer_dims = [dense_dim for i in range(self.n_layers - 1)]
        self.layer_dims.append(1)

        self.activations = [
            Activation.by_name("elu")(),
            Activation.by_name("elu")(),
            Activation.by_name("sigmoid")()
        ]
        self.layers = FeedForward(input_dim=self.intermediate_dim,
                                  num_layers=self.n_layers,
                                  hidden_dims=self.layer_dims,
                                  activations=self.activations)

    def forward(self,
                query_title: Dict[str, tensor],
                query_abstract: Dict[str, tensor],
                candidate_title: Dict[str, tensor],
                candidate_abstract: Dict[str, tensor],
                candidate_citations: tensor,
                title_intersection: tensor,
                abstract_intersection: tensor,
                cos_sim: tensor,
                label: tensor = None) -> Dict[str, tensor]:

        query_title_embed = self.title_embedder.forward(query_title["tokens"])
        query_abstract_embed = self.abstract_embedder.forward(
            query_abstract["tokens"])

        candidate_title_embed = self.title_embedder.forward(
            candidate_title["tokens"])
        candidate_abstract_embed = self.abstract_embedder.forward(
            candidate_title["tokens"])

        title_cos_sim = CosineSimilarity().forward(
            query_title_embed, candidate_title_embed).unsqueeze(-1)

        abstract_cos_sim = CosineSimilarity().forward(
            query_abstract_embed, candidate_abstract_embed).unsqueeze(-1)

        intermediate_output = cat(
            (title_cos_sim, abstract_cos_sim, candidate_citations,
             title_intersection, abstract_intersection, cos_sim),
            dim=-1)

        pred = self.layers.forward(intermediate_output)

        output = {"cite_prob": pred}

        if label is not None:
            output["loss"] = self._compute_loss(pred, label)

        return output

    def _compute_loss(self, pred: tensor, label: tensor) -> tensor:
        #in existing implementation training examples with even indices are positive/odd indices are negative
        positive = pred[::2]
        negative = pred[1::2]

        #"margin is given by the difference in label"
        margin = label[::2] - label[1::2]
        delta = clamp(margin + negative - positive, min=0)

        return mean(delta)