Ejemplo n.º 1
0
class BeamSearchDecoder:
    def __init__(self, model):

        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % ("model2"))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=1,
                               single_pass=True)
        self.model = model

    def beam_search(self, batch, conf):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros = helper.prepare_src_batch(
            batch, conf)
        encoder_output = self.model.encode(enc_batch, enc_padding_mask)
        hyps_list = [
            Hypothesis(tokens=[self.vocab.word2id(data.START_DECODING)],
                       log_probs=[0.0]) for _ in range(1)
        ]
        results = []
        steps = 0

        while steps < config.max_dec_len and len(results) < config.beam_size:
            hyp_tokens = [h.tokens for h in hyps_list]
            np_hyp_tokens = np.asarray(hyp_tokens)
            np_hyp_tokens[
                np_hyp_tokens >= self.vocab.size()] = self.vocab.word2id(
                    data.UNKNOWN_TOKEN)
            yt = torch.LongTensor(np_hyp_tokens).to(device)
            temp_enc_out = encoder_output.repeat(yt.size(0), 1, 1)
            out, _ = self.model.decode(temp_enc_out, yt, enc_padding_mask,
                                       helper.subsequent_mask(yt.size(-1)))
            extra_zeros_ip = None
            if extra_zeros is not None:
                extra_zeros_ip = extra_zeros[:, 0:steps + 1, :].repeat(
                    yt.size(0), 1, 1)
            if conf.coverage:
                log_probs, _ = self.model.generator(out, temp_enc_out,
                                                    enc_padding_mask,
                                                    enc_batch_extend_vocab,
                                                    extra_zeros_ip)
            else:
                log_probs = self.model.generator(out, temp_enc_out,
                                                 enc_padding_mask,
                                                 enc_batch_extend_vocab,
                                                 extra_zeros_ip)
            log_probs = log_probs.squeeze(1)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)
            if len(topk_log_probs.size()) == 3:
                topk_log_probs = topk_log_probs[:, -1, :].squeeze(1)
                topk_ids = topk_ids[:, -1, :].squeeze(1)
            all_hyps = []
            num_orig_hyps = 1 if steps == 0 else len(hyps_list)
            for i in range(num_orig_hyps):
                h = hyps_list[i]
                # print(h.tokens)
                for j in range(config.beam_size *
                               2):  # for each of the top beam_size hyps:
                    hyp = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item())
                    all_hyps.append(hyp)
            hyps_list = []
            sorted_hyps = sorted(all_hyps,
                                 key=lambda h: h.avg_log_prob,
                                 reverse=True)
            for h in sorted_hyps:
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    hyps_list.append(h)
                if len(hyps_list) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = hyps_list

        results_sorted = sorted(results,
                                key=lambda h: h.avg_log_prob,
                                reverse=True)
        return results_sorted[0]

    def decode(self, conf):

        self.model.eval()
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()

        i = 0
        while batch is not None:

            i += 1
            if i % 10 == 0:
                print(i)
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch, conf)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            # print(output_ids)
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]

            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
Ejemplo n.º 2
0
class BeamSearch(object):
    def __init__(self, args, model_file_path, save_path):
        model_name = os.path.basename(model_file_path)
        self.args = args
        self._decode_dir = os.path.join(config.log_root, save_path,
                                        'decode_%s' % (model_name))
        self._structures_dir = os.path.join(self._decode_dir, 'structures')
        self._sent_single_heads_dir = os.path.join(self._decode_dir,
                                                   'sent_heads_preds')
        self._sent_single_heads_ref_dir = os.path.join(self._decode_dir,
                                                       'sent_heads_ref')
        self._contsel_dir = os.path.join(self._decode_dir, 'content_sel_preds')
        self._contsel_ref_dir = os.path.join(self._decode_dir,
                                             'content_sel_ref')
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        self._rouge_ref_file = os.path.join(self._decode_dir, 'rouge_ref.json')
        self._rouge_pred_file = os.path.join(self._decode_dir,
                                             'rouge_pred.json')
        self.stat_res_file = os.path.join(self._decode_dir, 'stats.txt')
        self.sent_count_file = os.path.join(self._decode_dir,
                                            'sent_used_counts.txt')
        for p in [
                self._decode_dir, self._structures_dir,
                self._sent_single_heads_ref_dir, self._sent_single_heads_dir,
                self._contsel_ref_dir, self._contsel_dir, self._rouge_ref_dir,
                self._rouge_dec_dir
        ]:
            if not os.path.exists(p):
                os.mkdir(p)
        vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path
        self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file,
                           args)
        self.batcher = Batcher(args.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=args.beam_size,
                               single_pass=True,
                               args=args)
        self.batcher.setup_queues()
        time.sleep(30)

        self.model = Model(args, self.vocab).to(device)
        self.model.eval()

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def extract_structures(self, batch, sent_attention_matrix,
                           doc_attention_matrix, count, use_cuda, sent_scores):
        fileName = os.path.join(self._structures_dir,
                                "%06d_struct.txt" % count)
        fp = open(fileName, "w")
        fp.write("Doc: " + str(count) + "\n")
        #exit(0)
        doc_attention_matrix = doc_attention_matrix[:, :]  #this change yet to be tested!
        l = batch.enc_doc_lens[0].item()
        doc_sent_no = 0

        # for i in range(l):
        #     printstr = ''
        #     sent = batch.enc_batch[0][i]
        #     #scores = str_scores_sent[sent_no][0:l, 0:l]
        #     token_count = 0
        #     for j in range(batch.enc_sent_lens[0][i].item()):
        #         token = sent[j].item()
        #         printstr += self.vocab.id2word(token)+" "
        #         token_count = token_count + 1
        #     #print(printstr)
        #     fp.write(printstr+"\n")
        #
        #     scores = sent_attention_matrix[doc_sent_no][0:token_count, 0:token_count]
        #     shape2 = sent_attention_matrix[doc_sent_no][0:token_count,0:token_count].size()
        #     row = torch.ones([1, shape2[1]+1]).cuda()
        #     column = torch.zeros([shape2[0], 1]).cuda()
        #     new_scores = torch.cat([column, scores], dim=1)
        #     new_scores = torch.cat([row, new_scores], dim=0)
        #
        #     heads, tree_score = chu_liu_edmonds(new_scores.data.cpu().numpy().astype(np.float64))
        #     #print(heads, tree_score)
        #     fp.write(str(heads)+" ")
        #     fp.write(str(tree_score)+"\n")
        #     doc_sent_no+=1

        shape2 = doc_attention_matrix[0:l, 0:l + 1].size()
        row = torch.zeros([1, shape2[1]]).cuda()
        #column = torch.zeros([shape2[0], 1]).cuda()
        scores = doc_attention_matrix[0:l, 0:l + 1]
        #new_scores = torch.cat([column, scores], dim=1)
        new_scores = torch.cat([row, scores], dim=0)
        val, root_edge = torch.max(new_scores[:, 0], dim=0)
        root_score = torch.zeros([shape2[0] + 1, 1]).cuda()
        root_score[root_edge] = 1
        new_scores[:, 0] = root_score.squeeze()
        #print(new_scores)
        #print(new_scores.sum(dim=0))
        #print(new_scores.sum(dim=1))
        #print(new_scores.size())
        heads, tree_score = chu_liu_edmonds(
            new_scores.data.cpu().numpy().astype(np.float64))
        height = find_height(heads)
        leaf_nodes = leaf_node_proportion(heads)
        #print(heads, tree_score)
        fp.write("\n")
        sentences = str(batch.original_articles[0]).split("<split1>")
        for idx, sent in enumerate(sentences):
            fp.write(str(idx) + "\t" + str(sent) + "\n")
        #fp.write(str("\n".join(batch.original_articles[0].split("<split1>"))+"\n")
        fp.write(str(heads) + " ")
        fp.write(str(tree_score) + "\n")
        fp.write(str(height) + "\n")
        s = sent_scores[0].data.cpu().numpy()
        for val in s:
            fp.write(str(val))
        fp.close()
        #exit()
        structure_info = dict()
        structure_info['heads'] = heads
        structure_info['height'] = height
        structure_info['leaf_nodes'] = leaf_nodes
        return structure_info

    def decode(self):
        start = time.time()
        counter = 0
        sent_counter = []
        avg_max_seq_len_list = []
        copied_sequence_len = Counter()
        copied_sequence_per_sent = []
        article_copy_id_count_tot = Counter()
        sentence_copy_id_count = Counter()
        novel_counter = Counter()
        repeated_counter = Counter()
        summary_sent_count = Counter()
        summary_sent = []
        article_sent = []
        summary_len = []
        abstract_ref = []
        abstract_pred = []
        sentence_count = []
        tot_sentence_id_count = Counter()
        height_avg = []
        leaf_node_proportion_avg = []
        precision_tree_dist = []
        recall_tree_dist = []
        batch = self.batcher.next_batch()
        height_counter = Counter()
        leaf_nodes_counter = Counter()
        sent_count_fp = open(self.sent_count_file, 'w')

        counts = {
            'token_consel_num_correct': 0,
            'token_consel_num': 0,
            'sent_single_heads_num_correct': 0,
            'sent_single_heads_num': 0,
            'sent_all_heads_num_correct': 0,
            'sent_all_heads_num': 0,
            'sent_all_child_num_correct': 0,
            'sent_all_child_num': 0
        }
        no_batches_processed = 0
        while batch is not None:
            # Run beam search to get best Hypothesis
            #start = time.process_time()
            has_summary, best_summary, sample_predictions, sample_counts, structure_info, adj_mat = self.get_decoded_outputs(
                batch, counter)
            #print('Time taken for decoder: ', time.process_time() - start)
            # token_contsel_tot_correct += token_consel_num_correct
            # token_contsel_tot_num += token_consel_num
            # sent_heads_tot_correct += sent_heads_num_correct
            # sent_heads_tot_num += sent_heads_num

            if args.predict_contsel_tags:
                no_words = batch.enc_word_lens[0]
                prediction = sample_predictions['token_contsel_prediction'][
                    0:no_words]
                ref = batch.contsel_tags[0]
                write_tags(prediction, ref, counter, self._contsel_dir,
                           self._contsel_ref_dir)
                counts['token_consel_num_correct'] += sample_counts[
                    'token_consel_num_correct']
                counts['token_consel_num'] += sample_counts['token_consel_num']

            if args.predict_sent_single_head:
                no_sents = batch.enc_doc_lens[0]
                prediction = sample_predictions[
                    'sent_single_heads_prediction'][0:no_sents].tolist()
                ref = batch.original_parent_heads[0]
                write_tags(prediction, ref, counter,
                           self._sent_single_heads_dir,
                           self._sent_single_heads_ref_dir)
                counts['sent_single_heads_num_correct'] += sample_counts[
                    'sent_single_heads_num_correct']
                counts['sent_single_heads_num'] += sample_counts[
                    'sent_single_heads_num']

            if args.predict_sent_all_head:
                counts['sent_all_heads_num_correct'] += sample_counts[
                    'sent_all_heads_num_correct']
                counts['sent_all_heads_num'] += sample_counts[
                    'sent_all_heads_num']

            if args.predict_sent_all_child:
                counts['sent_all_child_num_correct'] += sample_counts[
                    'sent_all_child_num_correct']
                counts['sent_all_child_num'] += sample_counts[
                    'sent_all_child_num']

            if has_summary == False:
                batch = self.batcher.next_batch()
                continue
            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if self.args.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            summary_len.append(len(decoded_words))
            assert adj_mat is not None, "Explicit matrix is none."
            assert structure_info['heads'] is not None, "Heads is none."
            precision, recall = tree_distance(
                structure_info['heads'],
                adj_mat.cpu().data.numpy()[0, :, :])
            if precision is not None and recall is not None:
                precision_tree_dist.append(precision)
                recall_tree_dist.append(recall)
            height_counter[structure_info['height']] += 1
            height_avg.append(structure_info['height'])
            leaf_node_proportion_avg.append(structure_info['leaf_nodes'])
            leaf_nodes_counter[np.floor(structure_info['leaf_nodes'] *
                                        10)] += 1
            abstract_ref.append(" ".join(original_abstract_sents))
            abstract_pred.append(" ".join(decoded_words))

            sent_res = get_sent_dist(" ".join(decoded_words),
                                     batch.original_articles[0].decode(),
                                     minimum_seq=self.args.minimum_seq)

            sent_counter.append(
                (sent_res['seen_sent'], sent_res['article_sent']))
            summary_len.append(sent_res['summary_len'])
            summary_sent.append(sent_res['summary_sent'])
            summary_sent_count[sent_res['summary_sent']] += 1
            article_sent.append(sent_res['article_sent'])
            if sent_res['avg_copied_seq_len'] is not None:
                avg_max_seq_len_list.append(sent_res['avg_copied_seq_len'])
                copied_sequence_per_sent.append(
                    np.average(
                        list(sent_res['counter_summary_sent_id'].values())))
            copied_sequence_len.update(sent_res['counter_copied_sequence_len'])
            sentence_copy_id_count.update(sent_res['counter_summary_sent_id'])
            article_copy_id_count_tot.update(
                sent_res['counter_article_sent_id'])
            novel_counter.update(sent_res['novel_ngram_counter'])
            repeated_counter.update(sent_res['repeated_ngram_counter'])

            sent_count_fp.write(
                str(counter) + "\t" + str(sent_res['article_sent']) + "\t" +
                str(sent_res['seen_sent']) + "\n")
            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)

            batch = self.batcher.next_batch()

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()
            #print('Time taken for rest: ', time.process_time() - start)
            if args.decode_for_subset:
                if counter == 1000:
                    break

        print("Decoder has finished reading dataset for single_pass.")

        fp = open(self.stat_res_file, 'w')
        percentages = [
            float(len(seen_sent)) / float(sent_count)
            for seen_sent, sent_count in sent_counter
        ]
        avg_percentage = sum(percentages) / float(len(percentages))
        nosents = [len(seen_sent) for seen_sent, sent_count in sent_counter]
        avg_nosents = sum(nosents) / float(len(nosents))

        res = dict()
        res['avg_percentage_seen_sent'] = avg_percentage
        res['avg_nosents'] = avg_nosents
        res['summary_len'] = summary_sent_count
        res['avg_summary_len'] = np.average(summary_len)
        res['summary_sent'] = np.average(summary_sent)
        res['article_sent'] = np.average(article_sent)
        res['avg_copied_seq_len'] = np.average(avg_max_seq_len_list)
        res['avg_sequences_per_sent'] = np.average(copied_sequence_per_sent)
        res['counter_copied_sequence_len'] = copied_sequence_len
        res['counter_summary_sent_id'] = sentence_copy_id_count
        res['counter_article_sent_id'] = article_copy_id_count_tot
        res['novel_ngram_counter'] = novel_counter
        res['repeated_ngram_counter'] = repeated_counter

        fp.write("Summary metrics\n")
        for key in res:
            fp.write('{}: {}\n'.format(key, res[key]))

        fp.write("Structures metrics\n")
        fp.write("Average depth of RST tree: " +
                 str(sum(height_avg) / len(height_avg)) + "\n")
        fp.write("Average proportion of leaf nodes in RST tree: " + str(
            sum(leaf_node_proportion_avg) / len(leaf_node_proportion_avg)) +
                 "\n")
        fp.write("Precision of edges latent to explicit: " +
                 str(np.average(precision_tree_dist)) + "\n")
        fp.write("Recall of edges latent to explicit: " +
                 str(np.average(recall_tree_dist)) + "\n")
        fp.write("Tree height counter:\n")
        fp.write(str(height_counter) + "\n")
        fp.write("Tree leaf proportion counter:")
        fp.write(str(leaf_nodes_counter) + "\n")

        if args.predict_contsel_tags:
            fp.write("Avg token_contsel: " +
                     str((counts['token_consel_num_correct'] /
                          float(counts['token_consel_num']))))
        if args.predict_sent_single_head:
            fp.write("Avg single sent heads: " +
                     str((counts['sent_single_heads_num_correct'] /
                          float(counts['sent_single_heads_num']))))
        if args.predict_sent_all_head:
            fp.write("Avg all sent heads: " +
                     str((counts['sent_all_heads_num_correct'] /
                          float(counts['sent_all_heads_num']))))
        if args.predict_sent_all_child:
            fp.write("Avg all sent child: " +
                     str((counts['sent_all_child_num_correct'] /
                          float(counts['sent_all_child_num']))))
        fp.close()
        sent_count_fp.close()

        write_to_json_file(abstract_ref, self._rouge_ref_file)
        write_to_json_file(abstract_pred, self._rouge_pred_file)

    def get_decoded_outputs(self, batch, count):
        #batch should have only one example
        enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, \
            enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, word_batch, word_padding_mask, enc_word_lens, \
                enc_tags_batch, enc_sent_tags, enc_sent_token_mat, adj_mat, weighted_adj_mat, norm_adj_mat, \
                    parent_heads, undir_weighted_adj_mat = get_input_from_batch(batch, use_cuda, self.args)

        enc_adj_mat = adj_mat
        if args.use_weighted_annotations:
            if args.use_undirected_weighted_graphs:
                enc_adj_mat = undir_weighted_adj_mat
            else:
                enc_adj_mat = weighted_adj_mat

        encoder_output = self.model.encoder.forward_test(
            enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask,
            enc_padding_sent_mask, word_batch, word_padding_mask,
            enc_word_lens, enc_tags_batch, enc_sent_token_mat, enc_adj_mat)
        encoder_outputs, enc_padding_mask, encoder_last_hidden, max_encoder_output, \
        enc_batch_extend_vocab, token_level_sentence_scores, sent_outputs, token_scores, \
        sent_scores, sent_matrix, sent_level_rep = \
                                    self.model.get_app_outputs(encoder_output, enc_padding_token_mask,
                                                   enc_padding_sent_mask, enc_batch_extend_vocab, enc_sent_token_mat)

        mask = enc_padding_sent_mask[0].unsqueeze(0).repeat(
            enc_padding_sent_mask.size(1),
            1) * enc_padding_sent_mask[0].unsqueeze(1).transpose(1, 0)

        mask = torch.cat((enc_padding_sent_mask[0].unsqueeze(1), mask), dim=1)
        mat = encoder_output['sent_attention_matrix'][0][:, :] * mask

        structure_info = self.extract_structures(
            batch, encoder_output['token_attention_matrix'], mat, count,
            use_cuda, encoder_output['sent_score'])

        counts = {}
        predictions = {}
        if args.predict_contsel_tags:
            pred = encoder_output['token_score'][0, :, :].view(-1, 2)
            token_contsel_gold = enc_tags_batch[0, :].view(-1)
            token_contsel_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            token_contsel_prediction[
                token_contsel_gold ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            token_consel_num_correct = torch.sum(
                token_contsel_prediction.eq(token_contsel_gold)).item()
            token_consel_num = torch.sum(token_contsel_gold != -1).item()
            predictions['token_contsel_prediction'] = token_contsel_prediction
            counts['token_consel_num_correct'] = token_consel_num_correct
            counts['token_consel_num'] = token_consel_num

        if args.predict_sent_single_head:
            pred = encoder_output['sent_single_head_scores'][0, :, :]
            head_labels = parent_heads[0, :].view(-1)
            sent_single_heads_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            sent_single_heads_prediction[
                head_labels ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            sent_single_heads_num_correct = torch.sum(
                sent_single_heads_prediction.eq(head_labels)).item()
            sent_single_heads_num = torch.sum(head_labels != -1).item()
            predictions[
                'sent_single_heads_prediction'] = sent_single_heads_prediction
            counts[
                'sent_single_heads_num_correct'] = sent_single_heads_num_correct
            counts['sent_single_heads_num'] = sent_single_heads_num

        if args.predict_sent_all_head:
            pred = encoder_output['sent_all_head_scores'][0, :, :, :]
            target = adj_mat[0, :, :].permute(0, 1).view(-1)
            sent_all_heads_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            sent_all_heads_prediction[
                target ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            sent_all_heads_num_correct = torch.sum(
                sent_all_heads_prediction.eq(target)).item()
            sent_all_heads_num = torch.sum(target != -1).item()
            predictions[
                'sent_all_heads_prediction'] = sent_all_heads_prediction
            counts['sent_all_heads_num_correct'] = sent_all_heads_num_correct
            counts['sent_all_heads_num'] = sent_all_heads_num

        if args.predict_sent_all_child:
            pred = encoder_output['sent_all_child_scores'][0, :, :, :]
            target = adj_mat[0, :, :].view(-1)
            sent_all_child_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            sent_all_child_prediction[
                target ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            sent_all_child_num_correct = torch.sum(
                sent_all_child_prediction.eq(target)).item()
            sent_all_child_num = torch.sum(target != -1).item()
            predictions[
                'sent_all_child_prediction'] = sent_all_child_prediction
            counts['sent_all_child_num_correct'] = sent_all_child_num_correct
            counts['sent_all_child_num'] = sent_all_child_num

        results = []
        steps = 0
        has_summary = False
        beams_sorted = [None]
        if args.predict_summaries:
            has_summary = True

            if (args.fixed_scorer):
                scorer_output = self.model.module.pretrained_scorer.forward_test(
                    enc_batch, enc_sent_lens, enc_doc_lens,
                    enc_padding_token_mask, enc_padding_sent_mask, word_batch,
                    word_padding_mask, enc_word_lens, enc_tags_batch)
                token_scores = scorer_output['token_score']
                sent_scores = scorer_output['sent_score'].unsqueeze(1).repeat(
                    1, enc_padding_token_mask.size(2), 1, 1).view(
                        enc_padding_token_mask.size(0),
                        enc_padding_token_mask.size(1) *
                        enc_padding_token_mask.size(2))

            all_child, all_head = None, None
            if args.use_gold_annotations_for_decode:
                if args.use_weighted_annotations:
                    if args.use_undirected_weighted_graphs:
                        permuted_all_head = undir_weighted_adj_mat[:, :, :].permute(
                            0, 2, 1)
                        all_head = permuted_all_head.clone()
                        row_sums = torch.sum(permuted_all_head,
                                             dim=2,
                                             keepdim=True)
                        all_head[row_sums.expand_as(
                            permuted_all_head) != 0] = permuted_all_head[
                                row_sums.expand_as(permuted_all_head) !=
                                0] / row_sums.expand_as(permuted_all_head)[
                                    row_sums.expand_as(permuted_all_head) != 0]

                        base_all_child = undir_weighted_adj_mat[:, :, :]
                        all_child = base_all_child.clone()
                        row_sums = torch.sum(base_all_child,
                                             dim=2,
                                             keepdim=True)
                        all_child[row_sums.expand_as(
                            base_all_child) != 0] = base_all_child[
                                row_sums.expand_as(base_all_child) !=
                                0] / row_sums.expand_as(base_all_child)[
                                    row_sums.expand_as(base_all_child) != 0]
                    else:
                        permuted_all_head = weighted_adj_mat[:, :, :].permute(
                            0, 2, 1)
                        all_head = permuted_all_head.clone()
                        row_sums = torch.sum(permuted_all_head,
                                             dim=2,
                                             keepdim=True)
                        all_head[row_sums.expand_as(
                            permuted_all_head) != 0] = permuted_all_head[
                                row_sums.expand_as(permuted_all_head) !=
                                0] / row_sums.expand_as(permuted_all_head)[
                                    row_sums.expand_as(permuted_all_head) != 0]

                        base_all_child = weighted_adj_mat[:, :, :]
                        all_child = base_all_child.clone()
                        row_sums = torch.sum(base_all_child,
                                             dim=2,
                                             keepdim=True)
                        all_child[row_sums.expand_as(
                            base_all_child) != 0] = base_all_child[
                                row_sums.expand_as(base_all_child) !=
                                0] / row_sums.expand_as(base_all_child)[
                                    row_sums.expand_as(base_all_child) != 0]
                else:
                    permuted_all_head = adj_mat[:, :, :].permute(0, 2, 1)
                    all_head = permuted_all_head.clone()
                    row_sums = torch.sum(permuted_all_head,
                                         dim=2,
                                         keepdim=True)
                    all_head[row_sums.expand_as(
                        permuted_all_head) != 0] = permuted_all_head[
                            row_sums.expand_as(permuted_all_head) !=
                            0] / row_sums.expand_as(permuted_all_head)[
                                row_sums.expand_as(permuted_all_head) != 0]

                    base_all_child = adj_mat[:, :, :]
                    all_child = base_all_child.clone()
                    row_sums = torch.sum(base_all_child, dim=2, keepdim=True)
                    all_child[row_sums.expand_as(base_all_child) !=
                              0] = base_all_child[
                                  row_sums.expand_as(base_all_child) !=
                                  0] / row_sums.expand_as(base_all_child)[
                                      row_sums.expand_as(base_all_child) != 0]
                    # all_head = adj_mat[:, :, :].permute(0,2,1) + 0.00005
                    # row_sums = torch.sum(all_head, dim=2, keepdim=True)
                    # all_head = all_head / row_sums
                    # all_child = adj_mat[:, :, :] + 0.00005
                    # row_sums = torch.sum(all_child, dim=2, keepdim=True)
                    # all_child = all_child / row_sums

            s_t_0 = self.model.reduce_state(encoder_last_hidden)

            if config.use_maxpool_init_ctx:
                c_t_0 = max_encoder_output

            dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            #decoder batch preparation, it has beam_size example initially everything is repeated
            beams = [
                Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                     log_probs=[0.0],
                     state=(dec_h[0], dec_c[0]),
                     context=c_t_0[0],
                     coverage=(coverage_t_0[0] if self.args.is_coverage
                               or self.args.bu_coverage_penalty else None))
                for _ in range(args.beam_size)
            ]

            while steps < args.max_dec_steps and len(results) < args.beam_size:
                latest_tokens = [h.latest_token for h in beams]
                # cur_len = torch.stack([len(h.tokens) for h in beams])
                latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                 for t in latest_tokens]
                y_t_1 = Variable(torch.LongTensor(latest_tokens))
                if use_cuda:
                    y_t_1 = y_t_1.cuda()
                all_state_h = []
                all_state_c = []

                all_context = []

                for h in beams:
                    state_h, state_c = h.state
                    all_state_h.append(state_h)
                    all_state_c.append(state_c)

                    all_context.append(h.context)

                s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0),
                         torch.stack(all_state_c, 0).unsqueeze(0))
                c_t_1 = torch.stack(all_context, 0)

                coverage_t_1 = None
                if self.args.is_coverage or self.args.bu_coverage_penalty:
                    all_coverage = []
                    for h in beams:
                        all_coverage.append(h.coverage)
                    coverage_t_1 = torch.stack(all_coverage, 0)

                final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                    y_t_1, s_t_1, encoder_outputs, word_padding_mask, c_t_1,
                    extra_zeros, enc_batch_extend_vocab, coverage_t_1,
                    token_scores, sent_scores, sent_outputs,
                    enc_sent_token_mat, all_head, all_child, sent_level_rep)

                if args.bu_coverage_penalty:
                    penalty = torch.max(coverage_t,
                                        coverage_t.clone().fill_(1.0)).sum(-1)
                    penalty -= coverage_t.size(-1)
                    final_dist -= args.beta * penalty.unsqueeze(1).expand_as(
                        final_dist)
                if args.bu_length_penalty:
                    penalty = ((5 + steps + 1) / 6.0)**args.alpha
                    final_dist = final_dist / penalty

                topk_log_probs, topk_ids = torch.topk(final_dist,
                                                      args.beam_size * 2)

                dec_h, dec_c = s_t
                dec_h = dec_h.squeeze()
                dec_c = dec_c.squeeze()

                all_beams = []
                num_orig_beams = 1 if steps == 0 else len(beams)
                for i in range(num_orig_beams):
                    h = beams[i]
                    state_i = (dec_h[i], dec_c[i])
                    context_i = c_t[i]
                    coverage_i = (coverage_t[i] if self.args.is_coverage
                                  or self.args.bu_coverage_penalty else None)

                    for j in range(args.beam_size *
                                   2):  # for each of the top 2*beam_size hyps:
                        new_beam = h.extend(token=topk_ids[i, j].item(),
                                            log_prob=topk_log_probs[i,
                                                                    j].item(),
                                            state=state_i,
                                            context=context_i,
                                            coverage=coverage_i)
                        all_beams.append(new_beam)

                beams = []
                for h in self.sort_beams(all_beams):
                    if h.latest_token == self.vocab.word2id(
                            data.STOP_DECODING):
                        if steps >= config.min_dec_steps:
                            results.append(h)
                    else:
                        beams.append(h)
                    if len(beams) == args.beam_size or len(
                            results) == args.beam_size:
                        break

                steps += 1

            if len(results) == 0:
                results = beams

            beams_sorted = self.sort_beams(results)

        return has_summary, beams_sorted[
            0], predictions, counts, structure_info, undir_weighted_adj_mat