Ejemplo n.º 1
0
class Demo(Evaluate):
    def __init__(self, opt):
        self.vocab = Vocab(config.demo_vocab_path, config.demo_vocab_size)
        self.opt = opt
        self.setup_valid()

    def evaluate(self, article, ref):
        dec = self.abstract(article)
        scores = rouge.get_scores(dec, ref)
        rouge_1 = sum([x["rouge-1"]["f"] for x in scores]) / len(scores)
        rouge_2 = sum([x["rouge-2"]["f"] for x in scores]) / len(scores)
        rouge_l = sum([x["rouge-l"]["f"] for x in scores]) / len(scores)
        return {
            'dec': dec,
            'rouge_1': rouge_1,
            'rouge_2': rouge_2,
            'rouge_l': rouge_l
        }

    def abstract(self, article):
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        example = Example(' '.join(jieba.cut(article)), '', self.vocab)
        batch = Batch([example], self.vocab, 1)
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(
            batch)
        with T.autograd.no_grad():
            enc_batch = self.model.embeds(enc_batch)
            enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)
            pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e,
                                   extra_zeros, enc_batch_extend_vocab,
                                   self.model, start_id, end_id, unk_id)

        for i in range(len(pred_ids)):
            decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                 batch.art_oovs[i])
            decoded_words = " ".join(decoded_words)
        return decoded_words
Ejemplo n.º 2
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self.model_path_name = model_name
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()

        decoded_result = []
        refered_result = []
        article_result = []
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]
            article = batch.original_articles[0]

            #write_for_rouge(original_abstract_sents, decoded_words, counter,
            #                self._rouge_ref_dir, self._rouge_dec_dir)
            decoded_sents = []
            while len(decoded_words) > 0:
                try:
                    fst_period_idx = decoded_words.index(".")
                except ValueError:
                    fst_period_idx = len(decoded_words)
                sent = decoded_words[:fst_period_idx + 1]
                decoded_words = decoded_words[fst_period_idx + 1:]
                decoded_sents.append(' '.join(sent))

# pyrouge calls a perl script that puts the data into HTML files.
# Therefore we need to make our output HTML safe.
            decoded_sents = [make_html_safe(w) for w in decoded_sents]
            reference_sents = [
                make_html_safe(w) for w in original_abstract_sents
            ]
            decoded_result.append(' '.join(decoded_sents))
            refered_result.append(' '.join(reference_sents))
            article_result.append(article)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        load_file = self.model_path_name
        self.print_original_predicted(decoded_result, refered_result,
                                      article_result, load_file)

        rouge = Rouge()
        scores = rouge.get_scores(decoded_result, refered_result)
        rouge_1 = sum([x["rouge-1"]["f"] for x in scores]) / len(scores)
        rouge_2 = sum([x["rouge-2"]["f"] for x in scores]) / len(scores)
        rouge_l = sum([x["rouge-l"]["f"] for x in scores]) / len(scores)
        rouge_1_r = sum([x["rouge-1"]["r"] for x in scores]) / len(scores)
        rouge_2_r = sum([x["rouge-2"]["r"] for x in scores]) / len(scores)
        rouge_l_r = sum([x["rouge-l"]["r"] for x in scores]) / len(scores)
        rouge_1_p = sum([x["rouge-1"]["p"] for x in scores]) / len(scores)
        rouge_2_p = sum([x["rouge-2"]["p"] for x in scores]) / len(scores)
        rouge_l_p = sum([x["rouge-l"]["p"] for x in scores]) / len(scores)
        log_str = " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" + "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l
        log_str_r = " rouge_1_r:" + "%.4f" % rouge_1_r + " rouge_2_r:" + "%.4f" % rouge_2_r + " rouge_l_r:" + "%.4f" % rouge_l_r
        logger.info(load_file + " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" +
                    "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l)
        log_str_p = " rouge_1_p:" + "%.4f" % rouge_1_p + " rouge_2_p:" + "%.4f" % rouge_2_p + " rouge_l_p:" + "%.4f" % rouge_l_p
        results_file = os.path.join(self._decode_dir, "ROUGE_results.txt")
        with open(results_file, "w") as f:
            f.write(log_str + '\n')
            f.write(log_str_r + '\n')
            f.write(log_str_p + '\n')

        #results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        #rouge_log(results_dict, self._decode_dir)

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split("_")[1] + ".txt"

        with open(os.path.join(self._rouge_dec_dir, filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 3
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)

        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        # print("MODE MUST BE train")
        # time.sleep(15)
        self.print_interval = config.print_interval

        train_dir = config.train_dir
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = train_dir
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        # self.summary_writer = tf.compat.v1.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir, 'iter{}.pt'.format(iter))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        if not config.is_hierarchical:
            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
            s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            pad_id  = self.vocab.word2id('[PAD]')
            enc_sent_pos = get_sent_position(enc_batch, stop_id, pad_id)
            dec_sent_pos = get_sent_position(dec_batch, stop_id, pad_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask, sent_lens, seq_lens2 = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)

            s_t_1, sent_s_t_1 = self.model.reduce_state(encoder_hidden, sent_enc_hidden)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if not config.is_hierarchical:
                # start = datetime.now()

                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(y_t_1, s_t_1,
                                                            encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                            extra_zeros, enc_batch_extend_vocab,
                                                                               coverage, di)
                # print('NO HIER Time: ',datetime.now() - start)
                # import pdb; pdb.set_trace()
            else:
                # start = datetime.now()
                max_doc_len = enc_batch.size(1)
                final_dist, sent_s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, sent_s_t_1,
                                                            encoder_outputs, encoder_feature, enc_padding_mask, seq_lens2,
                                                            sent_s_t_1, sent_enc_outputs, sent_enc_feature, sent_enc_padding_mask,
                                                            sent_lens, max_doc_len,
                                                            c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
                # print('DO HIER Time: ',datetime.now() - start)
                # import pdb; pdb.set_trace()


            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        # start = datatime.now()
        loss.backward()
        # print('{} HIER Time: {}'.format(config.is_hierarchical ,datetime.now() - start))
        # import pdb; pdb.set_trace()

        clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        sys.stdout.flush()

        # data_path = "lib/data/batches_train.vocab50000.batch16.pk.bin"
        # with open(data_path, 'rb') as f:
        #     stored_batches = pickle.load(f, encoding="bytes")
        # print("loaded data: {}".format(data_path))
        # num_batches = len(stored_batches)

        while iter < n_iters:
            batch = self.batcher.next_batch()
            # batch_id = iter%num_batches
            # batch = stored_batches[batch_id]

            loss = self.train_one_batch(batch)

            # running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, iter)

            iter += 1

            # if iter % 100 == 0:
            #     self.summary_writer.flush()

            if iter % self.print_interval == 0:
                print("[{}] iter {}, loss: {:.5f}".format(str(datetime.now()), iter, loss))
                sys.stdout.flush()

            if iter % config.save_every == 0:
                self.save_model(running_avg_loss, iter)

        print("Finished training!")
Ejemplo n.º 4
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.decode_dir, model_name)
        self._decode_dir = os.path.splitext(self._decode_dir)[0]
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            Path(p).mkdir(parents=True, exist_ok=True)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.pad_id = self.vocab.word2id(PAD_TOKEN)
        self.start_id = self.vocab.word2id(START_DECODING)
        self.stop_id = self.vocab.word2id(STOP_DECODING)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def if_already_exists(self, idx):
        decoded_file = os.path.join(self._rouge_dec_dir,
                                    "file.{}.txt".format(idx))
        return os.path.isfile(decoded_file)

    def decode(self, file_id_start, file_id_stop, ami_id='191209'):
        print("AMI transcription:", ami_id)

        test_data = load_ami_data(ami_id, 'test')

        # do this for faster stack CPU machines - to replace those that fail!!
        idx_list = [i for i in range(file_id_start, file_id_stop)]
        random.shuffle(idx_list)
        for idx in idx_list:

            # for idx in range(file_id_start, file_id_stop):
            # check if this is written already
            if self.if_already_exists(idx):
                print("ID {} already exists".format(idx))
                continue

            # Run beam search to get best Hypothesis
            best_summary, art_oovs = self.beam_search(test_data, idx)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            # original_abstract_sents = batch.original_abstracts_sents[0]
            original_abstract_sents = []

            write_for_rouge(original_abstract_sents, decoded_words, idx,
                            self._rouge_ref_dir, self._rouge_dec_dir)

            print("decoded idx = {}".format(idx))
        print("Finished decoding idx [{},{})".format(file_id_start,
                                                     file_id_stop))

        # print("Starting ROUGE eval...")
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)

    def beam_search(self, test_data, idx):
        # batch should have only one example
        # enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
        #       get_input_from_batch(batch, use_cuda)

        enc_pack, art_oovs = get_a_batch_decode(test_data,
                                                idx,
                                                self.vocab,
                                                config.beam_size,
                                                config.max_enc_steps,
                                                config.max_dec_steps,
                                                self.start_id,
                                                self.stop_id,
                                                self.pad_id,
                                                sum_type='short',
                                                use_cuda=use_cuda)
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = enc_pack
        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state.forward1(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda: y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder.forward1(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0], art_oovs
Ejemplo n.º 5
0
class Train(object):
    def __init__(self, opt):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        self.opt = opt
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        time.sleep(5)

    def save_model(self, iter):
        save_path = config.save_model_path + "/%07d.tar" % iter
        T.save(
            {
                "iter": iter + 1,
                "model_dict": self.model.state_dict(),
                "trainer_dict": self.trainer.state_dict()
            }, save_path)

    def setup_train(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        self.trainer = T.optim.Adam(self.model.parameters(), lr=config.lr)
        start_iter = 0
        if self.opt.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.opt.load_model)
            checkpoint = T.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.load_state_dict(checkpoint["model_dict"])
            self.trainer.load_state_dict(checkpoint["trainer_dict"])
            print("Loaded model at " + load_model_path)
        if self.opt.new_lr is not None:
            self.trainer = T.optim.Adam(self.model.parameters(),
                                        lr=self.opt.new_lr)
        return start_iter

    def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                        extra_zeros, enc_batch_extend_vocab, batch):
        ''' Calculate Negative Log Likelihood Loss for the given batch. In order to reduce exposure bias,
                pass the previous generated token as input with a probability of 0.25 instead of ground truth label
        Args:
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param batch: batch object
        '''
        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(
            batch)  #Get input and target batchs for training decoder
        step_losses = []
        s_t = (enc_hidden[0], enc_hidden[1])  #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(
            self.start_id))  #Input to the decoder
        prev_s = None  #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        for t in range(min(max_dec_len, config.max_dec_steps)):
            use_gound_truth = get_cuda((T.rand(len(enc_out)) > 0.25)).long(
            )  #Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
            x_t = use_gound_truth * dec_batch[:, t] + (
                1 - use_gound_truth
            ) * x_t  #Select decoder input based on use_ground_truth probabilities
            x_t = self.model.embeds(x_t)
            final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            target = target_batch[:, t]
            log_probs = T.log(final_dist + config.eps)
            step_loss = F.nll_loss(log_probs,
                                   target,
                                   reduction="none",
                                   ignore_index=self.pad_id)
            step_losses.append(step_loss)
            x_t = T.multinomial(final_dist, 1).squeeze(
            )  #Sample words from final distribution which can be used as input in next time step
            is_oov = (x_t >= config.vocab_size
                      ).long()  #Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t.detach() + (
                is_oov) * self.unk_id  #Replace OOVs with [UNK] token

        losses = T.sum(
            T.stack(step_losses, 1), 1
        )  #unnormalized losses for each example in the batch; (batch_size)
        batch_avg_loss = losses / dec_lens  #Normalized losses; (batch_size)
        mle_loss = T.mean(batch_avg_loss)  #Average batch loss
        return mle_loss

    def train_batch_RL(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                       extra_zeros, enc_batch_extend_vocab, article_oovs,
                       greedy):
        '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation
        Args
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param article_oovs: Batch containing list of OOVs in each example
        :param greedy: If true, performs greedy based sampling, else performs multinomial sampling
        Returns:
        :decoded_strs: List of decoded sentences
        :log_probs: Log probabilities of sampled words
        '''
        s_t = enc_hidden  #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(
            self.start_id))  #Input to the decoder
        prev_s = None  #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        inds = []  #Stores sampled indices for each time step
        decoder_padding_mask = []  #Stores padding masks of generated samples
        log_probs = []  #Stores log probabilites of generated samples
        mask = get_cuda(
            T.LongTensor(len(enc_out)).fill_(1)
        )  #Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise

        for t in range(config.max_dec_steps):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            if greedy is False:
                multi_dist = Categorical(probs)
                x_t = multi_dist.sample()  #perform multinomial sampling
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                _, x_t = T.max(probs, dim=1)  #perform greedy sampling
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = get_cuda(T.zeros(
                len(enc_out)))  #Padding mask of batch for current time step
            mask_t[
                mask ==
                1] = 1  #If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0
            mask[
                (mask == 1) + (x_t == self.end_id) ==
                2] = 0  #If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t >= config.vocab_size
                      ).long()  #Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t + (
                is_oov) * self.unk_id  #Replace OOVs with [UNK] token

        inds = T.stack(inds, dim=1)
        decoder_padding_mask = T.stack(decoder_padding_mask, dim=1)
        if greedy is False:  #If multinomial based sampling, compute log probabilites of sampled words
            log_probs = T.stack(log_probs, dim=1)
            log_probs = log_probs * decoder_padding_mask  #Not considering sampled words with padding mask = 0
            lens = T.sum(decoder_padding_mask,
                         dim=1)  #Length of sampled sentence
            log_probs = T.sum(
                log_probs, dim=1
            ) / lens  # (bs,)                                     #compute normalizied log probability of a sentence
        decoded_strs = []
        for i in range(len(enc_out)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = data.outputids2words(
                id_list, self.vocab,
                oovs)  # Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(data.STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(
                    S
            ) < 2:  #If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs

    def reward_function(self, decoded_sents, original_sents):
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    print("Error occured at:")
                    print("decoded_sents:", decoded_sents[i])
                    print("original_sents:", original_sents[i])
                    score = [{"rouge-l": {"f": 0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        rouge_l_f1 = get_cuda(T.FloatTensor(rouge_l_f1))
        return rouge_l_f1

    # def write_to_file(self, decoded, max, original, sample_r, baseline_r, iter):
    #     with open("temp.txt", "w") as f:
    #         f.write("iter:"+str(iter)+"\n")
    #         for i in range(len(original)):
    #             f.write("dec: "+decoded[i]+"\n")
    #             f.write("max: "+max[i]+"\n")
    #             f.write("org: "+original[i]+"\n")
    #             f.write("Sample_R: %.4f, Baseline_R: %.4f\n\n"%(sample_r[i].item(), baseline_r[i].item()))

    def train_one_batch(self, batch, iter):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(
            batch)

        enc_batch = self.model.embeds(
            enc_batch)  #Get embeddings for encoder input
        enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

        # -------------------------------Summarization-----------------------
        if self.opt.train_mle == "yes":  #perform MLE training
            mle_loss = self.train_batch_MLE(enc_out, enc_hidden,
                                            enc_padding_mask, context,
                                            extra_zeros,
                                            enc_batch_extend_vocab, batch)
        else:
            mle_loss = get_cuda(T.FloatTensor([0]))
        # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == "yes":  #perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(
                enc_out,
                enc_hidden,
                enc_padding_mask,
                context,
                extra_zeros,
                enc_batch_extend_vocab,
                batch.art_oovs,
                greedy=False)
            with T.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(enc_out,
                                                      enc_hidden,
                                                      enc_padding_mask,
                                                      context,
                                                      extra_zeros,
                                                      enc_batch_extend_vocab,
                                                      batch.art_oovs,
                                                      greedy=True)

            sample_reward = self.reward_function(sample_sents,
                                                 batch.original_abstracts)
            baseline_reward = self.reward_function(greedy_sents,
                                                   batch.original_abstracts)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(
                sample_reward - baseline_reward
            ) * RL_log_probs  #Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = T.mean(rl_loss)

            batch_reward = T.mean(sample_reward).item()
        else:
            rl_loss = get_cuda(T.FloatTensor([0]))
            batch_reward = 0

    # ------------------------------------------------------------------------------------
        self.trainer.zero_grad()
        (self.opt.mle_weight * mle_loss +
         self.opt.rl_weight * rl_loss).backward()
        self.trainer.step()
        return mle_loss.item(), batch_reward

    def trainIters(self):
        iter = self.setup_train()
        count = mle_total = r_total = 0
        while iter <= config.max_iterations:
            batch = self.batcher.next_batch()
            try:
                mle_loss, r = self.train_one_batch(batch, iter)
            except KeyboardInterrupt:
                print(
                    "-------------------Keyboard Interrupt------------------")
                exit(0)

            mle_total += mle_loss
            r_total += r
            count += 1
            iter += 1

            if iter % 50 == 0:
                mle_avg = mle_total / count
                r_avg = r_total / count
                logger.info("iter:" + str(iter) + "  mle_loss:" +
                            "%.3f" % mle_avg + "  reward:" + "%.4f" % r_avg)
                count = mle_total = r_total = 0

            if iter % 5000 == 0:
                self.save_model(iter)
Ejemplo n.º 6
0
class Train(object):
    def __init__(self):
        if config.is_hierarchical:
            raise Exception("Hierarchical PGN-AMI not supported!")

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.pad_id = self.vocab.word2id(PAD_TOKEN)
        self.start_id = self.vocab.word2id(START_DECODING)
        self.stop_id = self.vocab.word2id(STOP_DECODING)

        self.print_interval = config.print_interval

        train_dir = config.train_dir
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = train_dir
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir,
                                       'iter{}.pt'.format(iter))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, ami_data, idx):
        # enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
        #     get_ami_input_from_batch(batch, use_cuda)
        # dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
        #     get_ami_output_from_batch(batch, use_cuda)

        enc_pack, dec_pack = get_a_batch(ami_data,
                                         idx,
                                         self.vocab,
                                         config.batch_size,
                                         config.max_enc_steps,
                                         config.max_dec_steps,
                                         self.start_id,
                                         self.stop_id,
                                         self.pad_id,
                                         sum_type='short',
                                         use_cuda=use_cuda)
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = enc_pack
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = dec_pack

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing

            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        sys.stdout.flush()

        ami_data = load_ami_data('train')
        valid_data = load_ami_data('valid')
        # make the training data 100
        random.shuffle(valid_data)
        ami_data.extend(valid_data[:6])
        valid_data = valid_data[6:]

        num_batches = len(ami_data)
        idx = 0

        # validation & stopping
        best_valid_loss = 1000000000
        stop_counter = 0

        while iter < n_iters:
            if idx == 0:
                print("shuffle training data")
                random.shuffle(ami_data)

            loss = self.train_one_batch(ami_data, idx)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     iter)

            iter += 1
            idx += config.batch_size
            if idx == num_batches: idx = 0

            if iter % self.print_interval == 0:
                print("[{}] iter {}, loss: {:.5f}".format(
                    str(datetime.now()), iter, loss))
                sys.stdout.flush()

            if iter % config.save_every == 0:
                self.save_model(running_avg_loss, iter)

            if iter % config.eval_every == 0:
                valid_loss = self.run_eval(valid_data)
                print("valid_loss = {:.5f}".format(valid_loss))
                if valid_loss < best_valid_loss:
                    stop_counter = 0
                    best_valid_loss = valid_loss
                    print("VALID better")
                else:
                    stop_counter += 1
                    print(
                        "VALID NOT better, counter = {}".format(stop_counter))
                    if stop_counter == config.stop_after:
                        print("Stop training")
                        return

        print("Finished training!")

    def eval_one_batch(self, eval_data, idx):

        enc_pack, dec_pack = get_a_batch(eval_data,
                                         idx,
                                         self.vocab,
                                         1,
                                         config.max_enc_steps,
                                         config.max_dec_steps,
                                         self.start_id,
                                         self.stop_id,
                                         self.pad_id,
                                         sum_type='short',
                                         use_cuda=use_cuda)

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = enc_pack
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = dec_pack

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist,
                                      dim=1,
                                      index=target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data.item()

    def run_eval(self, eval_data):
        running_avg_loss, iter = 0, 0
        batch_losses = []
        num_batches = len(eval_data)
        print("valid data size = {}".format(num_batches))
        for idx in range(num_batches):
            loss = self.eval_one_batch(eval_data, idx)
            batch_losses.append(loss)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     iter)
            print("#", end="")
            sys.stdout.flush()
        print()

        avg_loss = sum(batch_losses) / len(batch_losses)
        return avg_loss
Ejemplo n.º 7
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.decode_dir, model_name)
        self._decode_dir = os.path.splitext(self._decode_dir)[0]
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            Path(p).mkdir(parents=True, exist_ok=True)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        # self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
        #                        batch_size=config.beam_size, single_pass=True)
        # time.sleep(15)
        self.get_batches(config.decode_pk_path)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def get_batches(self, path):
        """
        load batches dumped by pickle
        see batch_saver.py for more information
        """
        with open(path, 'rb') as f:
            batches = pickle.load(f, encoding="bytes")
        self.batches = batches

        print("loaded: {}".format(path))

    def if_already_exists(self, idx):
        ref_file = os.path.join(self._rouge_ref_dir,
                                "{}_reference.txt".format(idx))
        decoded_file = os.path.join(self._rouge_dec_dir,
                                    "{}_decoded.txt".format(idx))
        return os.path.isfile(ref_file) and os.path.isfile(decoded_file)

    def decode(self, file_id_start, file_id_stop):
        if file_id_stop > MAX_TEST_ID: file_id_stop = MAX_TEST_ID

        # while batch is not None:

        # do this for faster stack CPU machines - to replace those that fail!!
        idx_list = [i for i in range(file_id_start, file_id_stop)]
        random.shuffle(idx_list)

        for idx in idx_list:

            # check if this is written already
            if self.if_already_exists(idx):
                # print("ID {} already exists".format(idx))
                continue

            # batch = self.batcher.next_batch()
            batch = self.batches[idx]

            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, idx,
                            self._rouge_ref_dir, self._rouge_dec_dir)

            print("decoded idx = {}".format(idx))
        print("Finished decoding idx [{},{})".format(file_id_start,
                                                     file_id_stop))

        # print("Starting ROUGE eval...")
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        if not config.is_hierarchical:

            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
                enc_batch, enc_lens)
            s_t_0 = self.model.reduce_state(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            enc_sent_pos = get_sent_position(enc_batch, stop_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)
            s_t_0, _ = self.model.reduce_state(encoder_hidden, sent_enc_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda: y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            if not config.is_hierarchical:
                final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage_t_1, steps)

            else:
                final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                    y_t_1, s_t_1, enc_sent_pos, encoder_outputs,
                    encoder_feature, enc_padding_mask, None, sent_enc_outputs,
                    sent_enc_feature, sent_enc_padding_mask, c_t_1,
                    extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 8
0
class Evaluate(object):

    def __init__(self, data_path, opt, batch_size=config.batch_size):

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='eval', batch_size=batch_size,
                               single_pass=True)

        self.opt =opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = torch.load(os.path.join(config.save_model_path,
                                             self.opt.load_model))
        # 加载在train中保存得模型
        self.model.load_state_dict(checkpoint['model_dict'])

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):

        # 这里可能会存在一点问题,debug得时候剋注意一下
        filename = 'test_' + loadfile.split('.')[:-1] + '.txt'

        with open(os.path.join('data', filename), 'w') as f:
            for i in range(len(decoded_sents)):
                f.write('article' + article_sents[i] + '\n')
                f.write('reference:' + ref_sents[i] + '\n')
                f.write('decoder:' + decoded_sents[i] + '\n')

    def evaluate_batch(self, print_sents =False):

        self.setup_valid()
        batch = self.batcher.next_batch()

        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)

        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()

        batch_number = 0

        while batch is not None:

            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, \
                extra_zeros, ct_e = get_enc_data(batch)

            with torch.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with torch.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e,
                                       extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                # 返回的是一个 单词列表。
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])

                if len(decoded_words) < 2:
                    decoded_words = 'xxx'
                else:
                    decoded_words = ' '.join(decoded_words)

                decoded_sents.append(decoded_words)
                summary = batch.original_summarys[i]
                article = batch.original_articles[i]
                ref_sents.append(summary)
                article_sents.append(article)

            batch = self.batcher.next_batch()
            batch_number += 1

            if batch_number < 100:
                continue
            else:
                break

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)

        if self.opt.task == 'test':
            print(load_file, 'scores:', scores)
            sys.stdout.flush()
        else:
            rouge_l = scores['rouge-l']['f']
            print(load_file, 'rouge-l:', '%.4f' % rouge_l)
Ejemplo n.º 9
0
class TrainInstance:
    def __init__(self, args):
        self.hparams = hp()
        self.model = Model(self.hparams)
        self.vocab = Vocab(config.vocab_path, self.hparams.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=self.hparams.batch_size,
                               single_pass=False)
        self.args = args
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        time.sleep(3)

    def save_model(self, iter):
        save_path = config.save_model_path + "/%07d.tar" % iter
        T.save(
            {
                "iter": iter + 1,
                "model_dict": self.model.state_dict(),
                "trainer_dict": self.trainer.state_dict()
            }, save_path)

    def setup_train(self):
        self.trainer = T.optim.Adam(self.model.parameters(), lr=config.lr)
        start_iter = 0
        if self.args.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.args.load_model)
            checkpoint = T.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.load_state_dict(checkpoint["model_dict"])
            self.trainer.load_state_dict(checkpoint["trainer_dict"])
            print("Loaded model at " + load_model_path)
        if self.args.new_lr is not None:
            self.trainer = T.optim.Adam(self.model.parameters(),
                                        lr=self.args.new_lr)
        return start_iter

    def decoder(self, enc_out, mask, batch):
        SOS_IDX = 2
        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(batch)
        dec_batch = torch.cuda.LongTensor(dec_batch)
        dec_in = torch.cuda.LongTensor([SOS_IDX] *
                                       self.hparams.batch_size).unsqueeze(1)
        PAD_IDX = 1
        loss = 0
        try:
            # for t in range(min(max_dec_len, config.max_dec_steps)):
            for t in range(dec_batch.size(1)):
                dec_out = self.model.decoder(enc_out, dec_in, mask)
                loss += F.nll_loss(dec_out,
                                   dec_batch[:, t],
                                   reduction="sum",
                                   ignore_index=PAD_IDX)
                dec_in = torch.cat((dec_in, dec_batch[:, t].unsqueeze(1)), 1)
                # if VERBOSE:
                # 	for i, j in enumerate(dec_out.data.topk(1)[1]):
                # 		pred[i].append(scalar(j))
            loss /= dec_batch.data.gt(
                0).sum().float()  # divide by the number of unpadded tokens
            loss.backward()
            print("Loss: {}".format(mm.scalar(loss)))
            return loss
        except:
            return 0

    def train_one_batch(self, batch):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(
            batch)

        enc_batch = torch.cuda.LongTensor(enc_batch)

        PAD_IDX = 1

        def mask_pad(x):
            return x.data.eq(PAD_IDX).view(self.hparams.batch_size, 1, 1, -1)

        self.trainer.zero_grad()
        mask = mask_pad(enc_batch)

        enc_out = self.model.encoder(enc_batch, mask)
        dec_out = self.decoder(enc_out, mask, batch)

        if dec_out == True:
            # mle_loss = self.train_batch_MLE(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch)
            # (self.opt.mle_weight * mle_loss + self.opt.rl_weight * rl_loss).backward()
            self.trainer.step()

        # return mle_loss.item(), batch_reward

    def train_iters(self):
        iter = self.setup_train()
        # count = mle_total = r_total = 0
        count = 0
        print("Training")
        while iter <= config.max_iterations:
            batch = self.batcher.next_batch()
            if count == 1:
                count += 1
                continue
            # try:
            # mle_loss, r = self.train_one_batch(batch, iter)
            print("Batch {}".format(count))
            self.train_one_batch(batch)
            # except KeyboardInterrupt:
            # print("-------------------Keyboard Interrupt------------------")
            # exit(0)

            # mle_total += mle_loss
            # r_total += r
            count += 1
            iter += 1

            if iter % 1000 == 0:
                # mle_avg = mle_total / count
                # r_avg = r_total / count
                # print("iter:", iter, "mle_loss:", "%.3f" % mle_avg, "reward:", "%.4f" % r_avg)
                # count = mle_total = r_total = 0
                count = 0
                print("iter: {}".format(iter))

            if iter % 5000 == 0:
                self.save_model(iter)
Ejemplo n.º 10
0
class Evaluate(object):
    def __init__(self, model_file_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.eval_data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        # time.sleep(15)
        model_name = os.path.basename(model_file_path)
        # eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        # if not os.path.exists(eval_dir):
        #     os.mkdir(eval_dir)
        # self.summary_writer = tf.summary.FileWriter(eval_dir)
        self.model_file_path = model_file_path
        self.model = Model(model_file_path, is_eval=True)

    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        if not config.is_hierarchical:
            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
                enc_batch, enc_lens)
            s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            enc_sent_pos = get_sent_position(enc_batch, stop_id)
            dec_sent_pos = get_sent_position(dec_batch, stop_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)
            s_t_1, sent_s_t_1 = self.model.reduce_state(
                encoder_hidden, sent_enc_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if not config.is_hierarchical:
                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage, di)

            else:

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    y_t_1, s_t_1, enc_sent_pos, encoder_outputs,
                    encoder_feature, enc_padding_mask, sent_s_t_1,
                    sent_enc_outputs, sent_enc_feature, sent_enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist,
                                      dim=1,
                                      index=target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data.item()

    def run_eval(self):
        running_avg_loss, iter = 0, 0
        batch_losses = []
        # while batch is not None:
        for _ in range(835):
            batch = self.batcher.next_batch()

            loss = self.eval_one_batch(batch)
            batch_losses.append(loss)
            # running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     iter)
            iter += 1

            # if iter % 100 == 0:
            #     self.summary_writer.flush()

            print_interval = 10
            if iter % print_interval == 0:
                print("[{}] iter {}, loss: {:.5f}".format(
                    str(datetime.now()), iter, loss))

        avg_loss = sum(batch_losses) / len(batch_losses)
        print("Finished Eval for Model {}: Avg Loss = {:.5f}".format(
            self.model_file_path, avg_loss))
Ejemplo n.º 11
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)


    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in xrange(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage_t_1)

            topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                   log_prob=topk_log_probs[i, j].data[0],
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 12
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        if config.cuda:
            checkpoint = T.load(
                os.path.join(config.demo_model_path, self.opt.load_model))
        else:
            checkpoint = T.load(os.path.join(config.demo_model_path,
                                             self.opt.load_model),
                                map_location='cpu')
        self.model.load_state_dict(checkpoint["model_dict"])

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split(".")[0] + ".txt"

        with open(os.path.join("data", filename), "w", encoding='utf-8') as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, article):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(
                batch)
            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask,
                                       ct_e, extra_zeros,
                                       enc_batch_extend_vocab, self.model,
                                       start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        if article:
            self.print_original_predicted(decoded_sents, ref_sents,
                                          article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents)
        rouge_1 = sum([x["rouge-1"]["f"] for x in scores]) / len(scores)
        rouge_2 = sum([x["rouge-2"]["f"] for x in scores]) / len(scores)
        rouge_l = sum([x["rouge-l"]["f"] for x in scores]) / len(scores)
        logger.info(load_file + " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" +
                    "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l)
Ejemplo n.º 13
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.concept_vocab = Concept_vocab(config.concept_vocab_path,
                                           config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               self.concept_vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            best_summary = self.beam_search(batch)

            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")

    def beam_search(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output, _, _ = self.model.encoder(
            enc_batch, enc_lens, enc_batch, enc_batch)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in xrange(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                'decode', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage_t_1, steps)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                        log_prob=topk_log_probs[i, j].data[0],
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 14
0
class TrainSeq2Seq(object):
    def __init__(self, is_word_level=False, is_combined=False, alpha=0.3):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        # self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
        #                        batch_size=config.batch_size, single_pass=False)
        self.dataset = DailyMailDataset("train", self.vocab)
        #time.sleep(15)

        self.is_word_level = is_word_level
        self.is_combined = is_combined
        self.alpha = alpha

        if is_word_level:
            print("Using Word Level Policy Gradient")
        elif is_combined:
            print("Using Combined Policy Gradient w/ alpha = ", alpha)
        else:
            print("Using Sentence Level Policy Gradient")

        train_dir = './train_dumps'
        # train_dir = './train_dumps'
        if not os.path.exists(train_dir):
            #print('create dict')
            os.mkdir(train_dir)

        self.model_dir = os.path.join(
            train_dir, 'dumps_model_{:%m_%d_%H_%M}'.format(datetime.now()))
        if not os.path.exists(self.model_dir):
            #print('create folder')
            os.mkdir(self.model_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)
        return model_save_path

    def setup(self, seqseq_model, model_file_path):
        self.model = seqseq_model

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)
        #self.optimizer = Adam(params, lr=initial_lr)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            print("Loading checkpoint .... ")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if config.use_gpu:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch_nll(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, config.use_gpu)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, config.use_gpu)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def train_nll(self, n_iters, iter, running_avg_loss):
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch_nll(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     iter)
            print("Iteration:", iter, "  loss:", loss, "  Running avg loss:",
                  running_avg_loss)
            iter += 1

            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 1000 == 0:
                self.save_model(running_avg_loss, iter)

    def train_pg(self,
                 n_iters,
                 start_iter,
                 start_running_avg_loss,
                 start_pg_losses,
                 start_run_avg_losses,
                 num_epochs=50):
        """
        The generator is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """

        dataloader = DataLoader(self.dataset,
                                batch_size=config.batch_size,
                                shuffle=True,
                                num_workers=1,
                                collate_fn=create_batch_collate(
                                    self.vocab, config.batch_size))
        # pg_batcher = Batcher(config.train_data_path, self.vocab, mode='train',
        #     batch_size=config.batch_size, single_pass=False)
        #
        # time.sleep(15)

        start = time.time()
        running_avg_loss = start_running_avg_loss
        pg_losses = start_pg_losses
        run_avg_losses = start_run_avg_losses
        iteration = start_iter

        for epoch in range(num_epochs):
            print("Epoch :", epoch + 1)
            for batch in dataloader:
                iteration += 1

                loss = self.train_one_batch_pg(batch)

                running_avg_loss = calc_running_avg_loss(
                    loss, running_avg_loss, iteration)
                print("Iteration:", iteration, "  PG loss:", loss,
                      "  Running avg loss:", running_avg_loss)
                pg_losses.append(loss)
                run_avg_losses.append(running_avg_loss)

                print_interval = 10
                if iteration % print_interval == 0:
                    print(
                        'steps %d, seconds for %d batch: %.2f , loss: %f' %
                        (iteration, print_interval, time.time() - start, loss))

                    start = time.time()

                if iteration % 10 == 0:
                    # Dump model and losses
                    model_file_path = self.save_model(running_avg_loss,
                                                      iteration)
                    pickle.dump(
                        pg_losses,
                        open(
                            os.path.join(
                                self.model_dir,
                                'train_pg_losses_{}.p'.format(iteration)),
                            'wb'))
                    pickle.dump(
                        run_avg_losses,
                        open(
                            os.path.join(
                                self.model_dir,
                                'train_run_avg_losses_{}.p'.format(iteration)),
                            'wb'))
                    # Run eval
                    eval_processor = Evaluate_pg(
                        model_file_path,
                        is_word_level=self.is_word_level,
                        is_combined=self.is_combined,
                        alpha=self.alpha)
                    eval_losses = eval_processor.run_eval(
                        self.model_dir, iteration)

                    # Check if we should stop
                    avg_eval_loss = np.mean(eval_losses)
                    if running_avg_loss < avg_eval_loss:
                        print("Stopping at iteration {}".format(iteration))
                        break

    def compute_policy_grads_using_rewards(self, sentence_rewards,
                                           word_rewards, sentence_losses,
                                           word_losses, word_to_sent_ind):
        if self.is_combined:
            pg_losses = [[(self.alpha * word_reward + (1 - self.alpha) *
                           sentence_rewards[i][word_to_sent_ind[i][j]]) *
                          word_losses[i][j]
                          for j, word_reward in enumerate(abstract_rewards)
                          if j < len(word_to_sent_ind[i])]
                         for i, abstract_rewards in enumerate(word_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        elif self.is_word_level:
            pg_losses = [[
                word_reward * word_losses[i][j]
                for j, word_reward in enumerate(abstract_rewards)
                if j < len(word_to_sent_ind[i])
            ] for i, abstract_rewards in enumerate(word_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        else:
            pg_losses = [[
                rs * sentence_losses[ri][rsi] for rsi, rs in enumerate(r)
            ] for ri, r in enumerate(sentence_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        return pg_losses

    def compute_pg_loss(self, orig, pred, sentence_losses, split_predictions,
                        word_losses, word_to_sent_ind):
        sentence_rewards = None
        word_rewards = None
        # First compute the rewards
        if not self.is_word_level or self.is_combined:
            sentence_rewards = get_sentence_rewards(orig, pred)

        if self.is_word_level or self.is_combined:
            word_rewards = get_word_level_rewards(orig, split_predictions)

        pg_losses = self.compute_policy_grads_using_rewards(
            sentence_rewards=sentence_rewards,
            word_rewards=word_rewards,
            sentence_losses=sentence_losses,
            word_losses=word_losses,
            word_to_sent_ind=word_to_sent_ind)

        return pg_losses

    def compute_batched_sentence_loss(self, word_losses, orig, pred):
        orig_sum = []
        new_pred = []
        pred_sum = []
        sentence_losses = []

        # Convert the original sum as one single string per article
        for i in range(len(orig)):
            orig_sum.append(' '.join(map(str, orig[i])))
            new_pred.append([])
            pred_sum.append([])
            sentence_losses.append([])

        batch_sent_indices = []
        for i in range(len(pred)):
            sentence = []
            sentence = pred[i]
            losses = word_losses[i]
            sentence_indices = []
            count = 0
            while len(sentence) > 0:
                try:
                    idx = sentence.index(".")
                except ValueError:
                    idx = len(sentence)

                sentence_indices.extend([count for _ in range(idx)])

                if count > 0:
                    new_pred[i].append(new_pred[i][count - 1] +
                                       sentence[:idx + 1])
                else:
                    new_pred[i].append(sentence[:idx + 1])

                sentence_losses[i].append(sum(losses[:idx + 1]))

                sentence = sentence[idx + 1:]
                losses = losses[idx + 1:]
                count += 1
            batch_sent_indices.append(sentence_indices)

        for i in range(len(pred)):
            for j in range(len(new_pred[i])):
                pred_sum[i].append(' '.join(map(str, new_pred[i][j])))

        pg_losses = self.compute_pg_loss(orig_sum,
                                         pred_sum,
                                         sentence_losses,
                                         split_predictions=pred,
                                         word_losses=word_losses,
                                         word_to_sent_ind=batch_sent_indices)

        return pg_losses

    def train_one_batch_pg(self, batch):
        batch_size = batch.batch_size

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, config.use_gpu)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, config.use_gpu)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        output_ids = []
        # Begin with START symbol
        y_t_1 = torch.ones(batch_size, dtype=torch.long) * self.vocab.word2id(
            data.START_DECODING)
        if config.use_gpu:
            y_t_1 = y_t_1.cuda()

        for _ in range(batch_size):
            output_ids.append([])
            step_losses.append([])

        for di in range(min(max_dec_len, config.max_dec_steps)):
            #y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)  # NLL

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask

            # Move on to next token
            _, idx = torch.max(final_dist, 1)
            idx = idx.reshape(batch_size, -1).squeeze()
            y_t_1 = idx

            for i, pred in enumerate(y_t_1):
                if not pred.item() == data.PAD_TOKEN:
                    output_ids[i].append(pred.item())

            for i, loss in enumerate(step_loss):
                step_losses[i].append(step_loss[i])

        # Obtain the original and predicted summaries
        original_abstracts = batch.original_abstracts_sents
        predicted_abstracts = [
            data.outputids2words(ids, self.vocab, None) for ids in output_ids
        ]

        # Compute the batched loss
        batched_losses = self.compute_batched_sentence_loss(
            step_losses, original_abstracts, predicted_abstracts)
        #batched_losses = Variable(batched_losses, requires_grad=True)
        losses = torch.stack(batched_losses)
        losses = losses / dec_lens_var

        loss = torch.mean(losses)
        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()
Ejemplo n.º 15
0
class Validate(object):
    def __init__(self, data_path, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)

        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(config.load_model_path)
        self.model.load_state_dict(checkpoint["model_dict"])

    def validate_batch(self):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        original_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t_1 = get_enc_data(
                batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with T.autograd.no_grad():
                pred_ids = beam_search_on_batch(enc_hidden, enc_out,
                                                enc_padding_mask, c_t_1,
                                                extra_zeros,
                                                enc_batch_extend_vocab,
                                                self.model, start_id, end_id,
                                                unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                tar = batch.original_abstracts[i]
                original_sents.append(tar)

            batch = self.batcher.next_batch()

        load_file = config.load_model_path.split("/")[-1]

        scores = rouge.get_scores(decoded_sents, original_sents, avg=True)
        rouge_l = scores["rouge-l"]["f"]
        print(load_file, "rouge_l:", "%.4f" % rouge_l)
Ejemplo n.º 16
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%d' % (int(time.time())))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            self.write_for_rouge(original_abstract_sents, decoded_words,
                                 counter)
            counter += 1
            if counter % 10000:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()
        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 attn_dists=[],
                 p_gens=[]) for _ in xrange(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [
                t if t < self.vocab.size() else self.vocab.word2id(
                    data.UNKNOWN_TOKEN) for t in latest_tokens
            ]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            final_dist, s_t, c_t, attn_dist, p_gen = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                for j in xrange(config.beam_size *
                                2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                        log_prob=topk_log_probs[i, j].data[0],
                                        state=(dec_h[i], dec_c[i]),
                                        context=c_t[i],
                                        attn_dist=attn_dist[i],
                                        p_gen=p_gen[i])
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]

    def write_for_rouge(self, reference_sents, decoded_words, ex_index):
        decoded_sents = []
        while len(decoded_words) > 0:
            try:
                fst_period_idx = decoded_words.index(".")
            except ValueError:
                fst_period_idx = len(decoded_words)
            sent = decoded_words[:fst_period_idx + 1]
            decoded_words = decoded_words[fst_period_idx + 1:]
            decoded_sents.append(' '.join(sent))

        # pyrouge calls a perl script that puts the data into HTML files.
        # Therefore we need to make our output HTML safe.
        decoded_sents = [make_html_safe(w) for w in decoded_sents]
        reference_sents = [make_html_safe(w) for w in reference_sents]

        ref_file = os.path.join(self._rouge_ref_dir,
                                "%06d_reference.txt" % ex_index)
        decoded_file = os.path.join(self._rouge_dec_dir,
                                    "%06d_decoded.txt" % ex_index)

        with open(ref_file, "w") as f:
            for idx, sent in enumerate(reference_sents):
                f.write(sent) if idx == len(reference_sents) - 1 else f.write(
                    sent + "\n")
        with open(decoded_file, "w") as f:
            for idx, sent in enumerate(decoded_sents):
                f.write(sent) if idx == len(decoded_sents) - 1 else f.write(
                    sent + "\n")

        print("Wrote example %i to file" % ex_index)
Ejemplo n.º 17
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        if T.cuda.is_available(): map_location = T.device('cuda')
        else: map_location = T.device('cpu')
        checkpoint = T.load(
            os.path.join(config.save_model_path, self.opt.load_model),
            map_location)
        self.model.load_state_dict(checkpoint["model_dict"])


#        mlflow.pytorch.save_model(self.model,config.save_model_path+'_2')
#        mlflow.pytorch.load_model(config.save_model_path+'_2')

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split(".")[0] + ".txt"

        with open(os.path.join("data", filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents=False):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(
                batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask,
                                       ct_e, extra_zeros,
                                       enc_batch_extend_vocab, self.model,
                                       start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)
                article_art_oovs = batch.art_oovs[i]

            #batch = self.batcher.next_batch()
            break

        load_file = self.opt.load_model  # just a model name

        #if print_sents:
        #    self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)
        Batcher.article_summary = decoded_sents[0]
        Batcher.oovs = " ".join(article_art_oovs)

        #        print('Article: ',article_sents[0], '\n==> Summary: [',decoded_sents[0],']\nOut of vocabulary: ', " ".join(article_art_oovs),'\nModel used: ', load_file)
        scores = 0  #rouge.get_scores(decoded_sents, ref_sents, avg = True)
        if self.opt.task == "test":
            print('Done.')
            #print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
Ejemplo n.º 18
0
class Train(object):
    def __init__(self, opt):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)

        self.opt = opt
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)

        time.sleep(5)

    def save_model(self, iter):
        save_path = config.save_model_path + "%07d.tar" % iter
        torch.save(
            {
                'iter': iter + 1,
                'model_dict': self.model.state_dict(),
                'training_dict': self.trainer.state_dict()
            }, save_path)

    def setup_train(self):
        self.model = Model()
        self.model = get_cuda(self.model)

        self.trainer = torch.optim.Adam(self.model.parameters(), lr=config.lr)

        start_iter = 0

        if self.opt.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.opt.load_model)
            checkpoint = torch.load(load_model_path)
            start_iter = checkpoint['iter']
            self.model.load_state_dict(checkpoint['model_dict'])
            self.trainer.load_state_dict(checkpoint['trainer_dict'])
            print("load model at" + load_model_path)

        if self.opt.new_lr is not None:
            self.trainer = torch.optim.Adam(self.model.parameters(),
                                            lr=self.opt.new_lr)
            # for params in self.traine
            # .param_groups:
            #     params['lr'] = self.opt.new_lr

        return start_iter

    def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                        extra_zeros, enc_batch_extend_vocab, batch):
        '''
        以0.25的概率使用生成token来作为输入,0.75的概率以ground-truth label作为输入。

        输入:
        enc_out: encoder的每个time step的输出。
            [batch_size, max_seq_len, 2 * hidden_dim]

        enc_hidden: encoder最后的单元的隐藏状态和记忆状态。 (h, c)
            [batch_size, hidden_dim]

        enc_padding_mask: 对encoder的输入区分padding部分和确切的输入部分。
            因为输入的时候是按照最长的单元的长度来设定的,所以在形成batch的时候进行了padding操作。
            [batch_size, max_seq_len]. 0代表填充,1代表没有填充。

        ct_e: decoder的time step对encoder进行attention操作得到的向量。
            [batch_size, 2 * hidden_dim]. 随着time step而不断的变化的。

        extra_zeros:存储oovs。
            [batch_size, max_art_oovs]

        enc_batch_extend_vocab: 输入的batch,并且里面的各个article的oov都使用了对应的temperatual oov id来表示。
            [batch_size, max_seq_len]

        batch: 输入的batch, 类 Batch的对象。
        '''

        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(batch)

        step_losses = []

        h_t = (enc_hidden[0], enc_hidden[1])

        x_t = get_cuda(torch.LongTensor(len(enc_out)).fill_(self.start_id))

        prev_s = None
        sum_temporal_srcs = None

        for t in range(min(max_dec_len, config.max_dec_steps)):
            # 对于batch中的每个article,随机生成一个数字,
            # 从而得到对应的article是否使用ground-truth label。得到0/1
            use_ground_truth = get_cuda(
                (torch.rand(len(enc_out)) > 0.25)).long()

            x_t = use_ground_truth * dec_batch[:, t] + (1 -
                                                        use_ground_truth) * x_t
            # 这里我觉得有一点不太对,
            # 因为输入x_t 的最后一个维度并不是config.vocab_size
            # 原因: 这里并不需要x_t 最后一个维度是config.vocab_size,  嵌入层会自动的将整数转换为嵌入表示,也就是在后面增加一个维度,为 emb_dim
            x_t = self.model.embeds(x_t)

            final_dist, h_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                x_t, h_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            target = target_batch[:, t]

            log_probs = torch.log(final_dist + config.eps)
            step_loss = F.nll_loss(log_probs,
                                   target,
                                   reduction='none',
                                   ignore_index=self.pad_id)

            step_losses.append(step_loss)

            # final_dist:[batch_size, config.vocab_size + batch.max_art_oovs]
            # 对得到的结果在第二个维度进行采样,将采样的数量设置为1. 返回的结果是采样的位置。
            # x_t : [batch_size, 1] --> [batch_size]
            x_t = torch.multinomial(final_dist, 1).squeeze()

            is_oov = (x_t >= config.vocab_size).long()
            # x_t: [batch_size]
            x_t = (1 - is_oov) * x_t.detach() + (is_oov) * self.unk_id

        losses = torch.sum(torch.stack(step_losses, 1), 1)

        batch_avg_loss = losses / dec_lens
        mle_loss = torch.mean(batch_avg_loss)

        return mle_loss

    # 一步迭代进行的所有的步骤
    def train_one_batch(self, batch):

        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, \
            extra_zeros, context = get_enc_data(batch)

        enc_batch = self.model.embeds(enc_batch)

        enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

        if self.opt.train_mle == 'yes':
            mle_loss = self.train_batch_MLE(enc_out, enc_hidden,
                                            enc_padding_mask, context,
                                            extra_zeros,
                                            enc_batch_extend_vocab, batch)
        else:
            mle_loss = get_cuda(torch.FloatTensor([0]))

        self.trainer.zero_grad()
        mle_loss.backward()
        self.trainer.step()

        return mle_loss.item()

    # 真正的train的迭代部分
    def train_iters(self):
        iter = self.setup_train()
        count = mle_total = 0

        while iter <= config.max_iterations:
            batch = self.batcher.next_batch()
            try:
                mle_loss = self.train_one_batch(batch)
            except KeyboardInterrupt:
                print("-------------Keyboard Interrupt------------")
                exit(0)

            mle_total += mle_loss
            mle_loss = 0
            count += 1
            iter += 1

            if iter % 1000 == 0:
                mle_avg = mle_total / count

                print('iter:', iter, 'mle_loss:', "%.3f" % mle_avg)

                count = mle_total = 0
                sys.stdout.flush()

            if iter % 2000 == 0:
                self.save_model(iter)
Ejemplo n.º 19
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(
            os.path.join(config.save_model_path, self.opt.load_model))
        self.model.load_state_dict(checkpoint["model_dict"])
        # save the light version of picked model
        '''
        print(config.save_model_path,' ','light_'+self.opt.load_model)
        save_path = config.save_model_path + '/light_'+self.opt.load_model 
        print(save_path)
        T.save({
            "model_dict": self.model.state_dict(),
        }, save_path)      
        exit()
        ''' #-- end --

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split(".")[0] + ".txt"

        with open(os.path.join("data", filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents=False):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(
                batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask,
                                       ct_e, extra_zeros,
                                       enc_batch_extend_vocab, self.model,
                                       start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents,
                                          article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)
        if self.opt.task == "test":
            print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
            print(load_file, "rouge_l:", "%.4f" % rouge_l)
            with open("test_rg.txt", "a") as f:
                f.write("\n" + load_file + " - rouge_l: " + str(rouge_l))
                f.close()
Ejemplo n.º 20
0
class Train(object):
    def __init__(self, opt):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)
        self.opt = opt
        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):

        # 训练设置,包括
        if self.opt.load_model != None:
            model_file_path = os.path.join(self.model_dir, self.opt.load_model)
        else:
            model_file_path = None

        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        if self.opt.train_mle == "yes":
            step_losses = []
            for di in range(min(max_dec_len, config.max_dec_steps)):
                y_t_1 = dec_batch[:, di]  # Teacher forcing
                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage, di)
                target = target_batch[:, di]
                gold_probs = torch.gather(final_dist, 1,
                                          target.unsqueeze(1)).squeeze()
                step_loss = -torch.log(gold_probs + config.eps)
                if config.is_coverage:
                    step_coverage_loss = torch.sum(
                        torch.min(attn_dist, coverage), 1)
                    step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                    coverage = next_coverage

                step_mask = dec_padding_mask[:, di]
                step_loss = step_loss * step_mask
                step_losses.append(step_loss)

            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / dec_lens_var
            mle_loss = torch.mean(batch_avg_loss)
        else:
            mle_loss = get_cuda(torch.FloatTensor([0]))
            # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == "yes":  # perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(
                encoder_outputs,
                encoder_hidden,
                enc_padding_mask,
                encoder_feature,
                enc_batch_extend_vocab,
                extra_zeros,
                c_t_1,
                batch.art_oovs,
                coverage,
                greedy=False)
            with torch.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(encoder_outputs,
                                                      encoder_hidden,
                                                      enc_padding_mask,
                                                      encoder_feature,
                                                      enc_batch_extend_vocab,
                                                      extra_zeros,
                                                      c_t_1,
                                                      batch.art_oovs,
                                                      coverage,
                                                      greedy=True)

            sample_reward = self.reward_function(sample_sents,
                                                 batch.original_abstracts)
            baseline_reward = self.reward_function(greedy_sents,
                                                   batch.original_abstracts)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(
                sample_reward - baseline_reward
            ) * RL_log_probs  # Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = torch.mean(rl_loss)

            batch_reward = torch.mean(sample_reward).item()
        else:
            rl_loss = get_cuda(torch.FloatTensor([0]))
            batch_reward = 0
        #loss.backward()
        (self.opt.mle_weight * mle_loss +
         self.opt.rl_weight * rl_loss).backward()
        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return mle_loss.item(), batch_reward

    def train_batch_RL(self, encoder_outputs, encoder_hidden, enc_padding_mask,
                       encoder_feature, enc_batch_extend_vocab, extra_zeros,
                       c_t_1, article_oovs, coverage, greedy):
        '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation
        Args
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param article_oovs: Batch containing list of OOVs in each example
        :param greedy: If true, performs greedy based sampling, else performs multinomial sampling
        Returns:
        :decoded_strs: List of decoded sentences
        :log_probs: Log probabilities of sampled words
        '''
        s_t_1 = self.model.reduce_state(
            encoder_hidden)  # Decoder hidden states
        y_t_1 = get_cuda(
            torch.LongTensor(len(encoder_outputs)).fill_(
                self.vocab.word2id(data.START_DECODING))
        )  # Input to the decoder                                                              #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        inds = []  # Stores sampled indices for each time step
        decoder_padding_mask = [
        ]  # 存储生成样本的填充掩码 Stores padding masks of generated samples
        log_probs = []  # Stores log probabilites of generated samples
        mask = get_cuda(
            torch.LongTensor(len(encoder_outputs)).fill_(1)
        )  # Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise

        for t in range(config.max_dec_steps):
            probs, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, t)
            if greedy is False:
                multi_dist = Categorical(probs)  # 根据概率分布进行采样
                y_t_1 = multi_dist.sample()  # perform multinomial sampling
                log_prob = multi_dist.log_prob(y_t_1)
                log_probs.append(log_prob)
            else:
                _, y_t_1 = torch.max(
                    probs, dim=1
                )  # 取概率最大的词                                                  #perform greedy sampling
            y_t_1 = y_t_1.detach()
            inds.append(y_t_1)
            mask_t = get_cuda(torch.zeros(len(encoder_outputs))
                              )  # Padding mask of batch for current time step
            mask_t[
                mask ==
                1] = 1  # If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0
            mask[
                (mask == 1) +
                (y_t_1 == self.vocab.word2id(data.STOP_DECODING)) ==
                2] = 0  # If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (y_t_1 >= config.vocab_size
                      ).long()  # Mask indicating whether sampled word is OOV
            y_t_1 = (1 - is_oov) * y_t_1 + (is_oov) * self.vocab.word2id(
                data.UNKNOWN_TOKEN)  # Replace OOVs with [UNK] token

        inds = torch.stack(inds, dim=1)
        decoder_padding_mask = torch.stack(decoder_padding_mask, dim=1)
        if greedy is False:  # If multinomial based sampling, compute log probabilites of sampled words
            log_probs = torch.stack(log_probs, dim=1)
            log_probs = log_probs * decoder_padding_mask  # Not considering sampled words with padding mask = 0
            lens = torch.sum(decoder_padding_mask,
                             dim=1)  # Length of sampled sentence
            log_probs = torch.sum(
                log_probs, dim=1
            ) / lens  # (bs,)                                     #compute normalizied log probability of a sentence
        decoded_strs = []
        for i in range(len(encoder_outputs)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = data.outputids2words(
                id_list, self.vocab,
                oovs)  # Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(data.STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(
                    S
            ) < 2:  # If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs

    def reward_function(self, decoded_sents, original_sents):
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    print("Error occured at:")
                    print("decoded_sents:", decoded_sents[i])
                    print("original_sents:", original_sents[i])
                    score = [{"rouge-1": {"p": 0.0}}]
                scores.append(score[0])
        rouge_l_p1 = [score["rouge-1"]["p"] for score in scores]
        rouge_l_p1 = get_cuda(torch.FloatTensor(rouge_l_p1))
        return rouge_l_p1

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 50 == 0:
                self.summary_writer.flush()
            print_interval = 50
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 100 == 0:
                self.save_model(running_avg_loss, iter)
Ejemplo n.º 21
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()

        # 新的架构里写在训练的decode部分
        while batch is not None:

            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.MARK_EOS)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]
            original_article = batch.original_articles[0]

            # 英文
            # write_for_rouge(original_abstract_sents, decoded_words, counter,
            #                 self._rouge_ref_dir, self._rouge_dec_dir)
            # 中文
            self.write_result(original_article, original_abstract_sents,
                              decoded_words, counter)
            counter += 1
            #     if counter % 1000 == 0:
            #         print('%d example in %d sec'%(counter, time.time() - start))
            #         start = time.time()

            batch = self.batcher.next_batch()

        # print("Decoder has finished reading dataset for single_pass.")
        # print("Now starting ROUGE eval...")
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)

    def write_result(self, original_title, reference_summarization,
                     decoded_words, ex_index):
        """
        Write output to file.

        Args:
            reference_sents: list of strings
            decoded_words: list of strings
            ex_index: int, the index with which to label the files
        """
        summarization = ''.join(decoded_words)

        # Write to file
        result_file = os.path.join(self._decode_dir, "result.txt")

        with open(result_file, 'w') as f:
            f.write(original_title + '\t\t' + reference_summarization +
                    '\t\t' + summarization + "\n")

        print("Wrote example %i to file" % ex_index)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.MARK_GO)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.MARK_UNK) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))  # 向量
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):  # 对于不同的句子
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.MARK_EOS):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 22
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.ouput_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)

        self.checkpoint_dir = os.path.join(train_dir, 'checkpoints')
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.train_summary_writer = tf.summary.create_file_writer(
            os.path.join(train_dir, 'log', 'train'))
        self.eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(train_dir, 'log', 'eval'))

    def save_model(self, model_path, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        torch.save(state, model_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(device, model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(device)

        return start_iter, start_loss

    def train_one_batch(self, batch, forcing_ratio=1):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, device)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, device)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        y_t_1_hat = None
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            # decide the next input
            if di == 0 or random.random() < forcing_ratio:
                x_t = y_t_1  # teacher forcing, use label from last time step as input
            else:
                # use embedding of UNK for all oov word
                y_t_1_hat[y_t_1_hat > self.vocab.size()] = self.vocab.word2id(
                    UNKNOWN_TOKEN)
                x_t = y_t_1_hat.flatten(
                )  # use prediction from last time step as input
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                x_t, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
            _, y_t_1_hat = final_dist.data.topk(1)
            target = target_batch[:, di].unsqueeze(1)
            step_loss = cal_NLLLoss(target, final_dist)
            if config.is_coverage:  # if not using coverge, keep coverage=None
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:,
                                         di]  # padding in target should not count into loss
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def train(self, n_iters, init_model_path=None):
        iter, avg_loss = self.setup_train(init_model_path)
        start = time.time()
        cnt = 0
        best_model_path = None
        min_eval_loss = float('inf')
        while iter < n_iters:
            s = config.forcing_ratio
            k = config.decay_to_0_iter
            x = iter
            nere_zero = 0.0001
            if config.forcing_decay_type:
                if x >= config.decay_to_0_iter:
                    forcing_ratio = 0
                elif config.forcing_decay_type == 'linear':
                    forcing_ratio = s * (k - x) / k
                elif config.forcing_decay_type == 'exp':
                    p = pow(nere_zero, 1 / k)
                    forcing_ratio = s * (p**x)
                elif config.forcing_decay_type == 'sig':
                    r = math.log((1 / nere_zero) - 1) / k
                    forcing_ratio = s / (1 + pow(math.e, r * (x - k / 2)))
                else:
                    raise ValueError('Unrecognized forcing_decay_type: ' +
                                     config.forcing_decay_type)
            else:
                forcing_ratio = config.forcing_ratio
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch, forcing_ratio=forcing_ratio)
            model_path = os.path.join(self.checkpoint_dir,
                                      'model_step_%d' % (iter + 1))
            avg_loss = calc_avg_loss(loss, avg_loss)

            if (iter + 1) % config.print_interval == 0:
                with self.train_summary_writer.as_default():
                    tf.summary.scalar(name='loss', data=loss, step=iter)
                self.train_summary_writer.flush()
                logger.info('steps %d, took %.2f seconds, train avg loss: %f' %
                            (iter + 1, time.time() - start, avg_loss))
                start = time.time()
            if config.eval_interval is not None and (
                    iter + 1) % config.eval_interval == 0:
                start = time.time()
                logger.info("Start Evaluation on model %s" % model_path)
                eval_processor = Evaluate(self.model, self.vocab)
                eval_loss = eval_processor.run_eval()
                logger.info(
                    "Evaluation finished, took %.2f seconds, eval loss: %f" %
                    (time.time() - start, eval_loss))
                with self.eval_summary_writer.as_default():
                    tf.summary.scalar(name='eval_loss',
                                      data=eval_loss,
                                      step=iter)
                self.eval_summary_writer.flush()
                if eval_loss < min_eval_loss:
                    logger.info(
                        "This is the best model so far, saving it to disk.")
                    min_eval_loss = eval_loss
                    best_model_path = model_path
                    self.save_model(model_path, eval_loss, iter)
                    cnt = 0
                else:
                    cnt += 1
                    if cnt > config.patience:
                        logger.info(
                            "Eval loss doesn't drop for %d straight times, early stopping.\n"
                            "Best model: %s (Eval loss %f: )" %
                            (config.patience, best_model_path, min_eval_loss))
                        break
                start = time.time()
            elif (iter + 1) % config.save_interval == 0:
                self.save_model(model_path, avg_loss, iter)
            iter += 1
        else:
            logger.info(
                "Training finished, best model: %s, with train loss %f: " %
                (best_model_path, min_eval_loss))
Ejemplo n.º 23
0
class Evaluate_pg(object):
    def __init__(self, model_file_path, is_word_level, is_combined, alpha):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        # self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval',
        #                        batch_size=config.batch_size, single_pass=True)
        self.dataset = DailyMailDataset("val", self.vocab)
        # time.sleep(15)
        model_name = os.path.basename(model_file_path)

        self.is_word_level = is_word_level
        self.is_combined = is_combined
        self.alpha = alpha

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)

        self.model = Model(model_file_path, is_eval=True)

    def compute_policy_grads_using_rewards(self, sentence_rewards,
                                           word_rewards, sentence_losses,
                                           word_losses, word_to_sent_ind):
        if self.is_combined:
            pg_losses = [[(self.alpha * word_reward + (1 - self.alpha) *
                           sentence_rewards[i][word_to_sent_ind[i][j]]) *
                          word_losses[i][j]
                          for j, word_reward in enumerate(abstract_rewards)
                          if j < len(word_to_sent_ind[i])]
                         for i, abstract_rewards in enumerate(word_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        elif self.is_word_level:
            pg_losses = [[
                word_reward * word_losses[i][j]
                for j, word_reward in enumerate(abstract_rewards)
                if j < len(word_to_sent_ind[i])
            ] for i, abstract_rewards in enumerate(word_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        else:
            pg_losses = [[
                rs * sentence_losses[ri][rsi] for rsi, rs in enumerate(r)
            ] for ri, r in enumerate(sentence_rewards)]
            pg_losses = [sum(pg) for pg in pg_losses]
        return pg_losses

    def compute_pg_loss(self, orig, pred, sentence_losses, split_predictions,
                        word_losses, word_to_sent_ind):
        sentence_rewards = None
        word_rewards = None
        # First compute the rewards
        if not self.is_word_level or self.is_combined:
            sentence_rewards = get_sentence_rewards(orig, pred)

        if self.is_word_level or self.is_combined:
            word_rewards = get_word_level_rewards(orig, split_predictions)

        pg_losses = self.compute_policy_grads_using_rewards(
            sentence_rewards=sentence_rewards,
            word_rewards=word_rewards,
            sentence_losses=sentence_losses,
            word_losses=word_losses,
            word_to_sent_ind=word_to_sent_ind)

        return pg_losses

    def compute_batched_loss(self, word_losses, orig, pred):
        orig_sum = []
        new_pred = []
        pred_sum = []
        sentence_losses = []

        # Convert the original sum as one single string per article
        for i in range(len(orig)):
            orig_sum.append(' '.join(map(str, orig[i])))
            new_pred.append([])
            pred_sum.append([])
            sentence_losses.append([])

        batch_sent_indices = []
        for i in range(len(pred)):
            sentence = []
            sentence = pred[i]
            losses = word_losses[i]
            sentence_indices = []
            count = 0
            while len(sentence) > 0:
                try:
                    idx = sentence.index(".")
                except ValueError:
                    idx = len(sentence)

                sentence_indices.extend([count for _ in range(idx)])

                if count > 0:
                    new_pred[i].append(new_pred[i][count - 1] +
                                       sentence[:idx + 1])
                else:
                    new_pred[i].append(sentence[:idx + 1])

                sentence_losses[i].append(sum(losses[:idx + 1]))

                sentence = sentence[idx + 1:]
                losses = losses[idx + 1:]
                count += 1
            batch_sent_indices.append(sentence_indices)

        for i in range(len(pred)):
            for j in range(len(new_pred[i])):
                pred_sum[i].append(' '.join(map(str, new_pred[i][j])))

        pg_losses = self.compute_pg_loss(orig_sum,
                                         pred_sum,
                                         sentence_losses,
                                         split_predictions=pred,
                                         word_losses=word_losses,
                                         word_to_sent_ind=batch_sent_indices)

        return pg_losses

    def eval_one_batch(self, batch):
        batch_size = batch.batch_size

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        output_ids = []
        y_t_1 = torch.ones(batch_size, dtype=torch.long) * self.vocab.word2id(
            data.START_DECODING)

        if config.use_gpu:
            y_t_1 = y_t_1.cuda()

        for _ in range(batch_size):
            output_ids.append([])
            step_losses.append([])

        for di in range(min(max_dec_len, config.max_dec_steps)):
            #y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)  #NLL
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask

            # Move on to the next token
            _, idx = torch.max(final_dist, 1)
            idx = idx.reshape(batch_size, -1).squeeze()
            y_t_1 = idx

            for i, pred in enumerate(y_t_1):
                if not pred.item() == data.PAD_TOKEN:
                    output_ids[i].append(pred.item())

            for i, loss in enumerate(step_loss):
                step_losses[i].append(step_loss[i])

        # Obtain the original and predicted summaries
        original_abstracts = batch.original_abstracts_sents
        predicted_abstracts = [
            data.outputids2words(ids, self.vocab, None) for ids in output_ids
        ]

        # Compute the batched loss
        batched_losses = self.compute_batched_loss(step_losses,
                                                   original_abstracts,
                                                   predicted_abstracts)
        losses = torch.stack(batched_losses)
        losses = losses / dec_lens_var

        loss = torch.mean(losses)

        return loss.item()

    def run_eval(self, model_dir, train_iter_id):
        dataloader = DataLoader(self.dataset,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=1,
                                collate_fn=create_batch_collate(
                                    self.vocab, config.batch_size))
        running_avg_loss, iter = 0, 0
        start = time.time()
        # batch = self.batcher.next_batch()
        pg_losses = []
        run_avg_losses = []
        for batch in dataloader:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     iter)
            print("Iteration:", iter, "  loss:", loss, "  Running avg loss:",
                  running_avg_loss)
            iter += 1

            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start,
                       running_avg_loss))
                start = time.time()

            pg_losses.append(loss)
            run_avg_losses.append(running_avg_loss)

        # Dump val losses
        pickle.dump(
            pg_losses,
            open(
                os.path.join(model_dir,
                             'val_pg_losses_{}.p'.format(train_iter_id)),
                'wb'))
        pickle.dump(
            run_avg_losses,
            open(
                os.path.join(model_dir,
                             'val_run_avg_losses_{}.p'.format(train_iter_id)),
                'wb'))

        return run_avg_losses
Ejemplo n.º 24
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size = config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='eval',
                               batch_size=batch_size, single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(os.path.join(config.save_model_path, self.opt.load_model))
        self.model.load_state_dict(checkpoint["model_dict"])

    # self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)
    def print_original_predicted(self, decoded_sents, ref_sents, article_sents, loadfile):
        # filename = "test_"+loadfile.split(".")[0]+".txt"
        filename = "E:\\4160\\TextSummarization\\data\\test_output.txt"
        print(f"file name {filename}")
        # with open(os.path.join("data",filename), "w") as f:
        with open(filename, "w+", encoding="utf-8") as f:
            for i in range(len(decoded_sents)):
                f.write("article: "+article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents = True):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()

        # start_time = time.time()
        # test_batch = 0
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            # test_batch_time = time.time() - start_time
            # start_time = time.time()
            # print("Current test batch", test_batch)
            # print("Testing time", test_batch_time)
            # test_batch += 1

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg = True)
        if self.opt.task == "test":
            print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
            print(load_file, "rouge_l:", "%.4f" % rouge_l)
Ejemplo n.º 25
0
        vocab_word = vocab._id_to_word[i + 4]
        w2vec_word = w2vec.wv.index2entity[i]
    except Exception as e:
        continue
    if i + 4 > vocab_size: break
    #     print(vocab_word,w2vec_word)
    weight[i + 4, :] = torch.from_numpy(w2vec.wv.vectors[i])

embedding = torch.nn.Embedding.from_pretrained(weight)
# requires_grad指定是否在训练过程中对词向量的权重进行微调
embedding.weight.requires_grad = True
embedding

# In[53]:

vocab.word2id('the')

# # Embedding/glove

# In[54]:

from glove import Glove
from glove import Corpus

vocab_count = 50000
# write vocab to file
if not os.path.exists('Embedding/category/glove'):
    os.makedirs('Embedding/category/glove')

# In[55]:
class Main(object):
    def __init__(self):
        self.vocab = Vocab(VOCAB_PATH, VOCAB_SIZE)
        self.batcher = Batcher(TRAIN_DATA_PATH, self.vocab, mode = 'train',batch_size = BATCH_SIZE, single_pass = False)
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        self.model = MyModel().to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=LR)


    def beamSearch(self,enc_hid, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab):
        batch_size = len(enc_hid[0])
        beam_idx = torch.LongTensor(list(range(batch_size)))
        beams = [Beam(self.start_id, self.end_id, self.unk_id, (enc_hid[0][i], enc_hid[1][i]), ct_e[i]) for i in range(batch_size)]
        n_rem = batch_size
        sum_exp = None
        prev_s = None

        for t in range(MAX_DEC_STEPS):
            x_t = torch.stack(
                [beam.getTokens() for beam in beams if beam.done == False]  # remaining(rem),beam
            ).contiguous().view(-1)
            x_t = self.model.embeds(x_t)

            dec_h = torch.stack(
                [beam.hid_h for beam in beams if beam.done == False]
            ).contiguous().view(-1, HIDDEN_DIM)
            dec_c = torch.stack(
                [beam.hid_c for beam in beams if beam.done == False]
            ).contiguous().view(-1, HIDDEN_DIM)

            ct_e = torch.stack(
                [beam.context for beam in beams if beam.done == False]
            ).contiguous().view(-1, 2 * HIDDEN_DIM)

            if sum_exp is not None:
                sum_exp = torch.stack(
                    [beam.sum_exp for beam in beams if beam.done == False]
                ).contiguous().view(-1, enc_out.size(1))

            if prev_s is not None:
                prev_s = torch.stack(
                    [beam.prev_s for beam in beams if beam.done == False]
                )
                # try:
                prev_s = prev_s.contiguous().view(-1, t + 1, HIDDEN_DIM)
                # print(prev_s.shape)
                # except:
                #     continue
            s_t = (dec_h, dec_c)
            enc_out_beam = enc_out[beam_idx].view(n_rem, -1).repeat(1, BEAM_SIZE).view(-1, enc_out.size(1),
                                                                                              enc_out.size(2))
            enc_pad_mask_beam = enc_padding_mask[beam_idx].repeat(1, BEAM_SIZE).view(-1,
                                                                                            enc_padding_mask.size(1))

            extra_zeros_beam = None
            if extra_zeros is not None:
                extra_zeros_beam = extra_zeros[beam_idx].repeat(1, BEAM_SIZE).view(-1, extra_zeros.size(1))

            enc_extend_vocab_beam = enc_batch_extend_vocab[beam_idx].repeat(1, BEAM_SIZE).view(-1,enc_batch_extend_vocab.size(1))
            # print(enc_out_beam.shape)
            # try:
            final_dist, (dec_h, dec_c), ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_out_beam,enc_pad_mask_beam, ct_e,extra_zeros_beam,enc_extend_vocab_beam,sum_exp,prev_s)

            # except:
            #     continue
            # print(prev_s.shape)
            final_dist = final_dist.view(n_rem, BEAM_SIZE, -1)
            dec_h = dec_h.view(n_rem, BEAM_SIZE, -1)
            dec_c = dec_c.view(n_rem, BEAM_SIZE, -1)
            ct_e = ct_e.view(n_rem, BEAM_SIZE, -1)

            if sum_exp is not None:
                sum_exp = sum_exp.view(n_rem, BEAM_SIZE, -1)  # rem, beam, n_seq

            if prev_s is not None:
                prev_s = prev_s.view(n_rem, BEAM_SIZE, -1, HIDDEN_DIM)  # rem, beam, t
                # print("prev_s",prev_s.shape)

            active = []

            for i in range(n_rem):
                b = beam_idx[i].item()
                beam = beams[b]
                if beam.done:
                    continue

                sum_exp_i = prev_s_i = None
                if sum_exp is not None:
                    sum_exp_i = sum_exp[i]
                if prev_s is not None:
                    prev_s_i = prev_s[i]
                beam.advance(final_dist[i], (dec_h[i], dec_c[i]), ct_e[i], sum_exp_i, prev_s_i)
                if beam.done == False:
                    active.append(b)

            # print(len(active))
            if len(active) == 0:
                break

            beam_idx = torch.LongTensor(active)
            n_rem = len(beam_idx)

        predicted_words = []
        for beam in beams:
            predicted_words.append(beam.getBest())
        # print(len(predicted_words[0]))
        return predicted_words

    #获取encode的数据
    def getEncData(self,batch):
        batch_size = len(batch.enc_lens)
        enc_batch = torch.from_numpy(batch.enc_batch).long()
        enc_padding_mask = torch.from_numpy(batch.enc_padding_mask).float()

        enc_lens = batch.enc_lens

        ct_e = torch.zeros(batch_size, 2 * HIDDEN_DIM)

        enc_batch = enc_batch.to(DEVICE)
        enc_padding_mask = enc_padding_mask.to(DEVICE)

        ct_e = ct_e.to(DEVICE)

        enc_batch_extend_vocab = None
        if batch.enc_batch_extend_vocab is not None:
            enc_batch_extend_vocab = torch.from_numpy(batch.enc_batch_extend_vocab).long()
            enc_batch_extend_vocab = enc_batch_extend_vocab.to(DEVICE)

        extra_zeros = None
        if batch.max_art_oovs > 0:
            extra_zeros = torch.zeros(batch_size, batch.max_art_oovs)
            extra_zeros = extra_zeros.to(DEVICE)

        return enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e

    #获取decode数据
    def getDecData(self,batch):
        dec_batch = torch.from_numpy(batch.dec_batch).long()
        dec_lens = batch.dec_lens
        max_dec_len = np.max(dec_lens)
        dec_lens = torch.from_numpy(batch.dec_lens).float()

        target_batch = torch.from_numpy(batch.target_batch).long()

        dec_batch = dec_batch.to(DEVICE)
        dec_lens = dec_lens.to(DEVICE)
        target_batch = target_batch.to(DEVICE)

        return dec_batch, max_dec_len, dec_lens, target_batch

    #最大似然训练
    def trainMLEStep(self, batch):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = self.getEncData(batch)
        # print(enc_lens)
        enc_batch = self.model.embeds(enc_batch)
        #将输入直接传入编码器
        enc_output,enc_hidden = self.model.encoder(enc_batch,enc_lens)
        # print(enc_output.shape)
        dec_batch, max_dec_len, dec_lens, target_batch = self.getDecData(batch)
        total_loss = 0
        s_t = (enc_hidden[0], enc_hidden[1])
        x_t =torch.LongTensor(len(enc_output)).fill_(self.start_id).to(DEVICE)
        # print(x_t.shape)
        prev_s = None
        sum_exp = None
        #最大解码步数,每次解码一个单词
        for t in range(min(max_dec_len, MAX_DEC_STEPS)):
            # print(max_dec_len)
            choice = (torch.rand(len(enc_output)) > 0.25).long().to(DEVICE)
            #选取部分decoder_input和一部分
            x_t = choice * dec_batch[:, t] + (1 - choice) * x_t
            x_t = self.model.embeds(x_t)
            # print(x_t.shape)
            final_dist, s_t, ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_output, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_exp, prev_s)
            # print(prev_s.shape)
            target = target_batch[:, t]
            # print(target_batch.shape)
            log_probs = torch.log(final_dist + EPS)
            step_loss = F.nll_loss(log_probs, target, reduction="none", ignore_index=self.pad_id)
            total_loss = total_loss + step_loss
            x_t = torch.multinomial(final_dist, 1).squeeze()
            is_oov = (x_t >= VOCAB_SIZE).long()
            x_t = (1 - is_oov) * x_t.detach() + (is_oov) * self.unk_id

        batch_avg_loss = total_loss / dec_lens
        loss = torch.mean(batch_avg_loss)
        return loss

    #强化学习
    def trainRLStep(self,enc_output, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, article_oovs, type):


        s_t = enc_hidden
        x_t = torch.LongTensor(len(enc_output)).fill_(self.start_id).to(DEVICE)
        prev_s = None
        sum_exp = None
        inds = []
        decoder_padding_mask = []
        log_probs = []
        mask = torch.LongTensor(len(enc_output)).fill_(1).to(DEVICE)

        for t in range(MAX_DEC_STEPS):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_output, enc_padding_mask, ct_e,
                                                                             extra_zeros, enc_batch_extend_vocab,
                                                                             sum_exp, prev_s)
            if type == "sample":
                #根据概率产生sample
                multi_dist = Categorical(probs)
                # print(multi_dist)
                x_t = multi_dist.sample()
                # print(x_t.shape)
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                #greedy sample
                _, x_t = torch.max(probs, dim=1)
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = torch.zeros(len(enc_output)).to(DEVICE)
            mask_t[mask == 1] = 1
            mask[(mask == 1) + (x_t == self.end_id) == 2] = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t >= VOCAB_SIZE).long()
            #判断是否有超限,若有则用UNK
            x_t = (1 - is_oov) * x_t + (is_oov) * self.unk_id

        inds = torch.stack(inds, dim=1)
        decoder_padding_mask = torch.stack(decoder_padding_mask, dim=1)
        if type == "sample":
            log_probs = torch.stack(log_probs, dim=1)
            #将pad的去除
            log_probs = log_probs * decoder_padding_mask
            lens = torch.sum(decoder_padding_mask, dim=1)
            #对应公式15 logp
            log_probs = torch.sum(log_probs,dim=1) / lens
            # print(log_prob.shape)
        decoded_strs = []
        #将output的id转换为word
        for i in range(len(enc_output)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = data.outputids2words(id_list, self.vocab, oovs)  # Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(data.STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(S) < 2:
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs


    def rewardFunction(self, decoded_sents, original_sents):
        rouge = Rouge()
        scores = rouge.get_scores(decoded_sents, original_sents)
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        rouge_l_f1 = (torch.FloatTensor(rouge_l_f1)).to(DEVICE)
        return rouge_l_f1



    #利用beamSearch测试
    def test(self):
        # time.sleep(5)

        batcher = Batcher(TEST_DATA_PATH, self.vocab, mode='test',batch_size=BATCH_SIZE, single_pass=True)
        batch = batcher.next_batch()
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        count = 0
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = self.getEncData(batch)
            with torch.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with torch.autograd.no_grad():
                pred_ids = self.beamSearch(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab)
                # print(len(pred_ids[0]))
            for i in range(len(pred_ids)):
                # print('t',pred_ids[i])
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                # print(decoded_words)
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)
            # print(decoded_sents)
            batch = batcher.next_batch()
            scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)
            #统计结果
            if count == 1:
                k0_sum = scores[KEYS[0]]
                k1_sum = scores[KEYS[1]]
                k2_sum = scores[KEYS[2]]

            if count > 1:
                k0_sum = dict(Counter(Counter(k0_sum) + Counter(scores[KEYS[0]])))
                k1_sum = dict(Counter(Counter(k1_sum) + Counter(scores[KEYS[1]])))
                k2_sum = dict(Counter(Counter(k2_sum) + Counter(scores[KEYS[2]])))
            if count == 10:
                break

            count += 1


        # print(scores)
        print(KEYS[0], end=' ')
        for k in k0_sum:
            print(k,k0_sum[k] / count,end = ' ')
        print('\n')
        print(KEYS[1],end = ' ')
        for k in k1_sum:
            print(k,k1_sum[k] / count,end = ' ')
        print('\n')
        print(KEYS[2], end=' ')
        for k in k2_sum:
            print(k,k2_sum[k] / count,end = ' ')
        print('\n')






    def train(self):
        iter = 1
        count = 0
        total_loss = 0
        total_reward = 0
        while iter <= MAX_ITERATIONS:
            batch = self.batcher.next_batch()

            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = self.getEncData(batch)

            enc_batch = self.model.embeds(enc_batch)  # Get embeddings for encoder input
            enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)
            # print(enc_out.shape)
            #将enc_batch传入
            mle_loss = self.trainMLEStep(batch)
            sample_sents, RL_log_probs = self.trainRLStep(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs,"sample")
            with torch.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.trainRLStep(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs, "greedy")

            sample_reward = self.rewardFunction(sample_sents, batch.original_abstracts)
            baseline_reward = self.rewardFunction(greedy_sents, batch.original_abstracts)
            #公式15
            rl_loss = -(sample_reward - baseline_reward) * RL_log_probs
            rl_loss = torch.mean(rl_loss)

            batch_reward = torch.mean(sample_reward).item()
            loss = LAMBDA * mle_loss + (1 - LAMBDA) * rl_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            #计算批loss值
            total_loss += loss.item()
            #计算批reward值
            total_reward += batch_reward
            count += 1
            iter += 1

            if iter % PRINT_PER_ITER == 0:


                loss_avg = total_loss / count
                reward_avg = total_reward / count
                total_loss = 0
                print("iter:", iter, "loss:", "%.3f" % loss_avg ,"reward:", "%.3f" % reward_avg)
                count = 0


            if iter % TEST_PER_ITER == 0:
                self.test()
Ejemplo n.º 27
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            with torch.no_grad():
                best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            print("===============SUMMARY=============")
            print(' '.join(decoded_words))

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)


    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, sent_lens = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)
        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        gamma = None
        if config.is_sentence_filtering:
            gamma, sent_dists = self.model.sentence_filterer(encoder_outputs, sent_lens)

        section_outputs, section_hidden = self.model.section_encoder(s_t_0)
        s_t_0 = self.model.section_reduce_state(section_hidden)

        dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in range(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, section_outputs, enc_padding_mask,
                                                        c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, gamma)

            topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item(),
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 28
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.concept_vocab = Concept_vocab(config.concept_vocab_path,
                                           config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               self.concept_vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        self.ds_batcher = Batcher(config.train_ds_data_path,
                                  self.vocab,
                                  self.concept_vocab,
                                  mode='train',
                                  batch_size=500,
                                  single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = AdagradCustom(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def calc_Rouge_1(self, sub, string):
        new_sub = [str(x) for x in sub]
        new_sub.insert(0, '"')
        new_sub.append('"')
        token_c = ' '.join(new_sub)
        summary = [[token_c]]
        new_string = [str(x) for x in string]
        new_string.insert(0, '"')
        new_string.append('"')
        token_r = ' '.join(new_string)
        reference = [[[token_r]]]

        rouge = Pythonrouge(summary_file_exist=False,
                            summary=summary,
                            reference=reference,
                            n_gram=2,
                            ROUGE_SU4=False,
                            ROUGE_L=False,
                            recall_only=True,
                            stemming=True,
                            stopwords=True,
                            word_level=True,
                            length_limit=True,
                            length=30,
                            use_cf=False,
                            cf=95,
                            scoring_formula='average',
                            resampling=False,
                            samples=10,
                            favor=True,
                            p=0.5)
        score = rouge.calc_score()
        return score['ROUGE-1']

    def calc_Rouge_2_recall(self, sub, string):
        token_c = sub
        token_r = string
        model = []
        ref = []
        if len(string) == 1 or len(string) == 1:
            score = 0.0
        else:
            i = 1
            while i < len(string):
                ref.append(str(token_r[i - 1]) + str(token_r[i]))
                i += 1
            i = 1
            while i < len(sub):
                model.append(str(token_c[i - 1]) + str(token_c[i]))
                i += 1
            sam = 0
            i = 0
            for i in range(len(ref)):
                for j in range(len(model)):
                    if ref[i] == model[j]:
                        sam += 1
                        model[j] = '-1'
                        break

            score = sam / float(len(ref))

        return score

    def calc_Rouge_L(self, sub, string):
        beta = 1.
        token_c = sub
        token_r = string
        if (len(string) < len(sub)):
            sub, string = string, sub
        lengths = [[0 for i in range(0,
                                     len(sub) + 1)]
                   for j in range(0,
                                  len(string) + 1)]
        for j in range(1, len(sub) + 1):
            for i in range(1, len(string) + 1):
                if (string[i - 1] == sub[j - 1]):
                    lengths[i][j] = lengths[i - 1][j - 1] + 1
                else:
                    lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
        lcs = lengths[len(string)][len(sub)]

        prec = lcs / float(len(token_c))
        rec = lcs / float(len(token_r))

        if (prec != 0 and rec != 0):
            score = ((1 + beta**2) * prec * rec) / float(rec + beta**2 * prec)
        else:
            score = 0.0
        return rec

    def calc_kl(self, dec, enc):
        kl = 0.
        dec = np.exp(dec)
        enc = np.exp(enc)
        all_dec = np.sum(dec)
        all_enc = np.sum(enc)
        for d, c in zip(dec, enc):
            d = d / all_dec
            c = c / all_enc
            kl = kl + c * np.log(c / d)
        return kl

    def calc_euc(self, dec, enc):
        euc = 0.
        for d, c in zip(dec, enc):
            euc = euc + np.sqrt(np.square(d - c))
        #print euc
        return euc

    def ds_loss(self, enc_batch_ds_emb, enc_padding_mask_ds, dec_batch_emb,
                dec_padding_mask):
        b1, t_k1, emb1 = list(enc_batch_ds_emb.size())
        b2, t_k2, emb2 = list(dec_batch_emb.size())
        enc_padding_mask_ds = enc_padding_mask_ds.unsqueeze(2).expand(
            b1, t_k1, emb1).contiguous()
        dec_padding_mask = dec_padding_mask.unsqueeze(2).expand(
            b2, t_k2, emb2).contiguous()
        enc_batch_ds_emb = enc_batch_ds_emb * enc_padding_mask_ds
        dec_batch_emb = dec_batch_emb * dec_padding_mask
        enc_batch_ds_emb = torch.sum(enc_batch_ds_emb, 1)
        dec_batch_emb = torch.sum(dec_batch_emb, 1)
        dec_title = dec_batch_emb.tolist()
        enc_article = enc_batch_ds_emb.tolist()
        dec_title_len = len(dec_title)
        enc_article_len = len(enc_article)
        dsloss = 0.
        for dec in dec_title:
            for enc in enc_article:
                dsloss = dsloss + self.calc_kl(dec, enc)
        dsloss = dsloss / float(dec_title_len * enc_article_len)
        print(dsloss)
        return dsloss

    def train_one_batch(self, batch, steps, batch_ds):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        enc_batch_ds, enc_padding_mask_ds, enc_lens_ds, _, _, _, _, _, _, _, _ = \
            get_input_from_batch(batch_ds, use_cuda)

        self.optimizer.zero_grad()
        encoder_outputs, encoder_hidden, max_encoder_output, enc_batch_ds_emb, dec_batch_emb = self.model.encoder(
            enc_batch, enc_lens, enc_batch_ds, dec_batch)
        if config.DS_train:
            ds_final_loss = self.ds_loss(enc_batch_ds_emb, enc_padding_mask_ds,
                                         dec_batch_emb, dec_padding_mask)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        s_t_0 = s_t_1
        c_t_0 = c_t_1
        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output
            c_t_0 = c_t_1

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        if config.DS_train:
            ds_final_loss = Variable(torch.FloatTensor([ds_final_loss]),
                                     requires_grad=False)
            ds_final_loss = ds_final_loss.cuda()
            loss = (config.pi - ds_final_loss) * torch.mean(batch_avg_loss)
        else:
            loss = torch.mean(batch_avg_loss)
        if steps > config.traintimes:
            scores = []
            sample_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    sample_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    sample_latest_tokens = sample_y[-1]
                    sample_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in sample_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(sample_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                sample_select = torch.multinomial(final_dist, 1).view(-1)
                sample_log_probs = torch.gather(
                    final_dist, 1, sample_select.unsqueeze(1)).squeeze()
                sample_y.append(sample_select.cpu().numpy().tolist())
                sample_step_loss = -torch.log(sample_log_probs + config.eps)
                sample_step_mask = dec_padding_mask[:, di]
                sample_step_loss = sample_step_loss * sample_step_mask
                scores.append(sample_step_loss)
            sample_sum_losses = torch.sum(torch.stack(scores, 1), 1)
            sample_batch_avg_loss = sample_sum_losses / dec_lens_var

            sample_y = np.transpose(sample_y).tolist()

            base_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    base_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    base_latest_tokens = base_y[-1]
                    base_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in base_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(base_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                base_log_probs, base_ids = torch.topk(final_dist, 1)
                base_y.append(base_ids[:, 0].cpu().numpy().tolist())

            base_y = np.transpose(base_y).tolist()

            refs = dec_batch.cpu().numpy().tolist()
            sample_dec_lens_var = map(int, dec_lens_var.cpu().numpy().tolist())
            sample_rougeL = [
                self.calc_Rouge_L(sample[:reflen],
                                  ref[:reflen]) for sample, ref, reflen in zip(
                                      sample_y, refs, sample_dec_lens_var)
            ]
            base_rougeL = [
                self.calc_Rouge_L(base[:reflen], ref[:reflen])
                for base, ref, reflen in zip(base_y, refs, sample_dec_lens_var)
            ]
            sample_rougeL = Variable(torch.FloatTensor(sample_rougeL),
                                     requires_grad=False)
            base_rougeL = Variable(torch.FloatTensor(base_rougeL),
                                   requires_grad=False)
            sample_rougeL = sample_rougeL.cuda()
            base_rougeL = base_rougeL.cuda()
            word_loss = -sample_batch_avg_loss * (base_rougeL - sample_rougeL)
            reinforce_loss = torch.mean(word_loss)
            loss = (1 - config.rein) * loss + config.rein * reinforce_loss

        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            batch_ds = self.ds_batcher.next_batch()
            loss = self.train_one_batch(batch, iter, batch_ds)
            loss = loss.cpu()

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 5
            if iter % print_interval == 0:
                print('steps %d , loss: %f' % (iter, loss))
                start = time.time()
            if iter % 50000 == 0:
                self.save_model(running_avg_loss, iter)
Ejemplo n.º 29
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        '''self.batcher = Batcher(config.oped_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)'''
        self.batches = self.read_opeds(config.oped_data_path, self.vocab,
                                       config.beam_size)

        self.model = Model(model_file_path, is_eval=True)

    def read_opeds(self, config_path, vocab, beam_size):
        file_list = glob.glob(config_path)
        #file_list = os.listdir(config_path)
        batch_list = []
        for file in file_list:
            with open(file, 'rb') as f:
                text = f.read().lower().decode('utf-8')
                text = re.sub('\n', '', text)
                text = re.sub(r'([.,!?()"])', r' \1 ', text).encode('utf-8')
                print(text)

                ex = Example(text, [], vocab)

                # text = text.split()
                # if len(text) > config.max_enc_steps:
                #     text = text[:config.max_enc_steps]
                # enc_input = [vocab.word2id(w.decode('utf-8')) for w in text]
                # assert(sum(enc_input) != 0)

                enc_input = [ex for _ in range(beam_size)]
                batch = Batch(enc_input, vocab, beam_size)
                batch_list.append(batch)
                print(batch.enc_batch)
        return batch_list

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        for batch in self.batches:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            write_results(decoded_words, counter, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()
        '''print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)'''

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 30
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)
        self.last_good_model_save_path = None 

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        # save the path to the last model that was not nan
        if (not math.isnan(running_avg_loss)):
            self.last_good_model_save_path = model_save_path 
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)
        self.last_good_model_save_path = model_file_path

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab,
                                                                           coverage, di)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            # calculate copy loss
            vocab_zero = Variable(torch.zeros(self.model.decoder.vocab_dist_.shape, dtype=torch.float))
            if use_cuda:
                vocab_zero = vocab_zero.cuda()
            if extra_zeros is not None:
                vocab_zero = torch.cat([vocab_zero, extra_zeros], 1)
            attn_dist_ = (1 - p_gen) * attn_dist
            attn_expanded = vocab_zero.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
            vocab_zero[:, self.vocab.word2id('[UNK]')] = 1.0
            # Not sure whether we want to add loss for the extra vocab indices
            #vocab_zero[:, config.vocab_size:] = 1.0
            y_unk_neg = 1.0 - vocab_zero
            copyloss=torch.bmm(y_unk_neg.unsqueeze(1), attn_expanded.unsqueeze(2))
            
            # add copy loss with lambda 2 weight
            step_loss = step_loss + config.copy_loss_wt * copyloss
                
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)


        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

            if (math.isnan(running_avg_loss)):
                print('Found a nan loss return. Restarting the training at {}' \
                        .format(self.last_good_model_save_path))
                iter, running_avg_loss = self.setup_train(self.last_good_model_save_path)
                start = time.time()

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval,
                                                                           time.time() - start, loss))
                start = time.time()
            if iter % 1000 == 0:
                self.save_model(running_avg_loss, iter)