Ejemplo n.º 1
0
def load_batches_decode():

    vocab   = Vocab(config.vocab_path, config.vocab_size)
    batcher = Batcher(config.decode_data_path, vocab, mode='decode',
                           batch_size=config.beam_size, single_pass=True)

    batches = [None for _ in range(TEST_DATA_SIZE)]
    for i in range(TEST_DATA_SIZE):
        batch = batcher.next_batch()
        batches[i] = batch

    with open("lib/data/batches_test.vocab{}.beam{}.pk.bin".format(vocab.size(), config.beam_size), "wb") as f:
        pickle.dump(batches, f)
Ejemplo n.º 2
0
def load_batches_train():

    vocab   = Vocab(config.vocab_path, config.vocab_size)
    batcher = Batcher(config.decode_data_path, vocab, mode='train',
                           batch_size=config.batch_size, single_pass=False)

    TRAIN_DATA_SIZE = 287226
    num_batches = int(TRAIN_DATA_SIZE / config.batch_size)
    batches = [None for _ in range(num_batches)]
    for i in tqdm(range(num_batches)):
        batch = batcher.next_batch()
        batches[i] = batch

    with open("lib/data/batches_train.vocab{}.batch{}.pk.bin".format(vocab.size(), config.batch_size), "wb") as f:
        pickle.dump(batches, f)
Ejemplo n.º 3
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        '''self.batcher = Batcher(config.oped_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)'''
        self.batches = self.read_opeds(config.oped_data_path, self.vocab,
                                       config.beam_size)

        self.model = Model(model_file_path, is_eval=True)

    def read_opeds(self, config_path, vocab, beam_size):
        file_list = glob.glob(config_path)
        #file_list = os.listdir(config_path)
        batch_list = []
        for file in file_list:
            with open(file, 'rb') as f:
                text = f.read().lower().decode('utf-8')
                text = re.sub('\n', '', text)
                text = re.sub(r'([.,!?()"])', r' \1 ', text).encode('utf-8')
                print(text)

                ex = Example(text, [], vocab)

                # text = text.split()
                # if len(text) > config.max_enc_steps:
                #     text = text[:config.max_enc_steps]
                # enc_input = [vocab.word2id(w.decode('utf-8')) for w in text]
                # assert(sum(enc_input) != 0)

                enc_input = [ex for _ in range(beam_size)]
                batch = Batch(enc_input, vocab, beam_size)
                batch_list.append(batch)
                print(batch.enc_batch)
        return batch_list

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        for batch in self.batches:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            write_results(decoded_words, counter, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()
        '''print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)'''

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
class Decoder(object):
    def __init__(self):
        self.vocab = Vocab(args.vocab_path, args.vocab_size)
        self.batcher = Batcher(
            args.decode_data_path,
            self.vocab,
            mode='decode',
            batch_size=1,
            single_pass=True)  # support only 1 item at a time
        time.sleep(15)
        vocab_size = self.vocab.size()
        self.beam_size = args.beam_size
        # self.bertClient = BertClient()
        self.encoder = EncoderLSTM(args.hidden_size, self.vocab.size())
        self.decoder = DecoderLSTM(args.hidden_size, self.vocab.size())
        if use_cuda:
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()

        # Prepare the output folder and files
        output_dir = os.path.join(args.logs, "outputs")
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        output_file = os.path.join(output_dir,
                                   "decoder_{}.txt".format(args.output_name))
        self.file = open(output_file, "w+")

    def load_model(self, checkpoint_file):
        print("Loading Checkpoint: ", checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        encoder_state_dict = checkpoint['encoder_state_dict']
        decoder_state_dict = checkpoint['decoder_state_dict']
        self.encoder.load_state_dict(encoder_state_dict)
        self.decoder.load_state_dict(decoder_state_dict)
        self.encoder.eval()
        self.decoder.eval()
        print("Weights Loaded")

    def decode(self):
        batch = self.batcher.next_batch()
        count = 0

        while batch is not None:
            # Preparing
            enc_batch, enc_padding_mask, enc_lens = get_input_from_batch(
                batch, use_cuda=use_cuda)
            dec_batch, target_batch, dec_padding_mask, max_dec_len = get_output_from_batch(
                batch, use_cuda=use_cuda)
            batch_size = len(enc_batch)
            proba_list = [0] * batch_size
            generated_list = [[]]

            # Encoding sentences
            outputs, hidden_state = self.encoder(enc_batch)

            x = torch.LongTensor([2])  # start sequence
            if use_cuda:
                x = x.cuda()
            """ Normal Approach """
            #answers = torch.ones((batch_size, args.max_dec_steps), dtype=torch.long)
            #for t in range(max(args.max_dec_steps, max_dec_len)):
            #    output, hidden_state = self.decoder(x, hidden_state)  # Output: batch * vocab_size (prob.)
            #    idx = torch.argmax(output, dim=1)
            #    answers[:, t] = idx.detach()
            #    x = idx
            """ Beam Approach """
            for t in range(max(args.max_dec_steps, max_dec_len) - 1):
                output, hidden_state = self.decoder(x, hidden_state)

                # For each sentence, find b best answers (beam search)
                states = []  # (probab, generated, hidden_state_index)
                for i, each_decode in enumerate(output):
                    prev_proba = proba_list[i]
                    prev_generated = generated_list[i]
                    arr = each_decode.detach().cpu().numpy(
                    )  # log-probab of each word
                    indices = arr.argsort()[-self.beam_size:][::-1]  #index
                    for idx in indices:
                        proba = arr[idx] + proba_list[i]  # new probab and prev
                        generated = prev_generated.copy()
                        generated.append(idx)
                        states.append((proba, generated, i))

                # Sort for the best generated sequence among all
                states.sort(key=lambda x: x[0], reverse=True)

                # Variables
                new_proba_list = []
                new_generated = []
                new_hidden = torch.Tensor()
                new_cell = torch.Tensor()
                new_x = torch.LongTensor()

                if use_cuda:
                    new_hidden = new_hidden.cuda()
                    new_cell = new_cell.cuda()
                    new_x = new_x.cuda()

                # Select top b sequences
                for state in states[:self.beam_size]:
                    new_proba_list.append(state[0])
                    new_generated.append(state[1])
                    idx = state[2]

                    h_0 = hidden_state[0].squeeze(0)[idx].unsqueeze(0)
                    c_0 = hidden_state[1].squeeze(0)[idx].unsqueeze(0)
                    new_hidden = torch.cat((new_hidden, h_0), dim=0)
                    new_cell = torch.cat((new_cell, c_0), dim=0)
                    generated_idx = torch.LongTensor([state[1][-1]])
                    if use_cuda:
                        generated_idx = generated_idx.cuda()

                    new_x = torch.cat((new_x, generated_idx))

                # Save the list
                proba_list = new_proba_list
                generated_list = new_generated
                hidden_state = (new_hidden.unsqueeze(0), new_cell.unsqueeze(0))
                x = new_x

            # Convert from id to word
            # answer = answers[0].numpy()
            answer = new_generated[0]
            sentence = ids2words(answer, self.vocab)
            self.file.write("{}\n".format(sentence))
            print("Writing line #{} to file ...".format(count + 1))
            self.file.flush()
            sys.stdout.flush()

            count += 1
            batch = self.batcher.next_batch()
Ejemplo n.º 5
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            with torch.no_grad():
                best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            print("===============SUMMARY=============")
            print(' '.join(decoded_words))

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)


    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0, sent_lens = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)
        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        gamma = None
        if config.is_sentence_filtering:
            gamma, sent_dists = self.model.sentence_filterer(encoder_outputs, sent_lens)

        section_outputs, section_hidden = self.model.section_encoder(s_t_0)
        s_t_0 = self.model.section_reduce_state(section_hidden)

        dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in range(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, section_outputs, enc_padding_mask,
                                                        c_t_1, extra_zeros, enc_batch_extend_vocab, coverage_t_1, gamma)

            topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item(),
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 6
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.decode_dir, model_name)
        self._decode_dir = os.path.splitext(self._decode_dir)[0]
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            Path(p).mkdir(parents=True, exist_ok=True)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.pad_id = self.vocab.word2id(PAD_TOKEN)
        self.start_id = self.vocab.word2id(START_DECODING)
        self.stop_id = self.vocab.word2id(STOP_DECODING)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def if_already_exists(self, idx):
        decoded_file = os.path.join(self._rouge_dec_dir,
                                    "file.{}.txt".format(idx))
        return os.path.isfile(decoded_file)

    def decode(self, file_id_start, file_id_stop, ami_id='191209'):
        print("AMI transcription:", ami_id)

        test_data = load_ami_data(ami_id, 'test')

        # do this for faster stack CPU machines - to replace those that fail!!
        idx_list = [i for i in range(file_id_start, file_id_stop)]
        random.shuffle(idx_list)
        for idx in idx_list:

            # for idx in range(file_id_start, file_id_stop):
            # check if this is written already
            if self.if_already_exists(idx):
                print("ID {} already exists".format(idx))
                continue

            # Run beam search to get best Hypothesis
            best_summary, art_oovs = self.beam_search(test_data, idx)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            # original_abstract_sents = batch.original_abstracts_sents[0]
            original_abstract_sents = []

            write_for_rouge(original_abstract_sents, decoded_words, idx,
                            self._rouge_ref_dir, self._rouge_dec_dir)

            print("decoded idx = {}".format(idx))
        print("Finished decoding idx [{},{})".format(file_id_start,
                                                     file_id_stop))

        # print("Starting ROUGE eval...")
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)

    def beam_search(self, test_data, idx):
        # batch should have only one example
        # enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
        #       get_input_from_batch(batch, use_cuda)

        enc_pack, art_oovs = get_a_batch_decode(test_data,
                                                idx,
                                                self.vocab,
                                                config.beam_size,
                                                config.max_enc_steps,
                                                config.max_dec_steps,
                                                self.start_id,
                                                self.stop_id,
                                                self.pad_id,
                                                sum_type='short',
                                                use_cuda=use_cuda)
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = enc_pack
        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state.forward1(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda: y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder.forward1(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0], art_oovs
Ejemplo n.º 7
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()

        # 新的架构里写在训练的decode部分
        while batch is not None:

            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.MARK_EOS)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]
            original_article = batch.original_articles[0]

            # 英文
            # write_for_rouge(original_abstract_sents, decoded_words, counter,
            #                 self._rouge_ref_dir, self._rouge_dec_dir)
            # 中文
            self.write_result(original_article, original_abstract_sents,
                              decoded_words, counter)
            counter += 1
            #     if counter % 1000 == 0:
            #         print('%d example in %d sec'%(counter, time.time() - start))
            #         start = time.time()

            batch = self.batcher.next_batch()

        # print("Decoder has finished reading dataset for single_pass.")
        # print("Now starting ROUGE eval...")
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)

    def write_result(self, original_title, reference_summarization,
                     decoded_words, ex_index):
        """
        Write output to file.

        Args:
            reference_sents: list of strings
            decoded_words: list of strings
            ex_index: int, the index with which to label the files
        """
        summarization = ''.join(decoded_words)

        # Write to file
        result_file = os.path.join(self._decode_dir, "result.txt")

        with open(result_file, 'w') as f:
            f.write(original_title + '\t\t' + reference_summarization +
                    '\t\t' + summarization + "\n")

        print("Wrote example %i to file" % ex_index)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.MARK_GO)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.MARK_UNK) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))  # 向量
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):  # 对于不同的句子
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.MARK_EOS):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 8
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.decode_dir, model_name)
        self._decode_dir = os.path.splitext(self._decode_dir)[0]
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            Path(p).mkdir(parents=True, exist_ok=True)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        # self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
        #                        batch_size=config.beam_size, single_pass=True)
        # time.sleep(15)
        self.get_batches(config.decode_pk_path)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def get_batches(self, path):
        """
        load batches dumped by pickle
        see batch_saver.py for more information
        """
        with open(path, 'rb') as f:
            batches = pickle.load(f, encoding="bytes")
        self.batches = batches

        print("loaded: {}".format(path))

    def if_already_exists(self, idx):
        ref_file = os.path.join(self._rouge_ref_dir,
                                "{}_reference.txt".format(idx))
        decoded_file = os.path.join(self._rouge_dec_dir,
                                    "{}_decoded.txt".format(idx))
        return os.path.isfile(ref_file) and os.path.isfile(decoded_file)

    def decode(self, file_id_start, file_id_stop):
        if file_id_stop > MAX_TEST_ID: file_id_stop = MAX_TEST_ID

        # while batch is not None:

        # do this for faster stack CPU machines - to replace those that fail!!
        idx_list = [i for i in range(file_id_start, file_id_stop)]
        random.shuffle(idx_list)

        for idx in idx_list:

            # check if this is written already
            if self.if_already_exists(idx):
                # print("ID {} already exists".format(idx))
                continue

            # batch = self.batcher.next_batch()
            batch = self.batches[idx]

            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, idx,
                            self._rouge_ref_dir, self._rouge_dec_dir)

            print("decoded idx = {}".format(idx))
        print("Finished decoding idx [{},{})".format(file_id_start,
                                                     file_id_stop))

        # print("Starting ROUGE eval...")
        # results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        # rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        if not config.is_hierarchical:

            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
                enc_batch, enc_lens)
            s_t_0 = self.model.reduce_state(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            enc_sent_pos = get_sent_position(enc_batch, stop_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)
            s_t_0, _ = self.model.reduce_state(encoder_hidden, sent_enc_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]

        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda: y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            if not config.is_hierarchical:
                final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                    y_t_1, s_t_1, encoder_outputs, encoder_feature,
                    enc_padding_mask, c_t_1, extra_zeros,
                    enc_batch_extend_vocab, coverage_t_1, steps)

            else:
                final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                    y_t_1, s_t_1, enc_sent_pos, encoder_outputs,
                    encoder_feature, enc_padding_mask, None, sent_enc_outputs,
                    sent_enc_feature, sent_enc_padding_mask, c_t_1,
                    extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 9
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.concept_vocab = Concept_vocab(config.concept_vocab_path,
                                           config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               self.concept_vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        self.ds_batcher = Batcher(config.train_ds_data_path,
                                  self.vocab,
                                  self.concept_vocab,
                                  mode='train',
                                  batch_size=500,
                                  single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = AdagradCustom(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def calc_Rouge_1(self, sub, string):
        new_sub = [str(x) for x in sub]
        new_sub.insert(0, '"')
        new_sub.append('"')
        token_c = ' '.join(new_sub)
        summary = [[token_c]]
        new_string = [str(x) for x in string]
        new_string.insert(0, '"')
        new_string.append('"')
        token_r = ' '.join(new_string)
        reference = [[[token_r]]]

        rouge = Pythonrouge(summary_file_exist=False,
                            summary=summary,
                            reference=reference,
                            n_gram=2,
                            ROUGE_SU4=False,
                            ROUGE_L=False,
                            recall_only=True,
                            stemming=True,
                            stopwords=True,
                            word_level=True,
                            length_limit=True,
                            length=30,
                            use_cf=False,
                            cf=95,
                            scoring_formula='average',
                            resampling=False,
                            samples=10,
                            favor=True,
                            p=0.5)
        score = rouge.calc_score()
        return score['ROUGE-1']

    def calc_Rouge_2_recall(self, sub, string):
        token_c = sub
        token_r = string
        model = []
        ref = []
        if len(string) == 1 or len(string) == 1:
            score = 0.0
        else:
            i = 1
            while i < len(string):
                ref.append(str(token_r[i - 1]) + str(token_r[i]))
                i += 1
            i = 1
            while i < len(sub):
                model.append(str(token_c[i - 1]) + str(token_c[i]))
                i += 1
            sam = 0
            i = 0
            for i in range(len(ref)):
                for j in range(len(model)):
                    if ref[i] == model[j]:
                        sam += 1
                        model[j] = '-1'
                        break

            score = sam / float(len(ref))

        return score

    def calc_Rouge_L(self, sub, string):
        beta = 1.
        token_c = sub
        token_r = string
        if (len(string) < len(sub)):
            sub, string = string, sub
        lengths = [[0 for i in range(0,
                                     len(sub) + 1)]
                   for j in range(0,
                                  len(string) + 1)]
        for j in range(1, len(sub) + 1):
            for i in range(1, len(string) + 1):
                if (string[i - 1] == sub[j - 1]):
                    lengths[i][j] = lengths[i - 1][j - 1] + 1
                else:
                    lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
        lcs = lengths[len(string)][len(sub)]

        prec = lcs / float(len(token_c))
        rec = lcs / float(len(token_r))

        if (prec != 0 and rec != 0):
            score = ((1 + beta**2) * prec * rec) / float(rec + beta**2 * prec)
        else:
            score = 0.0
        return rec

    def calc_kl(self, dec, enc):
        kl = 0.
        dec = np.exp(dec)
        enc = np.exp(enc)
        all_dec = np.sum(dec)
        all_enc = np.sum(enc)
        for d, c in zip(dec, enc):
            d = d / all_dec
            c = c / all_enc
            kl = kl + c * np.log(c / d)
        return kl

    def calc_euc(self, dec, enc):
        euc = 0.
        for d, c in zip(dec, enc):
            euc = euc + np.sqrt(np.square(d - c))
        #print euc
        return euc

    def ds_loss(self, enc_batch_ds_emb, enc_padding_mask_ds, dec_batch_emb,
                dec_padding_mask):
        b1, t_k1, emb1 = list(enc_batch_ds_emb.size())
        b2, t_k2, emb2 = list(dec_batch_emb.size())
        enc_padding_mask_ds = enc_padding_mask_ds.unsqueeze(2).expand(
            b1, t_k1, emb1).contiguous()
        dec_padding_mask = dec_padding_mask.unsqueeze(2).expand(
            b2, t_k2, emb2).contiguous()
        enc_batch_ds_emb = enc_batch_ds_emb * enc_padding_mask_ds
        dec_batch_emb = dec_batch_emb * dec_padding_mask
        enc_batch_ds_emb = torch.sum(enc_batch_ds_emb, 1)
        dec_batch_emb = torch.sum(dec_batch_emb, 1)
        dec_title = dec_batch_emb.tolist()
        enc_article = enc_batch_ds_emb.tolist()
        dec_title_len = len(dec_title)
        enc_article_len = len(enc_article)
        dsloss = 0.
        for dec in dec_title:
            for enc in enc_article:
                dsloss = dsloss + self.calc_kl(dec, enc)
        dsloss = dsloss / float(dec_title_len * enc_article_len)
        print(dsloss)
        return dsloss

    def train_one_batch(self, batch, steps, batch_ds):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        enc_batch_ds, enc_padding_mask_ds, enc_lens_ds, _, _, _, _, _, _, _, _ = \
            get_input_from_batch(batch_ds, use_cuda)

        self.optimizer.zero_grad()
        encoder_outputs, encoder_hidden, max_encoder_output, enc_batch_ds_emb, dec_batch_emb = self.model.encoder(
            enc_batch, enc_lens, enc_batch_ds, dec_batch)
        if config.DS_train:
            ds_final_loss = self.ds_loss(enc_batch_ds_emb, enc_padding_mask_ds,
                                         dec_batch_emb, dec_padding_mask)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        s_t_0 = s_t_1
        c_t_0 = c_t_1
        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output
            c_t_0 = c_t_1

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        if config.DS_train:
            ds_final_loss = Variable(torch.FloatTensor([ds_final_loss]),
                                     requires_grad=False)
            ds_final_loss = ds_final_loss.cuda()
            loss = (config.pi - ds_final_loss) * torch.mean(batch_avg_loss)
        else:
            loss = torch.mean(batch_avg_loss)
        if steps > config.traintimes:
            scores = []
            sample_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    sample_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    sample_latest_tokens = sample_y[-1]
                    sample_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in sample_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(sample_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                sample_select = torch.multinomial(final_dist, 1).view(-1)
                sample_log_probs = torch.gather(
                    final_dist, 1, sample_select.unsqueeze(1)).squeeze()
                sample_y.append(sample_select.cpu().numpy().tolist())
                sample_step_loss = -torch.log(sample_log_probs + config.eps)
                sample_step_mask = dec_padding_mask[:, di]
                sample_step_loss = sample_step_loss * sample_step_mask
                scores.append(sample_step_loss)
            sample_sum_losses = torch.sum(torch.stack(scores, 1), 1)
            sample_batch_avg_loss = sample_sum_losses / dec_lens_var

            sample_y = np.transpose(sample_y).tolist()

            base_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    base_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    base_latest_tokens = base_y[-1]
                    base_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in base_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(base_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                base_log_probs, base_ids = torch.topk(final_dist, 1)
                base_y.append(base_ids[:, 0].cpu().numpy().tolist())

            base_y = np.transpose(base_y).tolist()

            refs = dec_batch.cpu().numpy().tolist()
            sample_dec_lens_var = map(int, dec_lens_var.cpu().numpy().tolist())
            sample_rougeL = [
                self.calc_Rouge_L(sample[:reflen],
                                  ref[:reflen]) for sample, ref, reflen in zip(
                                      sample_y, refs, sample_dec_lens_var)
            ]
            base_rougeL = [
                self.calc_Rouge_L(base[:reflen], ref[:reflen])
                for base, ref, reflen in zip(base_y, refs, sample_dec_lens_var)
            ]
            sample_rougeL = Variable(torch.FloatTensor(sample_rougeL),
                                     requires_grad=False)
            base_rougeL = Variable(torch.FloatTensor(base_rougeL),
                                   requires_grad=False)
            sample_rougeL = sample_rougeL.cuda()
            base_rougeL = base_rougeL.cuda()
            word_loss = -sample_batch_avg_loss * (base_rougeL - sample_rougeL)
            reinforce_loss = torch.mean(word_loss)
            loss = (1 - config.rein) * loss + config.rein * reinforce_loss

        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            batch_ds = self.ds_batcher.next_batch()
            loss = self.train_one_batch(batch, iter, batch_ds)
            loss = loss.cpu()

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 5
            if iter % print_interval == 0:
                print('steps %d , loss: %f' % (iter, loss))
                start = time.time()
            if iter % 50000 == 0:
                self.save_model(running_avg_loss, iter)
class Train(object):
    def __init__(self):
        self.vocab = Vocab(args.vocab_path, args.vocab_size)
        sys.stdout.flush()
        self.batcher = Batcher(args.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=args.batch_size,
                               single_pass=False)
        time.sleep(15)
        vocab_size = self.vocab.size()
        self.model = BertLSTMModel(args.hidden_size, self.vocab.size(),
                                   args.max_dec_steps)
        # self.model = Seq2SeqLSTM(args.hidden_size, self.vocab.size(), args.max_dec_steps)
        if use_cuda:
            self.model = self.model.cuda()

        self.model_optimizer = torch.optim.Adam(self.model.parameters(),
                                                lr=args.lr)

        train_logs = os.path.join(args.logs, "train_logs")
        eval_logs = os.path.join(args.logs, "eval_logs")
        self.train_summary_writer = tf.summary.FileWriter(train_logs)
        self.eval_summary_writer = tf.summary.FileWriter(eval_logs)

    def trainOneBatch(self, batch):
        self.model_optimizer.zero_grad()
        loss = self.model(batch)
        loss.backward()

        clip_grad_norm_(self.model.parameters(), 5)
        # clip_grad_norm_(self.model.decoder.parameters(), 5)
        self.model_optimizer.step()
        return loss.item() / args.max_dec_steps

    def trainIters(self):
        running_avg_loss = 0
        s_time = time.time()
        for t in range(args.max_iteration):
            batch = self.batcher.next_batch()
            loss = self.trainOneBatch(batch)
            running_avg_loss = calc_running_avg_loss(loss,
                                                     running_avg_loss,
                                                     decay=0.999)
            save_running_avg_loss(running_avg_loss, t,
                                  self.train_summary_writer)

            # Print every 100 steps
            if (t + 1) % 1 == 0:
                time_run = time.time() - s_time
                s_time = time.time()
                print("timestep: {}, loss: {}, time: {}s".format(
                    t, running_avg_loss, time_run))
                sys.stdout.flush()

            # Save the model every 1000 steps
            if (t + 1) % args.save_every_itr == 0:
                with torch.no_grad():
                    self.save_checkpoint(t, running_avg_loss)
                    self.model.eval()
                    self.evaluate(t)
                    self.model.train()

    def save_checkpoint(self, step, loss):
        checkpoint_file = "checkpoint_{}".format(step)
        checkpoint_path = os.path.join(args.logs, checkpoint_file)
        torch.save(
            {
                'timestep': step,
                'model_state_dict': self.model.decoder.state_dict(),
                'optimizer_state_dict': self.model_optimizer.state_dict(),
                'loss': loss
            }, checkpoint_path)
        # torch.save({
        #     'timestep': step,
        #     'encoder_state_dict': self.model.encoder.state_dict(),
        #     'decoder_state_dict': self.model.decoder.state_dict(),
        #     'optimizer_state_dict': self.model_optimizer.state_dict(),
        #     'loss': loss
        # }, checkpoint_path)

    def evaluate(self, timestep):
        self.eval_batcher = Batcher(args.eval_data_path,
                                    self.vocab,
                                    mode='train',
                                    batch_size=args.batch_size,
                                    single_pass=True)
        time.sleep(15)
        t1 = time.time()
        batch = self.eval_batcher.next_batch()
        running_avg_loss = 0
        while batch is not None:
            loss = self.model(batch)
            loss = loss / args.max_dec_steps
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss)
            batch = self.eval_batcher.next_batch()

        # Save the evaluation score
        time_spent = time.time() - t1
        print("Evaluation Loss: {}, Time: {}s".format(running_avg_loss,
                                                      time_spent))
        save_running_avg_loss(running_avg_loss, timestep,
                              self.eval_summary_writer)
        sys.stdout.flush()
Ejemplo n.º 11
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.ouput_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)

        self.checkpoint_dir = os.path.join(train_dir, 'checkpoints')
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.train_summary_writer = tf.summary.create_file_writer(
            os.path.join(train_dir, 'log', 'train'))
        self.eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(train_dir, 'log', 'eval'))

    def save_model(self, model_path, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        torch.save(state, model_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(device, model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(device)

        return start_iter, start_loss

    def train_one_batch(self, batch, forcing_ratio=1):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, device)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, device)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        y_t_1_hat = None
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            # decide the next input
            if di == 0 or random.random() < forcing_ratio:
                x_t = y_t_1  # teacher forcing, use label from last time step as input
            else:
                # use embedding of UNK for all oov word
                y_t_1_hat[y_t_1_hat > self.vocab.size()] = self.vocab.word2id(
                    UNKNOWN_TOKEN)
                x_t = y_t_1_hat.flatten(
                )  # use prediction from last time step as input
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                x_t, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
            _, y_t_1_hat = final_dist.data.topk(1)
            target = target_batch[:, di].unsqueeze(1)
            step_loss = cal_NLLLoss(target, final_dist)
            if config.is_coverage:  # if not using coverge, keep coverage=None
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:,
                                         di]  # padding in target should not count into loss
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def train(self, n_iters, init_model_path=None):
        iter, avg_loss = self.setup_train(init_model_path)
        start = time.time()
        cnt = 0
        best_model_path = None
        min_eval_loss = float('inf')
        while iter < n_iters:
            s = config.forcing_ratio
            k = config.decay_to_0_iter
            x = iter
            nere_zero = 0.0001
            if config.forcing_decay_type:
                if x >= config.decay_to_0_iter:
                    forcing_ratio = 0
                elif config.forcing_decay_type == 'linear':
                    forcing_ratio = s * (k - x) / k
                elif config.forcing_decay_type == 'exp':
                    p = pow(nere_zero, 1 / k)
                    forcing_ratio = s * (p**x)
                elif config.forcing_decay_type == 'sig':
                    r = math.log((1 / nere_zero) - 1) / k
                    forcing_ratio = s / (1 + pow(math.e, r * (x - k / 2)))
                else:
                    raise ValueError('Unrecognized forcing_decay_type: ' +
                                     config.forcing_decay_type)
            else:
                forcing_ratio = config.forcing_ratio
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch, forcing_ratio=forcing_ratio)
            model_path = os.path.join(self.checkpoint_dir,
                                      'model_step_%d' % (iter + 1))
            avg_loss = calc_avg_loss(loss, avg_loss)

            if (iter + 1) % config.print_interval == 0:
                with self.train_summary_writer.as_default():
                    tf.summary.scalar(name='loss', data=loss, step=iter)
                self.train_summary_writer.flush()
                logger.info('steps %d, took %.2f seconds, train avg loss: %f' %
                            (iter + 1, time.time() - start, avg_loss))
                start = time.time()
            if config.eval_interval is not None and (
                    iter + 1) % config.eval_interval == 0:
                start = time.time()
                logger.info("Start Evaluation on model %s" % model_path)
                eval_processor = Evaluate(self.model, self.vocab)
                eval_loss = eval_processor.run_eval()
                logger.info(
                    "Evaluation finished, took %.2f seconds, eval loss: %f" %
                    (time.time() - start, eval_loss))
                with self.eval_summary_writer.as_default():
                    tf.summary.scalar(name='eval_loss',
                                      data=eval_loss,
                                      step=iter)
                self.eval_summary_writer.flush()
                if eval_loss < min_eval_loss:
                    logger.info(
                        "This is the best model so far, saving it to disk.")
                    min_eval_loss = eval_loss
                    best_model_path = model_path
                    self.save_model(model_path, eval_loss, iter)
                    cnt = 0
                else:
                    cnt += 1
                    if cnt > config.patience:
                        logger.info(
                            "Eval loss doesn't drop for %d straight times, early stopping.\n"
                            "Best model: %s (Eval loss %f: )" %
                            (config.patience, best_model_path, min_eval_loss))
                        break
                start = time.time()
            elif (iter + 1) % config.save_interval == 0:
                self.save_model(model_path, avg_loss, iter)
            iter += 1
        else:
            logger.info(
                "Training finished, best model: %s, with train loss %f: " %
                (best_model_path, min_eval_loss))
Ejemplo n.º 12
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%d' % (int(time.time())))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            self.write_for_rouge(original_abstract_sents, decoded_words,
                                 counter)
            counter += 1
            if counter % 10000:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()
        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 attn_dists=[],
                 p_gens=[]) for _ in xrange(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [
                t if t < self.vocab.size() else self.vocab.word2id(
                    data.UNKNOWN_TOKEN) for t in latest_tokens
            ]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            final_dist, s_t, c_t, attn_dist, p_gen = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                for j in xrange(config.beam_size *
                                2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                        log_prob=topk_log_probs[i, j].data[0],
                                        state=(dec_h[i], dec_c[i]),
                                        context=c_t[i],
                                        attn_dist=attn_dist[i],
                                        p_gen=p_gen[i])
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]

    def write_for_rouge(self, reference_sents, decoded_words, ex_index):
        decoded_sents = []
        while len(decoded_words) > 0:
            try:
                fst_period_idx = decoded_words.index(".")
            except ValueError:
                fst_period_idx = len(decoded_words)
            sent = decoded_words[:fst_period_idx + 1]
            decoded_words = decoded_words[fst_period_idx + 1:]
            decoded_sents.append(' '.join(sent))

        # pyrouge calls a perl script that puts the data into HTML files.
        # Therefore we need to make our output HTML safe.
        decoded_sents = [make_html_safe(w) for w in decoded_sents]
        reference_sents = [make_html_safe(w) for w in reference_sents]

        ref_file = os.path.join(self._rouge_ref_dir,
                                "%06d_reference.txt" % ex_index)
        decoded_file = os.path.join(self._rouge_dec_dir,
                                    "%06d_decoded.txt" % ex_index)

        with open(ref_file, "w") as f:
            for idx, sent in enumerate(reference_sents):
                f.write(sent) if idx == len(reference_sents) - 1 else f.write(
                    sent + "\n")
        with open(decoded_file, "w") as f:
            for idx, sent in enumerate(decoded_sents):
                f.write(sent) if idx == len(decoded_sents) - 1 else f.write(
                    sent + "\n")

        print("Wrote example %i to file" % ex_index)
Ejemplo n.º 13
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)


    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in xrange(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage_t_1)

            topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                   log_prob=topk_log_probs[i, j].data[0],
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 14
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.concept_vocab = Concept_vocab(config.concept_vocab_path,
                                           config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               self.concept_vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            best_summary = self.beam_search(batch)

            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")

    def beam_search(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output, _, _ = self.model.encoder(
            enc_batch, enc_lens, enc_batch, enc_batch)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in xrange(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                'decode', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage_t_1, steps)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                        log_prob=topk_log_probs[i, j].data[0],
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Ejemplo n.º 15
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self.model_path_name = model_name
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()

        decoded_result = []
        refered_result = []
        article_result = []
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]
            article = batch.original_articles[0]

            #write_for_rouge(original_abstract_sents, decoded_words, counter,
            #                self._rouge_ref_dir, self._rouge_dec_dir)
            decoded_sents = []
            while len(decoded_words) > 0:
                try:
                    fst_period_idx = decoded_words.index(".")
                except ValueError:
                    fst_period_idx = len(decoded_words)
                sent = decoded_words[:fst_period_idx + 1]
                decoded_words = decoded_words[fst_period_idx + 1:]
                decoded_sents.append(' '.join(sent))

# pyrouge calls a perl script that puts the data into HTML files.
# Therefore we need to make our output HTML safe.
            decoded_sents = [make_html_safe(w) for w in decoded_sents]
            reference_sents = [
                make_html_safe(w) for w in original_abstract_sents
            ]
            decoded_result.append(' '.join(decoded_sents))
            refered_result.append(' '.join(reference_sents))
            article_result.append(article)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        load_file = self.model_path_name
        self.print_original_predicted(decoded_result, refered_result,
                                      article_result, load_file)

        rouge = Rouge()
        scores = rouge.get_scores(decoded_result, refered_result)
        rouge_1 = sum([x["rouge-1"]["f"] for x in scores]) / len(scores)
        rouge_2 = sum([x["rouge-2"]["f"] for x in scores]) / len(scores)
        rouge_l = sum([x["rouge-l"]["f"] for x in scores]) / len(scores)
        rouge_1_r = sum([x["rouge-1"]["r"] for x in scores]) / len(scores)
        rouge_2_r = sum([x["rouge-2"]["r"] for x in scores]) / len(scores)
        rouge_l_r = sum([x["rouge-l"]["r"] for x in scores]) / len(scores)
        rouge_1_p = sum([x["rouge-1"]["p"] for x in scores]) / len(scores)
        rouge_2_p = sum([x["rouge-2"]["p"] for x in scores]) / len(scores)
        rouge_l_p = sum([x["rouge-l"]["p"] for x in scores]) / len(scores)
        log_str = " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" + "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l
        log_str_r = " rouge_1_r:" + "%.4f" % rouge_1_r + " rouge_2_r:" + "%.4f" % rouge_2_r + " rouge_l_r:" + "%.4f" % rouge_l_r
        logger.info(load_file + " rouge_1:" + "%.4f" % rouge_1 + " rouge_2:" +
                    "%.4f" % rouge_2 + " rouge_l:" + "%.4f" % rouge_l)
        log_str_p = " rouge_1_p:" + "%.4f" % rouge_1_p + " rouge_2_p:" + "%.4f" % rouge_2_p + " rouge_l_p:" + "%.4f" % rouge_l_p
        results_file = os.path.join(self._decode_dir, "ROUGE_results.txt")
        with open(results_file, "w") as f:
            f.write(log_str + '\n')
            f.write(log_str_r + '\n')
            f.write(log_str_p + '\n')

        #results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        #rouge_log(results_dict, self._decode_dir)

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split("_")[1] + ".txt"

        with open(os.path.join(self._rouge_dec_dir, filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]