Exemplo n.º 1
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)

            topk_log_probs, topk_ids = torch.topk(final_dist,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Exemplo n.º 2
0
def main(unused_argv):
    # %%
    # choose what level of logging you want
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info('Starting running in %s mode...', (FLAGS.mode))
    # 創建字典
    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)

    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_sen_num', 'max_dec_steps', 'max_enc_steps'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():
        if key in hparam_list:
            hps_dict[key] = val.value  # add it to the dict
    hps_generator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    hparam_list = [
        'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
        'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
        'max_enc_sen_num', 'max_enc_seq_len'
    ]
    hps_dict = {}

    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:
            hps_dict[key] = val.value  # add it to the dict
    hps_discriminator = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    # # 取出最小batch size 的資料量
    batcher = GenBatcher(vocab, hps_generator)
    # print(batcher.train_batch[0].original_review_inputs)
    # print(len(batcher.train_batch[0].original_review_inputs))
    tf.set_random_seed(123)
    # %%
    if FLAGS.mode == 'train_generator':

        # print("Start pre-training ......")
        ge_model = Generator(hps_generator, vocab)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model)

        generated = Generated_sample(ge_model, vocab, batcher, sess_ge)
        print("Start pre-training generator......")
        run_pre_train_generator(ge_model, batcher, 300, sess_ge, saver_ge,
                                train_dir_ge)
        # util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        print("finish load train-generator")

        print("Generating negative examples......")
        generator_graph = tf.Graph()
        with generator_graph.as_default():
            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
            print("finish load train-generator")

        generated.generator_train_negative_example()
        generated.generator_test_negative_example()

        print("finish write")
    elif FLAGS.mode == 'train_discriminator':
        # print("Start pre-training ......")
        model_dis = Discriminator(hps_discriminator, vocab)
        dis_batcher = DisBatcher(hps_discriminator, vocab,
                                 "discriminator_train/positive/*",
                                 "discriminator_train/negative/*",
                                 "discriminator_test/positive/*",
                                 "discriminator_test/negative/*")
        sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
            model_dis)

        print("Start pre-training discriminator......")
        if not os.path.exists("discriminator_result"):
            os.mkdir("discriminator_result")
        run_pre_train_discriminator(model_dis, dis_batcher, 1000, sess_dis,
                                    saver_dis, train_dir_dis)

    elif FLAGS.mode == "adversarial_train":

        generator_graph = tf.Graph()
        discriminatorr_graph = tf.Graph()

        print("Start adversarial-training......")
        # tf.reset_default_graph()

        with generator_graph.as_default():
            model = Generator(hps_generator, vocab)
            sess_ge, saver_ge, train_dir_ge = setup_training_generator(model)
            generated = Generated_sample(model, vocab, batcher, sess_ge)

            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
            print("finish load train-generator")
        with discriminatorr_graph.as_default():
            model_dis = Discriminator(hps_discriminator, vocab)
            dis_batcher = DisBatcher(hps_discriminator, vocab,
                                     "discriminator_train/positive/*",
                                     "discriminator_train/negative/*",
                                     "discriminator_test/positive/*",
                                     "discriminator_test/negative/*")
            sess_dis, saver_dis, train_dir_dis = setup_training_discriminator(
                model_dis)

            util.load_ckpt(saver_dis, sess_dis, ckpt_dir="train-discriminator")
            print("finish load train-discriminator")

        print("Start adversarial  training......")
        if not os.path.exists("train_sample_generated"):
            os.mkdir("train_sample_generated")
        if not os.path.exists("test_max_generated"):
            os.mkdir("test_max_generated")
        if not os.path.exists("test_sample_generated"):
            os.mkdir("test_sample_generated")

        whole_decay = False

        for epoch in range(100):
            print('開始訓練')
            batches = batcher.get_batches(mode='train')
            for step in range(int(len(batches) / 14)):

                run_train_generator(model, model_dis, sess_dis, batcher,
                                    dis_batcher,
                                    batches[step * 14:(step + 1) * 14],
                                    sess_ge, saver_ge, train_dir_ge)
                generated.generator_sample_example(
                    "train_sample_generated/" + str(epoch) + "epoch_step" +
                    str(step) + "_temp_positive", "train_sample_generated/" +
                    str(epoch) + "epoch_step" + str(step) + "_temp_negative",
                    14)

                tf.logging.info("test performance: ")
                tf.logging.info("epoch: " + str(epoch) + " step: " + str(step))

                #                print("evaluate the diversity of DP-GAN (decode based on  max probability)")
                #                generated.generator_test_sample_example(
                #                    "test_sample_generated/" +
                #                    str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                #                    "test_sample_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14)
                #
                #                print("evaluate the diversity of DP-GAN (decode based on sampling)")
                #                generated.generator_test_max_example(
                #                    "test_max_generated/" +
                #                    str(epoch) + "epoch_step" + str(step) + "_temp_positive",
                #                    "test_max_generated/" + str(epoch) + "epoch_step" + str(step) + "_temp_negative", 14)

                dis_batcher.train_queue = []
                for i in range(epoch + 1):
                    for j in range(step + 1):
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_positive/*")
                        dis_batcher.train_queue += dis_batcher.fill_example_queue(
                            "train_sample_generated/" + str(i) + "epoch_step" +
                            str(j) + "_temp_negative/*")
                dis_batcher.train_batch = dis_batcher.create_batches(
                    mode="train", shuffleis=True)
                whole_decay = run_train_discriminator(
                    model_dis, 5, dis_batcher,
                    dis_batcher.get_batches(mode="train"), sess_dis, saver_dis,
                    train_dir_dis, whole_decay)
    elif FLAGS.mode == "test_language_model":
        ge_model = Generator(hps_generator, vocab)
        sess_ge, saver_ge, train_dir_ge = setup_training_generator(ge_model)
        util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        print("finish load train-generator")

        #        generator_graph = tf.Graph()
        #        with generator_graph.as_default():
        #            util.load_ckpt(saver_ge, sess_ge, ckpt_dir="train-generator")
        #            print("finish load train-generator")

        # jieba.load_userdict('dir.txt')
        inputs = ''
        while inputs != "close":

            inputs = input("Enter your ask: ")
            sentence = segmentor.segment(t2s.convert(inputs))
            #            sentence = jieba.cut(inputs)
            sentence = (" ".join(sentence))
            sentence = s2t.convert(sentence)
            print(sentence)
            sentence = sentence.split()

            enc_input = [vocab.word2id(w) for w in sentence]
            enc_lens = np.array([len(enc_input)])
            enc_input = np.array([enc_input])
            out_sentence = ('[START]').split()
            dec_batch = [vocab.word2id(w) for w in out_sentence]
            #dec_batch = [2] + dec_batch
            # dec_batch.append(3)
            while len(dec_batch) < 40:
                dec_batch.append(1)

            dec_batch = np.array([dec_batch])
            dec_batch = np.resize(dec_batch, (1, 1, 40))
            dec_lens = np.array([len(dec_batch)])
            if (FLAGS.beamsearch == 'beamsearch_train'):
                result = ge_model.run_test_language_model(
                    sess_ge, enc_input, enc_lens, dec_batch, dec_lens)
                #                print(result['generated'])
                #                print(result['generated'].shape)
                output_ids = result['generated'][0]
                decoded_words = data.outputids2words(output_ids, vocab, None)
                print("decoded_words :", decoded_words)
            else:
                results = ge_model.run_test_beamsearch_example(
                    sess_ge, enc_input, enc_lens, dec_batch, dec_lens)
                beamsearch_outputs = results['beamsearch_outputs']
                for i in range(5):
                    predict_list = np.ndarray.tolist(beamsearch_outputs[:, :,
                                                                        i])
                    predict_list = predict_list[0]
                    predict_seq = [vocab.id2word(idx) for idx in predict_list]
                    decoded_words = " ".join(predict_seq).split()
                    #                    decoded_words = decoded_words

                    try:
                        if decoded_words[0] == '[STOPDOC]':
                            decoded_words = decoded_words[1:]
                        # index of the (first) [STOP] symbol
                        fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                        decoded_words = decoded_words[:fst_stop_idx]
                    except ValueError:
                        decoded_words = decoded_words

                    if decoded_words[-1] != '.' and decoded_words[
                            -1] != '!' and decoded_words[-1] != '?':
                        decoded_words.append('.')
                    decoded_words_all = []
                    decoded_output = ' '.join(
                        decoded_words).strip()  # single string
                    decoded_words_all.append(decoded_output)
                    decoded_words_all = ' '.join(decoded_words_all).strip()
                    decoded_words_all = decoded_words_all.replace("[UNK] ", "")
                    decoded_words_all = decoded_words_all.replace("[UNK]", "")
                    decoded_words_all = decoded_words_all.replace(" ", "")
                    decoded_words_all, _ = re.subn(r"(! ){2,}", "",
                                                   decoded_words_all)
                    decoded_words_all, _ = re.subn(r"(\. ){2,}", "",
                                                   decoded_words_all)
                    if decoded_words_all.startswith(','):
                        decoded_words_all = decoded_words_all[1:]
                    print("The resonse   : {}".format(decoded_words_all))
Exemplo n.º 3
0
def main(argv):
    tf.set_random_seed(111)  # a seed value for randomness

    # Create a batcher object that will create minibatches of data
    # TODO change to pass number

    # --------------- building graph ---------------
    hparam_gen = [
        'mode',
        'model_dir',
        'adagrad_init_acc',
        'steps_per_checkpoint',
        'batch_size',
        'beam_size',
        'cov_loss_wt',
        'coverage',
        'emb_dim',
        'rand_unif_init_mag',
        'gen_vocab_file',
        'gen_vocab_size',
        'hidden_dim',
        'gen_lr',
        'gen_max_gradient',
        'max_dec_steps',
        'max_enc_steps',
        'min_dec_steps',
        'trunc_norm_init_std',
        'single_pass',
        'log_root',
        'data_path',
    ]

    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_gen:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_gen = namedtuple("HParams4Gen", hps_dict.keys())(**hps_dict)

    print("Building vocabulary for generator ...")
    gen_vocab = Vocab(join_path(hps_gen.data_path, hps_gen.gen_vocab_file),
                      hps_gen.gen_vocab_size)

    hparam_dis = [
        'mode',
        'vocab_type',
        'model_dir',
        'dis_vocab_size',
        'steps_per_checkpoint',
        'learning_rate_decay_factor',
        'dis_vocab_file',
        'num_class',
        'layer_size',
        'conv_layers',
        'max_steps',
        'kernel_size',
        'early_stop',
        'pool_size',
        'pool_layers',
        'dis_max_gradient',
        'batch_size',
        'dis_lr',
        'lr_decay_factor',
        'cell_type',
        'max_enc_steps',
        'max_dec_steps',
        'single_pass',
        'data_path',
        'num_models',
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_dis:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_dis = namedtuple("HParams4Dis", hps_dict.keys())(**hps_dict)
    if hps_gen.gen_vocab_file == hps_dis.dis_vocab_file:
        hps_dis = hps_dis._replace(vocab_type="word")
        hps_dis = hps_dis._replace(layer_size=hps_gen.emb_dim)
        hps_dis = hps_dis._replace(dis_vocab_size=hps_gen.gen_vocab_size)
    else:
        hps_dis = hps_dis._replace(max_enc_steps=hps_dis.max_enc_steps * 2)
        hps_dis = hps_dis._replace(max_dec_steps=hps_dis.max_dec_steps * 2)
    if FLAGS.mode == "train_gan":
        hps_gen = hps_gen._replace(batch_size=hps_gen.batch_size *
                                   hps_dis.num_models)

    if FLAGS.mode != "pretrain_dis":
        with tf.variable_scope("generator"):
            generator = PointerGenerator(hps_gen, gen_vocab)
            print("Building generator graph ...")
            gen_decoder_scope = generator.build_graph()

    if FLAGS.mode != "pretrain_gen":
        print("Building vocabulary for discriminator ...")
        dis_vocab = Vocab(join_path(hps_dis.data_path, hps_dis.dis_vocab_file),
                          hps_dis.dis_vocab_size)
    if FLAGS.mode in ['train_gan', 'pretrain_dis']:
        with tf.variable_scope("discriminator"), tf.device("/gpu:0"):
            discriminator = Seq2ClassModel(hps_dis)
            print("Building discriminator graph ...")
            discriminator.build_graph()

    hparam_gan = [
        'mode',
        'model_dir',
        'gan_iter',
        'gan_gen_iter',
        'gan_dis_iter',
        'gan_lr',
        'rollout_num',
        'sample_num',
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_gan:  # if it's in the list
            hps_dict[key] = val  # add it to the dict

    hps_gan = namedtuple("HParams4GAN", hps_dict.keys())(**hps_dict)
    hps_gan = hps_gan._replace(mode="train_gan")
    if FLAGS.mode == 'train_gan':
        with tf.device("/gpu:0"):
            print("Creating rollout...")
            rollout = Rollout(generator, 0.8, gen_decoder_scope)

    # --------------- initializing variables ---------------
    all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) + \
        tf.get_collection_ref(tf.GraphKeys.WEIGHTS) + \
        tf.get_collection_ref(tf.GraphKeys.BIASES)
    sess = tf.Session(config=utils.get_config())
    sess.run(tf.variables_initializer(all_variables))
    if FLAGS.mode == "pretrain_gen":
        val_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.val_dir))
        model_dir = ensure_exists(join_path(FLAGS.model_dir, 'generator'))
        print("Restoring the generator model from the latest checkpoint...")
        gen_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[
                v for v in all_variables
                if "generator" in v.name and "GAN" not in v.name
            ])
        gen_dir = ensure_exists(join_path(FLAGS.model_dir, "generator"))
        # gen_dir = ensure_exists(FLAGS.model_dir)
        # temp_saver = tf.train.Saver(
        #     var_list=[v for v in all_variables if "generator" in v.name and "Adagrad" not in v.name])
        ckpt_path = utils.load_ckpt(gen_saver, sess, gen_dir)
        print('going to restore embeddings from checkpoint')
        if not ckpt_path:
            emb_path = join_path(FLAGS.model_dir, "generator", "init_embed")
            if emb_path:
                generator.saver.restore(
                    sess,
                    tf.train.get_checkpoint_state(
                        emb_path).model_checkpoint_path)
                print(
                    colored(
                        "successfully restored embeddings form %s" % emb_path,
                        'green'))
            else:
                print(
                    colored("failed to restore embeddings form %s" % emb_path,
                            'red'))

    elif FLAGS.mode in ["decode", "train_gan"]:
        print("Restoring the generator model from the best checkpoint...")
        dec_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        gan_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir))
        gan_val_dir = ensure_exists(
            join_path(FLAGS.model_dir, 'generator', FLAGS.gan_dir,
                      FLAGS.val_dir))
        gan_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        gan_val_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "generator" in v.name])
        utils.load_ckpt(dec_saver, sess, val_dir,
                        (FLAGS.mode in ["train_gan", "decode"]))

    if FLAGS.mode in ["pretrain_dis", "train_gan"]:
        dis_saver = tf.train.Saver(
            max_to_keep=3,
            var_list=[v for v in all_variables if "discriminator" in v.name])
        dis_dir = ensure_exists(join_path(FLAGS.model_dir, 'discriminator'))
        ckpt = utils.load_ckpt(dis_saver, sess, dis_dir)
        if not ckpt:
            if hps_dis.vocab_type == "word":
                discriminator.init_emb(
                    sess, join_path(FLAGS.model_dir, "generator",
                                    "init_embed"))
            else:
                discriminator.init_emb(
                    sess,
                    join_path(FLAGS.model_dir, "discriminator", "init_embed"))

    # --------------- train models ---------------
    if FLAGS.mode != "pretrain_dis":
        gen_batcher_train = GenBatcher("train",
                                       gen_vocab,
                                       hps_gen,
                                       single_pass=hps_gen.single_pass)
        decoder = Decoder(sess, generator, gen_vocab)
        gen_batcher_val = GenBatcher("val",
                                     gen_vocab,
                                     hps_gen,
                                     single_pass=True)
        val_saver = tf.train.Saver(
            max_to_keep=10,
            var_list=[
                v for v in all_variables
                if "generator" in v.name and "GAN" not in v.name
            ])

    if FLAGS.mode != "pretrain_gen":
        dis_val_batch_size = hps_dis.batch_size * hps_dis.num_models \
            if hps_dis.mode == "train_gan" else hps_dis.batch_size * hps_dis.num_models * 2
        dis_batcher_val = DisBatcher(
            hps_dis.data_path,
            "eval",
            gen_vocab,
            dis_vocab,
            dis_val_batch_size,
            single_pass=True,
            max_art_steps=hps_dis.max_enc_steps,
            max_abs_steps=hps_dis.max_dec_steps,
        )

    if FLAGS.mode == "pretrain_gen":
        # get reload the
        print('Going to pretrain the generator')
        try:
            pretrain_generator(generator, gen_batcher_train, sess,
                               gen_batcher_val, gen_saver, model_dir,
                               val_saver, val_dir)
        except KeyboardInterrupt:
            tf.logging.info("Caught keyboard interrupt on worker....")

    elif FLAGS.mode == "pretrain_dis":
        print('Going to pretrain the discriminator')
        dis_batcher = DisBatcher(
            hps_dis.data_path,
            "decode",
            gen_vocab,
            dis_vocab,
            hps_dis.batch_size * hps_dis.num_models,
            single_pass=hps_dis.single_pass,
            max_art_steps=hps_dis.max_enc_steps,
            max_abs_steps=hps_dis.max_dec_steps,
        )
        try:
            pretrain_discriminator(sess, discriminator, dis_batcher_val,
                                   dis_vocab, dis_batcher, dis_saver)
        except KeyboardInterrupt:
            tf.logging.info("Caught keyboard interrupt on worker....")

    elif FLAGS.mode == "train_gan":
        gen_best_loss = get_best_loss_from_chpt(val_dir)
        gen_global_step = 0
        print('Going to tune the two using Gan')
        for i_gan in range(hps_gan.gan_iter):
            # Train the generator for one step
            g_losses = []
            current_speed = []
            for it in range(hps_gan.gan_gen_iter):
                start_time = time.time()
                batch = gen_batcher_train.next_batch()

                # generate samples
                enc_states, dec_in_state, n_samples, n_targets_padding_mask = decoder.mc_generate(
                    batch, include_start_token=True, s_num=hps_gan.sample_num)
                # get rewards for the samples
                n_rewards = rollout.get_reward(sess, gen_vocab, dis_vocab,
                                               batch, enc_states, dec_in_state,
                                               n_samples, hps_gan.rollout_num,
                                               discriminator)

                # fine tune the generator
                n_sample_targets = [samples[:, 1:] for samples in n_samples]
                n_targets_padding_mask = [
                    padding_mask[:, 1:]
                    for padding_mask in n_targets_padding_mask
                ]
                n_samples = [samples[:, :-1] for samples in n_samples]
                # sample_target_padding_mask = pad_sample(sample_target, gen_vocab, hps_gen)
                n_samples = [
                    np.where(
                        np.less(samples, hps_gen.gen_vocab_size), samples,
                        np.array([[gen_vocab.word2id(data.UNKNOWN_TOKEN)] *
                                  hps_gen.max_dec_steps] * hps_gen.batch_size))
                    for samples in n_samples
                ]
                results = generator.run_gan_batch(sess, batch, n_samples,
                                                  n_sample_targets,
                                                  n_targets_padding_mask,
                                                  n_rewards)

                gen_global_step = results["global_step"]

                # for visualization
                g_loss = results["loss"]
                if not math.isnan(g_loss):
                    g_losses.append(g_loss)
                else:
                    print(colored('a nan in gan loss', 'red'))
                current_speed.append(time.time() - start_time)

            # Test
            # if FLAGS.gan_gen_iter and (i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1):
            if i_gan % 100 == 0 or i_gan == hps_gan.gan_iter - 1:
                print('Going to test the generator.')
                current_speed = sum(current_speed) / (len(current_speed) *
                                                      hps_gen.batch_size)
                everage_g_loss = sum(g_losses) / len(g_losses)
                # one more process hould be opened for the evaluation
                eval_loss, gen_best_loss = save_ckpt(
                    sess, generator, gen_best_loss, gan_dir, gan_saver,
                    gen_batcher_val, gan_val_dir, gan_val_saver,
                    gen_global_step)

                if eval_loss:
                    print("\nDashboard for " +
                          colored("GAN Generator", 'green') + " updated %s, "
                          "finished steps:\t%s\n"
                          "\tBatch size:\t%s\n"
                          "\tVocabulary size:\t%s\n"
                          "\tCurrent speed:\t%.4f seconds/article\n"
                          "\tAverage training loss:\t%.4f; "
                          "eval loss:\t%.4f" % (
                              datetime.datetime.now().strftime(
                                  "on %m-%d at %H:%M"),
                              gen_global_step,
                              FLAGS.batch_size,
                              hps_gen.gen_vocab_size,
                              current_speed,
                              everage_g_loss.item(),
                              eval_loss.item(),
                          ))

            # Train the discriminator
            print('Going to train the discriminator.')
            dis_best_loss = 1000
            dis_losses = []
            dis_accuracies = []
            for d_gan in range(hps_gan.gan_dis_iter):
                batch = gen_batcher_train.next_batch()
                enc_states, dec_in_state, k_samples_words, _ = decoder.mc_generate(
                    batch, s_num=hps_gan.sample_num)
                # shuould first tanslate to words to avoid unk
                articles_oovs = batch.art_oovs
                for samples_words in k_samples_words:
                    dec_batch_words = batch.target_batch
                    conditions_words = batch.enc_batch_extend_vocab
                    if hps_dis.vocab_type == "char":
                        samples = gen_vocab2dis_vocab(samples_words, gen_vocab,
                                                      articles_oovs, dis_vocab,
                                                      hps_dis.max_dec_steps,
                                                      STOP_DECODING)
                        dec_batch = gen_vocab2dis_vocab(
                            dec_batch_words, gen_vocab, articles_oovs,
                            dis_vocab, hps_dis.max_dec_steps, STOP_DECODING)
                        conditions = gen_vocab2dis_vocab(
                            conditions_words, gen_vocab, articles_oovs,
                            dis_vocab, hps_dis.max_enc_steps, PAD_TOKEN)
                    else:
                        samples = samples_words
                        dec_batch = dec_batch_words
                        conditions = conditions_words
                        # the unknown in target

                    inputs = np.concatenate([samples, dec_batch], 0)
                    conditions = np.concatenate([conditions, conditions], 0)

                    targets = [[1, 0] for _ in samples] + [[0, 1]
                                                           for _ in dec_batch]
                    targets = np.array(targets)
                    # randomize the samples
                    assert len(inputs) == len(conditions) == len(
                        targets
                    ), "lengthes of the inputs, conditions and targests should be the same."
                    indices = np.random.permutation(len(inputs))
                    inputs = np.split(inputs[indices], 2)
                    conditions = np.split(conditions[indices], 2)
                    targets = np.split(targets[indices], 2)
                    assert len(inputs) % 2 == 0, "the length should be mean"

                    results = discriminator.run_one_batch(
                        sess, inputs[0], conditions[0], targets[0])
                    dis_accuracies.append(results["accuracy"].item())
                    dis_losses.append(results["loss"].item())

                    results = discriminator.run_one_batch(
                        sess, inputs[1], conditions[1], targets[1])
                    dis_accuracies.append(results["accuracy"].item())

                ave_dis_acc = sum(dis_accuracies) / len(dis_accuracies)
                if d_gan == hps_gan.gan_dis_iter - 1:
                    if (sum(dis_losses) / len(dis_losses)) < dis_best_loss:
                        dis_best_loss = sum(dis_losses) / len(dis_losses)
                        checkpoint_path = ensure_exists(
                            join_path(hps_dis.model_dir,
                                      "discriminator")) + "/model.ckpt"
                        dis_saver.save(sess,
                                       checkpoint_path,
                                       global_step=results["global_step"])
                    print_dashboard("GAN Discriminator",
                                    results["global_step"].item(),
                                    hps_dis.batch_size, hps_dis.dis_vocab_size,
                                    results["loss"].item(), 0.00, 0.00, 0.00)
                    print("Average training accuracy: \t%.4f" % ave_dis_acc)

                if ave_dis_acc > 0.9:
                    break

    # --------------- decoding samples ---------------
    elif FLAGS.mode == "decode":
        print('Going to decode from the generator.')
        decoder.bs_decode(gen_batcher_train)
        print("Finished decoding..")
        # decode for generating corpus for discriminator

    sess.close()
Exemplo n.º 4
0
class BeamSearch(object):
    def __init__(self, model, config, step):
        self.config = config
        self.model = model.to(device)

        self._decode_dir = os.path.join(config.log_root,
                                        'decode_S%s' % str(step))
        self._rouge_ref = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec = os.path.join(self._decode_dir, 'rouge_dec')

        if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

        self.vocab = Vocab(config.vocab_file, config.vocab_size)
        self.test_data = CNNDMDataset('test', config.data_path, config,
                                      self.vocab)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    @staticmethod
    def report_rouge(ref_path, dec_path):
        print("Now starting ROUGE eval...")
        files_rouge = FilesRouge(dec_path, ref_path)
        scores = files_rouge.get_scores(avg=True)
        logging(str(scores))

    #@staticmethod
    def get_summary(self, best_summary, batch):
        # Extract the output ids from the hypothesis and convert back to words
        output_ids = [int(t) for t in best_summary.tokens[1:]]
        decoded_words = output2words(
            output_ids, self.vocab,
            (batch.art_oovs[0] if self.config.pointer_gen else None))

        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index('<end>')
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words
        decoded_abstract = ' '.join(decoded_words)
        return decoded_abstract

    def decode(self):
        config = self.config
        start = time.time()
        counter = 0
        test_loader = DataLoader(
            self.test_data,
            batch_size=1,
            shuffle=False,
            collate_fn=Collate(beam_size=config.beam_size))

        ref = open(self._rouge_ref, 'w')
        dec = open(self._rouge_dec, 'w')

        for batch in test_loader:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            original_abstract = batch.original_abstract[0]
            decoded_abstract = self.get_summary(best_summary, batch)

            ref.write(original_abstract + '\n')
            dec.write(decoded_abstract + '\n')

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

        print("Decoder has finished reading dataset for single_pass.")
        ref.close()
        dec.close()
        self.report_rouge(self._rouge_ref, self._rouge_dec)

    def beam_search(self, batch):
        config = self.config
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, config, device)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size

        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id('<start>')],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id('<unk>') \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.tensor(latest_tokens)).to(device)

            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)
                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id('<end>'):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break
            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Exemplo n.º 5
0
class BeamSearch(object):
    def __init__(self, model_file_path, data_path, data_class='val'):
        self.data_class = data_class
        if self.data_class not in  ['val', 'test']:
            print("data_class must be 'val' or 'test'.")
            raise ValueError

        # model_file_path e.g. --> ../log/{MODE NAME}/best_model/model_best_XXXXX
        model_name = os.path.basename(model_file_path)
        # log_root e.g. --> ../log/{MODE NAME}/
        log_root = os.path.dirname(os.path.dirname(model_file_path))
        # _decode_dir e.g. --> ../log/{MODE NAME}/decode_model_best_XXXXX/
        self._decode_dir = os.path.join(log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        self._result_path = os.path.join(self._decode_dir, 'result_%s_%s.txt' \
                                                        % (model_name, self.data_class))
        # remove result file if exist
        if os.path.isfile(self._result_path):
            os.remove(self._result_path)
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(5)

        self.model = Model(model_file_path, is_eval=True)


    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def beam_search(self, batch):
        # batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0 # 1 x 2H
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        # decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in range(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]

            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []
            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)
                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps)

            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                   log_prob=topk_log_probs[i, j].item(),
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)
        return beams_sorted[0]
    
    
    def decode(self):
        start = time.time()
        counter = 0
        bleu_scores = []
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_articles = batch.original_articles[0]
            original_abstracts = batch.original_abstracts_sents[0]
            reference = original_abstracts[0].strip().split()
            bleu = nltk.translate.bleu_score.sentence_bleu([reference], decoded_words, weights = (0.5, 0.5))
            bleu_scores.append(bleu)

            # write_for_rouge(original_abstracts, decoded_words, counter,
            #                 self._rouge_ref_dir, self._rouge_dec_dir)

            write_for_result(original_articles, original_abstracts, decoded_words, \
                                                self._result_path, self.data_class)

            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()
        
        '''
        # uncomment this if you successfully install `pyrouge`
        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        '''

        if self.data_class == 'val':
            print('Average BLEU score:', np.mean(bleu_scores))
            with open(self._result_path, "a") as f:
                print('Average BLEU score:', np.mean(bleu_scores), file=f)

    def get_processed_path(self):
        # ../log/{MODE NAME}/decode_model_best_XXXXX/result_model_best_2800_{data_class}.txt
        input_path = self._result_path
        temp = os.path.splitext(input_path)
        # ../log/{MODE NAME}/decode_model_best_XXXXX/result_model_best_2800_{data_class}_processed.txt
        output_path = temp[0] + "_processed" + temp[1]
        return input_path, output_path