Ejemplo n.º 1
0
def test(model, path, vocab):
    model.load_state_dict(torch.load(path))
    model.eval()

    beam_search = LSTMBeamSearch(conf.beam_size, conf.vocab_size, conf.max_decode_len, model)
    batcher = Batcher(config.decode_data_path, vocab, mode='decode', batch_size=1, single_pass=True)

    counter = 0
    batch = batcher.next_batch()

    while batch is not None:
        input_ids, input_mask, input_lens, extended_input_ids, extra_zeros = prepare_src_batch(batch)
        best_summary = beam_search.generate(input_ids, extended_input_ids, extra_zeros)
        output_ids = [int(t) for t in best_summary.tokens[1:]]
        decoded_words = outputids2words(output_ids, vocab, batch.art_oovs[0])

        try:
            fst_stop_idx = decoded_words.index(STOP_DECODING)
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words

        write_for_rouge(batch.original_abstracts_sents[0], decoded_words, counter,
                        conf.rouge_ref_dir, conf.rouge_dec_dir)
        batch = batcher.next_batch()
        counter += 1

    results_dict = rouge_eval(conf.rouge_ref_dir, conf.rouge_dec_dir)
    rouge_log(results_dict, conf.decode_dir)
Ejemplo n.º 2
0
class BeamSearchDecoder:
    def __init__(self, model):

        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % ("model_bert_coverage"))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = VocabBert(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        self.model = model

    def beam_search(self, batch, conf):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros = helper.prepare_src_batch(
            batch)
        encoder_output, _ = self.model.encoder.forward(
            enc_batch, enc_padding_mask.squeeze(1))

        hyps_list = [
            Hypothesis(tokens=[self.vocab.word2id(data.START_DECODING)],
                       log_probs=[0.0]) for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        yt = torch.zeros(config.beam_size, 1).long().to(device)
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in hyps_list]
            latest_tokens = [
                t if t < self.vocab.size() else self.vocab.word2id(
                    data.UNKNOWN_TOKEN) for t in latest_tokens
            ]

            curr_yt = torch.LongTensor(latest_tokens).unsqueeze(1).to(
                device)  # [Bx1]
            yt = torch.cat((yt, curr_yt), dim=1)

            out, _ = self.model.decode(
                encoder_output, yt[:, 1:], enc_padding_mask,
                helper.subsequent_mask(yt[:, 1:].size(-1)))
            extra_zeros_ip = None
            if extra_zeros is not None:
                extra_zeros_ip = extra_zeros[:, 0:steps + 1, :]

            if config.coverage:
                op_dist, _ = self.model.generator(out, encoder_output,
                                                  enc_padding_mask,
                                                  enc_batch_extend_vocab,
                                                  extra_zeros_ip)
            else:
                op_dist = self.model.generator(out, encoder_output,
                                               enc_padding_mask,
                                               enc_batch_extend_vocab,
                                               extra_zeros_ip)

            log_probs = op_dist[:, -1, :]
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            all_hyps = []
            num_orig_hyps = 1 if steps == 0 else len(hyps_list)

            for i in range(num_orig_hyps):
                h = hyps_list[i]

                for j in range(config.beam_size *
                               2):  # for each of the top beam_size hyps:
                    hyp = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item())
                    all_hyps.append(hyp)

            hyps_list = []
            sorted_hyps = sorted(all_hyps,
                                 key=lambda h: h.avg_log_prob,
                                 reverse=True)
            for h in sorted_hyps:
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    hyps_list.append(h)
                if len(hyps_list) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = hyps_list

        results_sorted = sorted(results,
                                key=lambda h: h.avg_log_prob,
                                reverse=True)
        return results_sorted[0]

    def decode(self, conf):

        self.model.eval()
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        article_list = list()
        i = 0
        while batch is not None:

            i += 1
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch, conf)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            # print(output_ids)
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]

            except ValueError:
                decoded_words = decoded_words
            if i % 100 == 0:
                print("Batch: {}".format(i))
                print(decoded_words)
            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
Ejemplo n.º 3
0
class Train(object):
    def __init__(self, args, model_name=None):
        self.args = args
        vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path
        self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file,
                           args)
        self.train_batcher = Batcher(args.train_data_path,
                                     self.vocab,
                                     mode='train',
                                     batch_size=args.batch_size,
                                     single_pass=False,
                                     args=args)
        self.eval_batcher = Batcher(args.eval_data_path,
                                    self.vocab,
                                    mode='eval',
                                    batch_size=args.batch_size,
                                    single_pass=True,
                                    args=args)
        time.sleep(30)

        if model_name is None:
            self.train_dir = os.path.join(config.log_root,
                                          'train_%d' % (int(time.time())))
        else:
            self.train_dir = os.path.join(config.log_root, model_name)

        if not os.path.exists(self.train_dir):
            os.mkdir(self.train_dir)
        self.model_dir = os.path.join(self.train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        #self.summary_writer = SummaryWriter(train_dir)

    def save_model(self, running_avg_loss, iter, logger, best_val_loss):
        state = {
            'iter': iter,
            'best_val_loss': best_val_loss,
            'encoder_state_dict': self.model.module.encoder.state_dict(),
            'decoder_state_dict': self.model.module.decoder.state_dict(),
            'reduce_state_dict': self.model.module.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        print(model_save_path)
        logger.debug(model_save_path)
        torch.save(state, model_save_path)
        if self.args.clear_old_checkpoints:
            self.clear_model_dir(checkpoints=self.args.keep_ckpts,
                                 logger=logger)

    def clear_model_dir(self, checkpoints, logger):
        """
        Clears the model directory and only maintains the latest `checkpoints` number of checkpoints.
        """
        files = os.listdir(self.model_dir)
        last_modification = [(os.path.getmtime(os.path.join(self.model_dir,
                                                            f)), f)
                             for f in files]

        # Sort the list by last modified.
        last_modification.sort(key=itemgetter(0))

        # Delete everything but the last 10 files.
        ckpnt_no = 0
        for time, f in last_modification[:-checkpoints]:
            ckpnt_no += 1
            os.remove(os.path.join(self.model_dir, f))
        msg = "Deleted %d checkpoints" % (ckpnt_no)
        logger.debug(msg)
        print(msg)

    def setup_train(self, args):
        self.model = nn.DataParallel(Model(args, self.vocab)).to(device)

        params = list(self.model.module.encoder.parameters()) + list(self.model.module.decoder.parameters()) + \
                 list(self.model.module.reduce_state.parameters())

        initial_lr = args.lr_coverage if args.is_coverage else args.lr
        self.optimizer = AdagradCustom(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        self.crossentropy = nn.CrossEntropyLoss(ignore_index=-1)
        self.head_child_crossent = nn.CrossEntropyLoss(ignore_index=-1,
                                                       weight=torch.Tensor(
                                                           [0.1, 1]).cuda())
        self.attn_mse_loss = nn.MSELoss()

        start_iter, start_loss = 0, 0
        best_val_loss = None
        if args.reload_path is not None:
            print('Loading from checkpoint: ' + str(args.reload_path))
            state = torch.load(args.reload_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']
            #if 'best_val_loss' in state:
            #    best_val_loss = state['best_val_loss']

            if not args.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.to(device)

        return start_iter, start_loss, best_val_loss

    def setup_logging(self):
        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)
        filename = os.path.join(self.train_dir, 'train.log')
        ah = logging.FileHandler(filename)
        ah.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(message)s')
        ah.setFormatter(formatter)
        logger.addHandler(ah)
        return logger

    def train_one_batch(self, batch, args):

        self.optimizer.zero_grad()
        self.model.module.encoder.document_structure_att.output = None
        loss, _, _, _ = self.get_loss(batch, args)
        if loss is None:
            return None
        s1 = time.time()
        loss.backward()
        #print("time for backward: "+str(time.time() - s1))

        clip_grad_norm(self.model.module.encoder.parameters(),
                       config.max_grad_norm)
        clip_grad_norm(self.model.module.decoder.parameters(),
                       config.max_grad_norm)
        clip_grad_norm(self.model.module.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()
        return loss.item()

    def train_iters(self, n_iters, args):
        start_iter, running_avg_loss, best_val_loss = self.setup_train(args)
        logger = self.setup_logging()
        logger.debug(str(args))
        logger.debug(str(config))

        start = time.time()
        # best_val_loss = None

        for it in tqdm(range(n_iters), dynamic_ncols=True):
            iter = start_iter + it
            self.model.module.train()
            batch = self.train_batcher.next_batch()
            start1 = time.time()
            loss = self.train_one_batch(batch, args)
            #print("time for 1 batch+get: "+str(time.time() - start))
            #print("time for 1 batch: "+str(time.time() - start1))
            #start=time.time()
            #print(loss)
            # for n,p in self.model.module.encoder.named_parameters():
            #     print('===========\ngradient:{}\n----------\n{}'.format(n,p.grad))
            # exit()
            if math.isnan(loss):
                msg = "Loss has reached NAN. Exiting"
                logger.debug(msg)
                print(msg)
                exit()
            if loss is not None:
                running_avg_loss = calc_running_avg_loss(
                    loss, running_avg_loss, iter)
                iter += 1

            print_interval = 200
            if iter % print_interval == 0:
                msg = 'steps %d, seconds for %d batch: %.2f , loss: %f' % (
                    iter, print_interval, time.time() - start, loss)
                print(msg)
                logger.debug(msg)
                start = time.time()
            if iter % config.eval_interval == 0:
                print("Starting Eval")
                loss = self.run_eval(logger, args)
                if best_val_loss is None or loss < best_val_loss:
                    best_val_loss = loss
                    self.save_model(running_avg_loss, iter, logger,
                                    best_val_loss)
                    print("Saving best model")
                    logger.debug("Saving best model")
                    # print("Deleting older checkpoints")
                    # ckpt_no = 0
                    # for f in sorted(os.listdir(self.model_dir))[:-10]:
                    #     ckpt_no +=1
                    #     os.remove(f)
                    # print("Deleted %d checkpoints" % (ckpt_no))

    def get_loss(self, batch, args, mode='train'):

        s2 = time.time()
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, \
            enc_batch_extend_vocab, extra_zeros, c_t_1, coverage, word_batch, word_padding_mask, enc_word_lens, \
                enc_tags_batch, enc_sent_tags, enc_sent_token_mat, adj_mat, weighted_adj_mat, norm_adj_mat,\
                    parent_heads, undir_weighted_adj_mat = get_input_from_batch(batch, use_cuda, args)
        #print("time for input func: "+str(time.time() - s2))

        final_dist_list, attn_dist_list, p_gen_list, coverage_list, sent_attention_matrix, \
        sent_single_head_scores, sent_all_head_scores, sent_all_child_scores, \
        token_score, sent_score, doc_score = self.model.forward(enc_batch, enc_padding_token_mask,
                                                                                        enc_padding_sent_mask,
                                                                                        enc_doc_lens,
                                                                                        enc_sent_lens,
                                                                                        enc_batch_extend_vocab,
                                                                                        extra_zeros,
                                                                                        c_t_1, coverage,
                                                                                        word_batch,
                                                                                        word_padding_mask,
                                                                                        enc_word_lens,
                                                                                        enc_tags_batch,
                                                                                        enc_sent_token_mat,
                                                                                        max_dec_len,
                                                                                        dec_batch, adj_mat,
                                                                                        weighted_adj_mat,
                                                                                        undir_weighted_adj_mat, args)

        step_losses = []
        loss = 0
        ind_losses = {
            'summ_loss': 0,
            'sent_single_head_loss': 0,
            'sent_all_head_loss': 0,
            'sent_all_child_loss': 0,
            'token_contsel_loss': 0,
            'sent_imp_loss': 0,
            'doc_imp_loss': 0
        }
        counts = {
            'token_consel_num_correct': 0,
            'token_consel_num': 0,
            'sent_imp_num_correct': 0,
            'doc_imp_num_correct': 0,
            'sent_single_heads_num_correct': 0,
            'sent_single_heads_num': 0,
            'sent_all_heads_num_correct': 0,
            'sent_all_heads_num': 0,
            'sent_all_child_num_correct': 0,
            'sent_all_child_num': 0
        }
        eval_data = {}

        s1 = time.time()
        if args.use_summ_loss:
            for di in range(min(max_dec_len, args.max_dec_steps)):
                final_dist = final_dist_list[:, di, :]
                attn_dist = attn_dist_list[:, di, :]
                if args.is_coverage:
                    coverage = coverage_list[:, di, :]

                target = target_batch[:, di]
                gold_probs = torch.gather(final_dist, 1,
                                          target.unsqueeze(1)).squeeze()
                step_loss = -torch.log(gold_probs + config.eps)
                if args.is_coverage:
                    step_coverage_loss = torch.sum(
                        torch.min(attn_dist, coverage), 1)
                    step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                step_mask = dec_padding_mask[:, di]
                step_loss = step_loss * step_mask
                step_losses.append(step_loss)
            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / dec_lens_var
            loss += torch.mean(batch_avg_loss)
            ind_losses['summ_loss'] += torch.mean(batch_avg_loss).item()

        if args.heuristic_chains:
            if args.use_attmat_loss:
                pred = sent_attention_matrix[:, :, 1:].contiguous().view(-1)
                gold = norm_adj_mat.view(-1)
                loss_aux = self.attn_mse_loss(pred, gold)
                loss += 100 * loss_aux
            if args.use_sent_single_head_loss:
                pred = sent_single_head_scores
                pred = pred.view(-1, pred.size(2))
                head_labels = parent_heads.view(-1)
                loss_aux = self.crossentropy(pred, head_labels.long())
                loss += loss_aux
                prediction = torch.argmax(
                    pred.clone().detach().requires_grad_(False), dim=1)
                if mode == 'eval':
                    prediction[
                        head_labels ==
                        -1] = -2  # Explicitly set masked tokens as different from value in gold
                    counts['sent_single_heads_num_correct'] = torch.sum(
                        prediction.eq(head_labels.long())).item()
                    counts['sent_single_heads_num'] = torch.sum(
                        head_labels != -1).item()
                ind_losses['sent_single_head_loss'] += loss_aux.item()
            if args.use_sent_all_head_loss:
                pred = sent_all_head_scores
                pred = pred.view(-1, pred.size(3))
                target_h = adj_mat.permute(0, 2, 1).contiguous().view(-1)
                #print(pred.size(), target.size())
                loss_aux = self.head_child_crossent(pred, target_h.long())
                loss += loss_aux
                prediction = torch.argmax(
                    pred.clone().detach().requires_grad_(False), dim=1)
                if mode == 'eval':
                    prediction[
                        target_h ==
                        -1] = -2  # Explicitly set masked tokens as different from value in gold
                    counts['sent_all_heads_num_correct'] = torch.sum(
                        prediction.eq(target_h.long())).item()
                    counts['sent_all_heads_num_correct_1'] = torch.sum(
                        prediction[target_h == 1].eq(
                            target_h[target_h == 1].long())).item()
                    counts['sent_all_heads_num_correct_0'] = torch.sum(
                        prediction[target_h == 0].eq(
                            target_h[target_h == 0].long())).item()
                    counts['sent_all_heads_num_1'] = torch.sum(
                        target_h == 1).item()
                    counts['sent_all_heads_num_0'] = torch.sum(
                        target_h == 0).item()
                    counts['sent_all_heads_num'] = torch.sum(
                        target_h != -1).item()
                    eval_data['sent_all_heads_pred'] = prediction.cpu().numpy()
                    eval_data['sent_all_heads_true'] = target_h.cpu().numpy()
                ind_losses['sent_all_head_loss'] += loss_aux.item()
                #print('all head '+str(loss_aux.item()))
            if args.use_sent_all_child_loss:
                pred = sent_all_child_scores
                pred = pred.view(-1, pred.size(3))
                target = adj_mat.contiguous().view(-1)
                loss_aux = self.head_child_crossent(pred, target.long())
                loss += loss_aux
                prediction = torch.argmax(
                    pred.clone().detach().requires_grad_(False), dim=1)
                if mode == 'eval':
                    prediction[
                        target ==
                        -1] = -2  # Explicitly set masked tokens as different from value in gold
                    counts['sent_all_child_num_correct'] = torch.sum(
                        prediction.eq(target.long())).item()
                    counts['sent_all_child_num_correct_1'] = torch.sum(
                        prediction[target == 1].eq(
                            target[target == 1].long())).item()
                    counts['sent_all_child_num_correct_0'] = torch.sum(
                        prediction[target == 0].eq(
                            target[target == 0].long())).item()
                    counts['sent_all_child_num_1'] = torch.sum(
                        target == 1).item()
                    counts['sent_all_child_num_0'] = torch.sum(
                        target == 0).item()
                    counts['sent_all_child_num'] = torch.sum(
                        target != -1).item()
                    eval_data['sent_all_child_pred'] = prediction.cpu().numpy()
                    eval_data['sent_all_child_true'] = target.cpu().numpy()
                ind_losses['sent_all_child_loss'] += loss_aux.item()
                #print('all child '+str(loss_aux.item()))
            # print(target_h.long().eq(target.long()))
            # print(adj_mat)
            #else:
            #   pass

        if args.use_token_contsel_loss:
            pred = token_score.view(-1, 2)
            gold = enc_tags_batch.view(-1)
            loss1 = self.crossentropy(pred, gold.long())
            loss += loss1
            if mode == 'eval':
                prediction = torch.argmax(
                    pred.clone().detach().requires_grad_(False), dim=1)
                prediction[
                    gold ==
                    -1] = -2  # Explicitly set masked tokens as different from value in gold
                counts['token_consel_num_correct'] = torch.sum(
                    prediction.eq(gold)).item()
                counts['token_consel_num'] = torch.sum(gold != -1).item()
            ind_losses['token_contsel_loss'] += loss1.item()
        if args.use_sent_imp_loss:
            pred = sent_score.view(-1)
            enc_sent_tags[enc_sent_tags == -1] = 0
            gold = enc_sent_tags.sum(dim=-1).float()
            gold = gold / gold.sum(dim=1, keepdim=True).repeat(1, gold.size(1))
            gold = gold.view(-1)
            loss2 = self.attn_mse_loss(pred, gold)
            ind_losses['sent_imp_loss'] += loss2.item()
            loss += loss2
        if args.use_doc_imp_loss:
            pred = doc_score.view(-1)
            count_tags = enc_tags_batch.clone().detach()
            count_tags[count_tags == 0] = 1
            count_tags[count_tags == -1] = 0
            token_count = count_tags.sum(dim=-1).sum(dim=-1)
            enc_tags_batch[enc_tags_batch == -1] = 0
            gold = enc_tags_batch.sum(dim=-1)
            gold = gold.sum(dim=-1)
            gold = gold / token_count
            loss3 = self.attn_mse_loss(pred, gold)
            loss += loss3
            ind_losses['doc_imp_loss'] += loss3.item()
        #print("time for loss compute: "+str(time.time() - s1))
        #print("time for 1 batch func: "+str(time.time() - s2))
        return loss, ind_losses, counts, eval_data

    def run_eval(self, logger, args):
        running_avg_loss, iter = 0, 0
        run_avg_losses = {
            'summ_loss': 0,
            'sent_single_head_loss': 0,
            'sent_all_head_loss': 0,
            'sent_all_child_loss': 0,
            'token_contsel_loss': 0,
            'sent_imp_loss': 0,
            'doc_imp_loss': 0
        }
        counts = {
            'token_consel_num_correct': 0,
            'token_consel_num': 0,
            'sent_single_heads_num_correct': 0,
            'sent_single_heads_num': 0,
            'sent_all_heads_num_correct': 0,
            'sent_all_heads_num': 0,
            'sent_all_heads_num_correct_1': 0,
            'sent_all_heads_num_1': 0,
            'sent_all_heads_num_correct_0': 0,
            'sent_all_heads_num_0': 0,
            'sent_all_child_num_correct': 0,
            'sent_all_child_num': 0,
            'sent_all_child_num_correct_1': 0,
            'sent_all_child_num_1': 0,
            'sent_all_child_num_correct_0': 0,
            'sent_all_child_num_0': 0
        }
        eval_res = {
            'sent_all_heads_pred': [],
            'sent_all_heads_true': [],
            'sent_all_child_pred': [],
            'sent_all_child_true': [],
        }
        self.model.module.eval()
        self.eval_batcher._finished_reading = False
        self.eval_batcher.setup_queues()
        batch = self.eval_batcher.next_batch()
        while batch is not None:
            loss, sample_ind_losses, sample_counts, eval_data = self.get_loss(
                batch, args, mode='eval')
            loss = loss.item()
            if loss is not None:
                running_avg_loss = calc_running_avg_loss(
                    loss, running_avg_loss, iter)

                if args.use_summ_loss:
                    run_avg_losses['summ_loss'] = calc_running_avg_loss(
                        sample_ind_losses['summ_loss'],
                        run_avg_losses['summ_loss'], iter)
                if args.use_sent_single_head_loss:
                    run_avg_losses[
                        'sent_single_head_loss'] = calc_running_avg_loss(
                            sample_ind_losses['sent_single_head_loss'],
                            run_avg_losses['sent_single_head_loss'], iter)
                    counts['sent_single_heads_num_correct'] += sample_counts[
                        'sent_single_heads_num_correct']
                    counts['sent_single_heads_num'] += sample_counts[
                        'sent_single_heads_num']
                if args.use_sent_all_head_loss:
                    run_avg_losses[
                        'sent_all_head_loss'] = calc_running_avg_loss(
                            sample_ind_losses['sent_all_head_loss'],
                            run_avg_losses['sent_all_head_loss'], iter)
                    counts['sent_all_heads_num_correct'] += sample_counts[
                        'sent_all_heads_num_correct']
                    counts['sent_all_heads_num'] += sample_counts[
                        'sent_all_heads_num']
                    counts['sent_all_heads_num_correct_1'] += sample_counts[
                        'sent_all_heads_num_correct_1']
                    counts['sent_all_heads_num_1'] += sample_counts[
                        'sent_all_heads_num_1']
                    counts['sent_all_heads_num_correct_0'] += sample_counts[
                        'sent_all_heads_num_correct_0']
                    counts['sent_all_heads_num_0'] += sample_counts[
                        'sent_all_heads_num_0']
                    eval_res['sent_all_heads_pred'].append(
                        eval_data['sent_all_heads_pred'])
                    eval_res['sent_all_heads_true'].append(
                        eval_data['sent_all_heads_true'])
                if args.use_sent_all_child_loss:
                    run_avg_losses[
                        'sent_all_child_loss'] = calc_running_avg_loss(
                            sample_ind_losses['sent_all_child_loss'],
                            run_avg_losses['sent_all_child_loss'], iter)
                    counts['sent_all_child_num_correct'] += sample_counts[
                        'sent_all_child_num_correct']
                    counts['sent_all_child_num'] += sample_counts[
                        'sent_all_child_num']
                    counts['sent_all_child_num_correct_1'] += sample_counts[
                        'sent_all_child_num_correct_1']
                    counts['sent_all_child_num_1'] += sample_counts[
                        'sent_all_child_num_1']
                    counts['sent_all_child_num_correct_0'] += sample_counts[
                        'sent_all_child_num_correct_0']
                    counts['sent_all_child_num_0'] += sample_counts[
                        'sent_all_child_num_0']
                    eval_res['sent_all_child_pred'].append(
                        eval_data['sent_all_child_pred'])
                    eval_res['sent_all_child_true'].append(
                        eval_data['sent_all_child_true'])
                if args.use_token_contsel_loss:
                    run_avg_losses[
                        'token_contsel_loss'] = calc_running_avg_loss(
                            sample_ind_losses['token_contsel_loss'],
                            run_avg_losses['token_contsel_loss'], iter)
                    counts['token_consel_num_correct'] += sample_counts[
                        'token_consel_num_correct']
                    counts['token_consel_num'] += sample_counts[
                        'token_consel_num']
                if args.use_sent_imp_loss:
                    run_avg_losses['sent_imp_loss'] = calc_running_avg_loss(
                        sample_ind_losses['sent_imp_loss'],
                        run_avg_losses['sent_imp_loss'], iter)
                if args.use_doc_imp_loss:
                    run_avg_losses['doc_imp_loss'] = calc_running_avg_loss(
                        sample_ind_losses['doc_imp_loss'],
                        run_avg_losses['doc_imp_loss'], iter)
                iter += 1
            batch = self.eval_batcher.next_batch()

        msg = 'Eval: loss: %f' % running_avg_loss
        print(msg)
        logger.debug(msg)

        if args.use_summ_loss:
            msg = 'Summ Eval: loss: %f' % run_avg_losses['summ_loss']
            print(msg)
            logger.debug(msg)
        if args.use_sent_single_head_loss:
            msg = 'Single Sent Head Eval: loss: %f' % run_avg_losses[
                'sent_single_head_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average Sent Single Head Accuracy: %f' % (
                counts['sent_single_heads_num_correct'] /
                float(counts['sent_single_heads_num']))
            print(msg)
            logger.debug(msg)
        if args.use_sent_all_head_loss:
            msg = 'All Sent Head Eval: loss: %f' % run_avg_losses[
                'sent_all_head_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average Sent All Head Accuracy: %f' % (
                counts['sent_all_heads_num_correct'] /
                float(counts['sent_all_heads_num']))
            print(msg)
            logger.debug(msg)
            # msg = 'Average Sent All Head Class1 Accuracy: %f' % (counts['sent_all_heads_num_correct_1']/float(counts['sent_all_heads_num_1']))
            # print(msg)
            # logger.debug(msg)
            # msg = 'Average Sent All Head Class0 Accuracy: %f' % (counts['sent_all_heads_num_correct_0']/float(counts['sent_all_heads_num_0']))
            # print(msg)
            # logger.debug(msg)
            y_pred = np.concatenate(eval_res['sent_all_heads_pred'])
            y_true = np.concatenate(eval_res['sent_all_heads_true'])
            msg = classification_report(y_true, y_pred, labels=[0, 1])
            print(msg)
            logger.debug(msg)

        if args.use_sent_all_child_loss:
            msg = 'All Sent Child Eval: loss: %f' % run_avg_losses[
                'sent_all_child_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average Sent All Child Accuracy: %f' % (
                counts['sent_all_child_num_correct'] /
                float(counts['sent_all_child_num']))
            print(msg)
            logger.debug(msg)
            # msg = 'Average Sent All Child Class1 Accuracy: %f' % (counts['sent_all_child_num_correct_1']/float(counts['sent_all_child_num_1']))
            # print(msg)
            # logger.debug(msg)
            # msg = 'Average Sent All Child Class0 Accuracy: %f' % (counts['sent_all_child_num_correct_0']/float(counts['sent_all_child_num_0']))
            # print(msg)
            # logger.debug(msg)
            y_pred = np.concatenate(eval_res['sent_all_child_pred'])
            y_true = np.concatenate(eval_res['sent_all_child_true'])
            msg = classification_report(y_true, y_pred, labels=[0, 1])
            print(msg)
            logger.debug(msg)
        if args.use_token_contsel_loss:
            msg = 'Token Contsel Eval: loss: %f' % run_avg_losses[
                'token_contsel_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average token content sel Accuracy: %f' % (
                counts['token_consel_num_correct'] /
                float(counts['token_consel_num']))
            print(msg)
            logger.debug(msg)
        if args.use_sent_imp_loss:
            msg = 'Sent Imp Eval: loss: %f' % run_avg_losses['sent_imp_loss']
            print(msg)
            logger.debug(msg)
        if args.use_doc_imp_loss:
            msg = 'Doc Imp Eval: loss: %f' % run_avg_losses['doc_imp_loss']
            print(msg)
            logger.debug(msg)

        return running_avg_loss
Ejemplo n.º 4
0
class Evaluate(object):
    def __init__(self, model_file_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.eval_data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        time.sleep(15)
        model_name = os.path.basename(model_file_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = SummaryWriter(eval_dir)
        self.model = Model(model_file_path, is_eval=True)

    def eval_one_batch(self, batch):
        enc_batch, enc_padding_token_mask, enc_padding_sent_mask,  enc_doc_lens, enc_sent_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)

        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(
            enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask,
            enc_padding_sent_mask)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_sent_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab, coverage)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        del enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage
        gc.collect()
        torch.cuda.empty_cache()

        return loss.item()

    def run_eval(self):
        running_avg_loss, iter = 0, 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            # if iter % 100 == 0:
            #     self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start,
                       running_avg_loss))
                start = time.time()
            batch = self.batcher.next_batch()
Ejemplo n.º 5
0
class BeamSearch(object):
    def __init__(self, args, model_file_path, save_path):
        model_name = os.path.basename(model_file_path)
        self.args = args
        self._decode_dir = os.path.join(config.log_root, save_path,
                                        'decode_%s' % (model_name))
        self._structures_dir = os.path.join(self._decode_dir, 'structures')
        self._sent_single_heads_dir = os.path.join(self._decode_dir,
                                                   'sent_heads_preds')
        self._sent_single_heads_ref_dir = os.path.join(self._decode_dir,
                                                       'sent_heads_ref')
        self._contsel_dir = os.path.join(self._decode_dir, 'content_sel_preds')
        self._contsel_ref_dir = os.path.join(self._decode_dir,
                                             'content_sel_ref')
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        self._rouge_ref_file = os.path.join(self._decode_dir, 'rouge_ref.json')
        self._rouge_pred_file = os.path.join(self._decode_dir,
                                             'rouge_pred.json')
        self.stat_res_file = os.path.join(self._decode_dir, 'stats.txt')
        self.sent_count_file = os.path.join(self._decode_dir,
                                            'sent_used_counts.txt')
        for p in [
                self._decode_dir, self._structures_dir,
                self._sent_single_heads_ref_dir, self._sent_single_heads_dir,
                self._contsel_ref_dir, self._contsel_dir, self._rouge_ref_dir,
                self._rouge_dec_dir
        ]:
            if not os.path.exists(p):
                os.mkdir(p)
        vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path
        self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file,
                           args)
        self.batcher = Batcher(args.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=args.beam_size,
                               single_pass=True,
                               args=args)
        self.batcher.setup_queues()
        time.sleep(30)

        self.model = Model(args, self.vocab).to(device)
        self.model.eval()

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def extract_structures(self, batch, sent_attention_matrix,
                           doc_attention_matrix, count, use_cuda, sent_scores):
        fileName = os.path.join(self._structures_dir,
                                "%06d_struct.txt" % count)
        fp = open(fileName, "w")
        fp.write("Doc: " + str(count) + "\n")
        #exit(0)
        doc_attention_matrix = doc_attention_matrix[:, :]  #this change yet to be tested!
        l = batch.enc_doc_lens[0].item()
        doc_sent_no = 0

        # for i in range(l):
        #     printstr = ''
        #     sent = batch.enc_batch[0][i]
        #     #scores = str_scores_sent[sent_no][0:l, 0:l]
        #     token_count = 0
        #     for j in range(batch.enc_sent_lens[0][i].item()):
        #         token = sent[j].item()
        #         printstr += self.vocab.id2word(token)+" "
        #         token_count = token_count + 1
        #     #print(printstr)
        #     fp.write(printstr+"\n")
        #
        #     scores = sent_attention_matrix[doc_sent_no][0:token_count, 0:token_count]
        #     shape2 = sent_attention_matrix[doc_sent_no][0:token_count,0:token_count].size()
        #     row = torch.ones([1, shape2[1]+1]).cuda()
        #     column = torch.zeros([shape2[0], 1]).cuda()
        #     new_scores = torch.cat([column, scores], dim=1)
        #     new_scores = torch.cat([row, new_scores], dim=0)
        #
        #     heads, tree_score = chu_liu_edmonds(new_scores.data.cpu().numpy().astype(np.float64))
        #     #print(heads, tree_score)
        #     fp.write(str(heads)+" ")
        #     fp.write(str(tree_score)+"\n")
        #     doc_sent_no+=1

        shape2 = doc_attention_matrix[0:l, 0:l + 1].size()
        row = torch.zeros([1, shape2[1]]).cuda()
        #column = torch.zeros([shape2[0], 1]).cuda()
        scores = doc_attention_matrix[0:l, 0:l + 1]
        #new_scores = torch.cat([column, scores], dim=1)
        new_scores = torch.cat([row, scores], dim=0)
        val, root_edge = torch.max(new_scores[:, 0], dim=0)
        root_score = torch.zeros([shape2[0] + 1, 1]).cuda()
        root_score[root_edge] = 1
        new_scores[:, 0] = root_score.squeeze()
        #print(new_scores)
        #print(new_scores.sum(dim=0))
        #print(new_scores.sum(dim=1))
        #print(new_scores.size())
        heads, tree_score = chu_liu_edmonds(
            new_scores.data.cpu().numpy().astype(np.float64))
        height = find_height(heads)
        leaf_nodes = leaf_node_proportion(heads)
        #print(heads, tree_score)
        fp.write("\n")
        sentences = str(batch.original_articles[0]).split("<split1>")
        for idx, sent in enumerate(sentences):
            fp.write(str(idx) + "\t" + str(sent) + "\n")
        #fp.write(str("\n".join(batch.original_articles[0].split("<split1>"))+"\n")
        fp.write(str(heads) + " ")
        fp.write(str(tree_score) + "\n")
        fp.write(str(height) + "\n")
        s = sent_scores[0].data.cpu().numpy()
        for val in s:
            fp.write(str(val))
        fp.close()
        #exit()
        structure_info = dict()
        structure_info['heads'] = heads
        structure_info['height'] = height
        structure_info['leaf_nodes'] = leaf_nodes
        return structure_info

    def decode(self):
        start = time.time()
        counter = 0
        sent_counter = []
        avg_max_seq_len_list = []
        copied_sequence_len = Counter()
        copied_sequence_per_sent = []
        article_copy_id_count_tot = Counter()
        sentence_copy_id_count = Counter()
        novel_counter = Counter()
        repeated_counter = Counter()
        summary_sent_count = Counter()
        summary_sent = []
        article_sent = []
        summary_len = []
        abstract_ref = []
        abstract_pred = []
        sentence_count = []
        tot_sentence_id_count = Counter()
        height_avg = []
        leaf_node_proportion_avg = []
        precision_tree_dist = []
        recall_tree_dist = []
        batch = self.batcher.next_batch()
        height_counter = Counter()
        leaf_nodes_counter = Counter()
        sent_count_fp = open(self.sent_count_file, 'w')

        counts = {
            'token_consel_num_correct': 0,
            'token_consel_num': 0,
            'sent_single_heads_num_correct': 0,
            'sent_single_heads_num': 0,
            'sent_all_heads_num_correct': 0,
            'sent_all_heads_num': 0,
            'sent_all_child_num_correct': 0,
            'sent_all_child_num': 0
        }
        no_batches_processed = 0
        while batch is not None:
            # Run beam search to get best Hypothesis
            #start = time.process_time()
            has_summary, best_summary, sample_predictions, sample_counts, structure_info, adj_mat = self.get_decoded_outputs(
                batch, counter)
            #print('Time taken for decoder: ', time.process_time() - start)
            # token_contsel_tot_correct += token_consel_num_correct
            # token_contsel_tot_num += token_consel_num
            # sent_heads_tot_correct += sent_heads_num_correct
            # sent_heads_tot_num += sent_heads_num

            if args.predict_contsel_tags:
                no_words = batch.enc_word_lens[0]
                prediction = sample_predictions['token_contsel_prediction'][
                    0:no_words]
                ref = batch.contsel_tags[0]
                write_tags(prediction, ref, counter, self._contsel_dir,
                           self._contsel_ref_dir)
                counts['token_consel_num_correct'] += sample_counts[
                    'token_consel_num_correct']
                counts['token_consel_num'] += sample_counts['token_consel_num']

            if args.predict_sent_single_head:
                no_sents = batch.enc_doc_lens[0]
                prediction = sample_predictions[
                    'sent_single_heads_prediction'][0:no_sents].tolist()
                ref = batch.original_parent_heads[0]
                write_tags(prediction, ref, counter,
                           self._sent_single_heads_dir,
                           self._sent_single_heads_ref_dir)
                counts['sent_single_heads_num_correct'] += sample_counts[
                    'sent_single_heads_num_correct']
                counts['sent_single_heads_num'] += sample_counts[
                    'sent_single_heads_num']

            if args.predict_sent_all_head:
                counts['sent_all_heads_num_correct'] += sample_counts[
                    'sent_all_heads_num_correct']
                counts['sent_all_heads_num'] += sample_counts[
                    'sent_all_heads_num']

            if args.predict_sent_all_child:
                counts['sent_all_child_num_correct'] += sample_counts[
                    'sent_all_child_num_correct']
                counts['sent_all_child_num'] += sample_counts[
                    'sent_all_child_num']

            if has_summary == False:
                batch = self.batcher.next_batch()
                continue
            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if self.args.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            summary_len.append(len(decoded_words))
            assert adj_mat is not None, "Explicit matrix is none."
            assert structure_info['heads'] is not None, "Heads is none."
            precision, recall = tree_distance(
                structure_info['heads'],
                adj_mat.cpu().data.numpy()[0, :, :])
            if precision is not None and recall is not None:
                precision_tree_dist.append(precision)
                recall_tree_dist.append(recall)
            height_counter[structure_info['height']] += 1
            height_avg.append(structure_info['height'])
            leaf_node_proportion_avg.append(structure_info['leaf_nodes'])
            leaf_nodes_counter[np.floor(structure_info['leaf_nodes'] *
                                        10)] += 1
            abstract_ref.append(" ".join(original_abstract_sents))
            abstract_pred.append(" ".join(decoded_words))

            sent_res = get_sent_dist(" ".join(decoded_words),
                                     batch.original_articles[0].decode(),
                                     minimum_seq=self.args.minimum_seq)

            sent_counter.append(
                (sent_res['seen_sent'], sent_res['article_sent']))
            summary_len.append(sent_res['summary_len'])
            summary_sent.append(sent_res['summary_sent'])
            summary_sent_count[sent_res['summary_sent']] += 1
            article_sent.append(sent_res['article_sent'])
            if sent_res['avg_copied_seq_len'] is not None:
                avg_max_seq_len_list.append(sent_res['avg_copied_seq_len'])
                copied_sequence_per_sent.append(
                    np.average(
                        list(sent_res['counter_summary_sent_id'].values())))
            copied_sequence_len.update(sent_res['counter_copied_sequence_len'])
            sentence_copy_id_count.update(sent_res['counter_summary_sent_id'])
            article_copy_id_count_tot.update(
                sent_res['counter_article_sent_id'])
            novel_counter.update(sent_res['novel_ngram_counter'])
            repeated_counter.update(sent_res['repeated_ngram_counter'])

            sent_count_fp.write(
                str(counter) + "\t" + str(sent_res['article_sent']) + "\t" +
                str(sent_res['seen_sent']) + "\n")
            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)

            batch = self.batcher.next_batch()

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()
            #print('Time taken for rest: ', time.process_time() - start)
            if args.decode_for_subset:
                if counter == 1000:
                    break

        print("Decoder has finished reading dataset for single_pass.")

        fp = open(self.stat_res_file, 'w')
        percentages = [
            float(len(seen_sent)) / float(sent_count)
            for seen_sent, sent_count in sent_counter
        ]
        avg_percentage = sum(percentages) / float(len(percentages))
        nosents = [len(seen_sent) for seen_sent, sent_count in sent_counter]
        avg_nosents = sum(nosents) / float(len(nosents))

        res = dict()
        res['avg_percentage_seen_sent'] = avg_percentage
        res['avg_nosents'] = avg_nosents
        res['summary_len'] = summary_sent_count
        res['avg_summary_len'] = np.average(summary_len)
        res['summary_sent'] = np.average(summary_sent)
        res['article_sent'] = np.average(article_sent)
        res['avg_copied_seq_len'] = np.average(avg_max_seq_len_list)
        res['avg_sequences_per_sent'] = np.average(copied_sequence_per_sent)
        res['counter_copied_sequence_len'] = copied_sequence_len
        res['counter_summary_sent_id'] = sentence_copy_id_count
        res['counter_article_sent_id'] = article_copy_id_count_tot
        res['novel_ngram_counter'] = novel_counter
        res['repeated_ngram_counter'] = repeated_counter

        fp.write("Summary metrics\n")
        for key in res:
            fp.write('{}: {}\n'.format(key, res[key]))

        fp.write("Structures metrics\n")
        fp.write("Average depth of RST tree: " +
                 str(sum(height_avg) / len(height_avg)) + "\n")
        fp.write("Average proportion of leaf nodes in RST tree: " + str(
            sum(leaf_node_proportion_avg) / len(leaf_node_proportion_avg)) +
                 "\n")
        fp.write("Precision of edges latent to explicit: " +
                 str(np.average(precision_tree_dist)) + "\n")
        fp.write("Recall of edges latent to explicit: " +
                 str(np.average(recall_tree_dist)) + "\n")
        fp.write("Tree height counter:\n")
        fp.write(str(height_counter) + "\n")
        fp.write("Tree leaf proportion counter:")
        fp.write(str(leaf_nodes_counter) + "\n")

        if args.predict_contsel_tags:
            fp.write("Avg token_contsel: " +
                     str((counts['token_consel_num_correct'] /
                          float(counts['token_consel_num']))))
        if args.predict_sent_single_head:
            fp.write("Avg single sent heads: " +
                     str((counts['sent_single_heads_num_correct'] /
                          float(counts['sent_single_heads_num']))))
        if args.predict_sent_all_head:
            fp.write("Avg all sent heads: " +
                     str((counts['sent_all_heads_num_correct'] /
                          float(counts['sent_all_heads_num']))))
        if args.predict_sent_all_child:
            fp.write("Avg all sent child: " +
                     str((counts['sent_all_child_num_correct'] /
                          float(counts['sent_all_child_num']))))
        fp.close()
        sent_count_fp.close()

        write_to_json_file(abstract_ref, self._rouge_ref_file)
        write_to_json_file(abstract_pred, self._rouge_pred_file)

    def get_decoded_outputs(self, batch, count):
        #batch should have only one example
        enc_batch, enc_padding_token_mask, enc_padding_sent_mask, enc_doc_lens, enc_sent_lens, \
            enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, word_batch, word_padding_mask, enc_word_lens, \
                enc_tags_batch, enc_sent_tags, enc_sent_token_mat, adj_mat, weighted_adj_mat, norm_adj_mat, \
                    parent_heads, undir_weighted_adj_mat = get_input_from_batch(batch, use_cuda, self.args)

        enc_adj_mat = adj_mat
        if args.use_weighted_annotations:
            if args.use_undirected_weighted_graphs:
                enc_adj_mat = undir_weighted_adj_mat
            else:
                enc_adj_mat = weighted_adj_mat

        encoder_output = self.model.encoder.forward_test(
            enc_batch, enc_sent_lens, enc_doc_lens, enc_padding_token_mask,
            enc_padding_sent_mask, word_batch, word_padding_mask,
            enc_word_lens, enc_tags_batch, enc_sent_token_mat, enc_adj_mat)
        encoder_outputs, enc_padding_mask, encoder_last_hidden, max_encoder_output, \
        enc_batch_extend_vocab, token_level_sentence_scores, sent_outputs, token_scores, \
        sent_scores, sent_matrix, sent_level_rep = \
                                    self.model.get_app_outputs(encoder_output, enc_padding_token_mask,
                                                   enc_padding_sent_mask, enc_batch_extend_vocab, enc_sent_token_mat)

        mask = enc_padding_sent_mask[0].unsqueeze(0).repeat(
            enc_padding_sent_mask.size(1),
            1) * enc_padding_sent_mask[0].unsqueeze(1).transpose(1, 0)

        mask = torch.cat((enc_padding_sent_mask[0].unsqueeze(1), mask), dim=1)
        mat = encoder_output['sent_attention_matrix'][0][:, :] * mask

        structure_info = self.extract_structures(
            batch, encoder_output['token_attention_matrix'], mat, count,
            use_cuda, encoder_output['sent_score'])

        counts = {}
        predictions = {}
        if args.predict_contsel_tags:
            pred = encoder_output['token_score'][0, :, :].view(-1, 2)
            token_contsel_gold = enc_tags_batch[0, :].view(-1)
            token_contsel_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            token_contsel_prediction[
                token_contsel_gold ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            token_consel_num_correct = torch.sum(
                token_contsel_prediction.eq(token_contsel_gold)).item()
            token_consel_num = torch.sum(token_contsel_gold != -1).item()
            predictions['token_contsel_prediction'] = token_contsel_prediction
            counts['token_consel_num_correct'] = token_consel_num_correct
            counts['token_consel_num'] = token_consel_num

        if args.predict_sent_single_head:
            pred = encoder_output['sent_single_head_scores'][0, :, :]
            head_labels = parent_heads[0, :].view(-1)
            sent_single_heads_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            sent_single_heads_prediction[
                head_labels ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            sent_single_heads_num_correct = torch.sum(
                sent_single_heads_prediction.eq(head_labels)).item()
            sent_single_heads_num = torch.sum(head_labels != -1).item()
            predictions[
                'sent_single_heads_prediction'] = sent_single_heads_prediction
            counts[
                'sent_single_heads_num_correct'] = sent_single_heads_num_correct
            counts['sent_single_heads_num'] = sent_single_heads_num

        if args.predict_sent_all_head:
            pred = encoder_output['sent_all_head_scores'][0, :, :, :]
            target = adj_mat[0, :, :].permute(0, 1).view(-1)
            sent_all_heads_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            sent_all_heads_prediction[
                target ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            sent_all_heads_num_correct = torch.sum(
                sent_all_heads_prediction.eq(target)).item()
            sent_all_heads_num = torch.sum(target != -1).item()
            predictions[
                'sent_all_heads_prediction'] = sent_all_heads_prediction
            counts['sent_all_heads_num_correct'] = sent_all_heads_num_correct
            counts['sent_all_heads_num'] = sent_all_heads_num

        if args.predict_sent_all_child:
            pred = encoder_output['sent_all_child_scores'][0, :, :, :]
            target = adj_mat[0, :, :].view(-1)
            sent_all_child_prediction = torch.argmax(
                pred.clone().detach().requires_grad_(False), dim=1)
            sent_all_child_prediction[
                target ==
                -1] = -2  # Explicitly set masked tokens as different from value in gold
            sent_all_child_num_correct = torch.sum(
                sent_all_child_prediction.eq(target)).item()
            sent_all_child_num = torch.sum(target != -1).item()
            predictions[
                'sent_all_child_prediction'] = sent_all_child_prediction
            counts['sent_all_child_num_correct'] = sent_all_child_num_correct
            counts['sent_all_child_num'] = sent_all_child_num

        results = []
        steps = 0
        has_summary = False
        beams_sorted = [None]
        if args.predict_summaries:
            has_summary = True

            if (args.fixed_scorer):
                scorer_output = self.model.module.pretrained_scorer.forward_test(
                    enc_batch, enc_sent_lens, enc_doc_lens,
                    enc_padding_token_mask, enc_padding_sent_mask, word_batch,
                    word_padding_mask, enc_word_lens, enc_tags_batch)
                token_scores = scorer_output['token_score']
                sent_scores = scorer_output['sent_score'].unsqueeze(1).repeat(
                    1, enc_padding_token_mask.size(2), 1, 1).view(
                        enc_padding_token_mask.size(0),
                        enc_padding_token_mask.size(1) *
                        enc_padding_token_mask.size(2))

            all_child, all_head = None, None
            if args.use_gold_annotations_for_decode:
                if args.use_weighted_annotations:
                    if args.use_undirected_weighted_graphs:
                        permuted_all_head = undir_weighted_adj_mat[:, :, :].permute(
                            0, 2, 1)
                        all_head = permuted_all_head.clone()
                        row_sums = torch.sum(permuted_all_head,
                                             dim=2,
                                             keepdim=True)
                        all_head[row_sums.expand_as(
                            permuted_all_head) != 0] = permuted_all_head[
                                row_sums.expand_as(permuted_all_head) !=
                                0] / row_sums.expand_as(permuted_all_head)[
                                    row_sums.expand_as(permuted_all_head) != 0]

                        base_all_child = undir_weighted_adj_mat[:, :, :]
                        all_child = base_all_child.clone()
                        row_sums = torch.sum(base_all_child,
                                             dim=2,
                                             keepdim=True)
                        all_child[row_sums.expand_as(
                            base_all_child) != 0] = base_all_child[
                                row_sums.expand_as(base_all_child) !=
                                0] / row_sums.expand_as(base_all_child)[
                                    row_sums.expand_as(base_all_child) != 0]
                    else:
                        permuted_all_head = weighted_adj_mat[:, :, :].permute(
                            0, 2, 1)
                        all_head = permuted_all_head.clone()
                        row_sums = torch.sum(permuted_all_head,
                                             dim=2,
                                             keepdim=True)
                        all_head[row_sums.expand_as(
                            permuted_all_head) != 0] = permuted_all_head[
                                row_sums.expand_as(permuted_all_head) !=
                                0] / row_sums.expand_as(permuted_all_head)[
                                    row_sums.expand_as(permuted_all_head) != 0]

                        base_all_child = weighted_adj_mat[:, :, :]
                        all_child = base_all_child.clone()
                        row_sums = torch.sum(base_all_child,
                                             dim=2,
                                             keepdim=True)
                        all_child[row_sums.expand_as(
                            base_all_child) != 0] = base_all_child[
                                row_sums.expand_as(base_all_child) !=
                                0] / row_sums.expand_as(base_all_child)[
                                    row_sums.expand_as(base_all_child) != 0]
                else:
                    permuted_all_head = adj_mat[:, :, :].permute(0, 2, 1)
                    all_head = permuted_all_head.clone()
                    row_sums = torch.sum(permuted_all_head,
                                         dim=2,
                                         keepdim=True)
                    all_head[row_sums.expand_as(
                        permuted_all_head) != 0] = permuted_all_head[
                            row_sums.expand_as(permuted_all_head) !=
                            0] / row_sums.expand_as(permuted_all_head)[
                                row_sums.expand_as(permuted_all_head) != 0]

                    base_all_child = adj_mat[:, :, :]
                    all_child = base_all_child.clone()
                    row_sums = torch.sum(base_all_child, dim=2, keepdim=True)
                    all_child[row_sums.expand_as(base_all_child) !=
                              0] = base_all_child[
                                  row_sums.expand_as(base_all_child) !=
                                  0] / row_sums.expand_as(base_all_child)[
                                      row_sums.expand_as(base_all_child) != 0]
                    # all_head = adj_mat[:, :, :].permute(0,2,1) + 0.00005
                    # row_sums = torch.sum(all_head, dim=2, keepdim=True)
                    # all_head = all_head / row_sums
                    # all_child = adj_mat[:, :, :] + 0.00005
                    # row_sums = torch.sum(all_child, dim=2, keepdim=True)
                    # all_child = all_child / row_sums

            s_t_0 = self.model.reduce_state(encoder_last_hidden)

            if config.use_maxpool_init_ctx:
                c_t_0 = max_encoder_output

            dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            #decoder batch preparation, it has beam_size example initially everything is repeated
            beams = [
                Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                     log_probs=[0.0],
                     state=(dec_h[0], dec_c[0]),
                     context=c_t_0[0],
                     coverage=(coverage_t_0[0] if self.args.is_coverage
                               or self.args.bu_coverage_penalty else None))
                for _ in range(args.beam_size)
            ]

            while steps < args.max_dec_steps and len(results) < args.beam_size:
                latest_tokens = [h.latest_token for h in beams]
                # cur_len = torch.stack([len(h.tokens) for h in beams])
                latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                 for t in latest_tokens]
                y_t_1 = Variable(torch.LongTensor(latest_tokens))
                if use_cuda:
                    y_t_1 = y_t_1.cuda()
                all_state_h = []
                all_state_c = []

                all_context = []

                for h in beams:
                    state_h, state_c = h.state
                    all_state_h.append(state_h)
                    all_state_c.append(state_c)

                    all_context.append(h.context)

                s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0),
                         torch.stack(all_state_c, 0).unsqueeze(0))
                c_t_1 = torch.stack(all_context, 0)

                coverage_t_1 = None
                if self.args.is_coverage or self.args.bu_coverage_penalty:
                    all_coverage = []
                    for h in beams:
                        all_coverage.append(h.coverage)
                    coverage_t_1 = torch.stack(all_coverage, 0)

                final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                    y_t_1, s_t_1, encoder_outputs, word_padding_mask, c_t_1,
                    extra_zeros, enc_batch_extend_vocab, coverage_t_1,
                    token_scores, sent_scores, sent_outputs,
                    enc_sent_token_mat, all_head, all_child, sent_level_rep)

                if args.bu_coverage_penalty:
                    penalty = torch.max(coverage_t,
                                        coverage_t.clone().fill_(1.0)).sum(-1)
                    penalty -= coverage_t.size(-1)
                    final_dist -= args.beta * penalty.unsqueeze(1).expand_as(
                        final_dist)
                if args.bu_length_penalty:
                    penalty = ((5 + steps + 1) / 6.0)**args.alpha
                    final_dist = final_dist / penalty

                topk_log_probs, topk_ids = torch.topk(final_dist,
                                                      args.beam_size * 2)

                dec_h, dec_c = s_t
                dec_h = dec_h.squeeze()
                dec_c = dec_c.squeeze()

                all_beams = []
                num_orig_beams = 1 if steps == 0 else len(beams)
                for i in range(num_orig_beams):
                    h = beams[i]
                    state_i = (dec_h[i], dec_c[i])
                    context_i = c_t[i]
                    coverage_i = (coverage_t[i] if self.args.is_coverage
                                  or self.args.bu_coverage_penalty else None)

                    for j in range(args.beam_size *
                                   2):  # for each of the top 2*beam_size hyps:
                        new_beam = h.extend(token=topk_ids[i, j].item(),
                                            log_prob=topk_log_probs[i,
                                                                    j].item(),
                                            state=state_i,
                                            context=context_i,
                                            coverage=coverage_i)
                        all_beams.append(new_beam)

                beams = []
                for h in self.sort_beams(all_beams):
                    if h.latest_token == self.vocab.word2id(
                            data.STOP_DECODING):
                        if steps >= config.min_dec_steps:
                            results.append(h)
                    else:
                        beams.append(h)
                    if len(beams) == args.beam_size or len(
                            results) == args.beam_size:
                        break

                steps += 1

            if len(results) == 0:
                results = beams

            beams_sorted = self.sort_beams(results)

        return has_summary, beams_sorted[
            0], predictions, counts, structure_info, undir_weighted_adj_mat