Example #1
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.train_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.train_path, isLower=FLAGS.isLower)
    else:
        trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(
            FLAGS.test_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.test_path, isLower=FLAGS.isLower)
    else:
        testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(
            FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of test samples: {}'.format(len(testset)))

    max_actual_len = max(train_ans_len, test_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    word_vocab = None
    POS_vocab = None
    NER_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
            print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        if FLAGS.with_NER:
            NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2')
            print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars, allPOSs,
         allNERs) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allPOSs: {}'.format(len(allPOSs)))
        print('Number of allNERs: {}'.format(len(allNERs)))

        if FLAGS.with_word:
            word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=allPOSs,
                              dim=FLAGS.POS_dim,
                              fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        if FLAGS.with_NER:
            NER_vocab = Vocab(voc=allNERs,
                              dim=FLAGS.NER_dim,
                              fileformat='build')
            NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.QADataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    POS_vocab,
                                                    NER_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' %
                          (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' %
                                   (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                              format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.
                              format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(
                            FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    if FLAGS.with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, _ = NP2P_data_stream.read_generation_datasets_from_fof(
            in_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, _ = NP2P_data_stream.read_all_GenerationDatasets(
            in_path, isLower=FLAGS.isLower)
    else:
        testset, _ = NP2P_data_stream.read_all_GQA_questions(
            in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    if mode not in (
            'pointwise',
            'multinomial',
            'greedy',
            'greedy_evaluate',
    ):
        batch_size = 1
    devDataStream = NP2P_data_stream.DataStream(testset,
Example #3
0
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2')
    print('action_vocab: {}'.format(action_vocab.word_vecs.shape))
    feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2')
    print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape))

    print('Loading test set.')
    if use_dep:
        testset = NP2P_data_stream.read_Testset(in_path, ulfdep=args.ulf)
    elif FLAGS.infile_format == 'fof':
        testset = NP2P_data_stream.read_generation_datasets_from_fof(
            in_path, isLower=FLAGS.isLower, ulfdep=args.ulf)
    else:
        testset = NP2P_data_stream.read_all_GenerationDatasets(
            in_path, isLower=FLAGS.isLower, ulfdep=args.ulf)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = 1
    assert batch_size == 1

    devDataStream = NP2P_data_stream.DataStream(testset,
                                                word_vocab=word_vocab,
                                                char_vocab=char_vocab,
                                                POS_vocab=POS_vocab,
                                                feat_vocab=feat_vocab,
                                                action_vocab=action_vocab,
                                                options=FLAGS,
                                                isShuffle=False,
                                                isLoop=False,
Example #4
0
            all_spans = re.split("\\s+", self.syntaxSpans)
            for i in xrange(len(all_spans)):
                cur_span = all_spans[i]
                items = re.split("-", cur_span)
                cur_start = int(items[0])
                cur_end = int(items[1])
                cur_label = items[2]
                if cur_end - cur_start >= max_chunk_len: continue
                self.chunk_starts.append(cur_start)
                self.chunk_ends.append(cur_end)
                self.chunk_labels.append(cur_label)
        return (self.chunk_starts, self.chunk_ends, self.chunk_labels)


if __name__ == "__main__":
    import NP2P_data_stream
    inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data/val.json.tok"
    all_instances, _ = NP2P_data_stream.read_all_GenerationDatasets(
        inpath, isLower=True)
    sample_instance = all_instances[0][1]
    print('Raw text: {}'.format(sample_instance.rawText))
    (chunk_starts, chunk_ends,
     chunk_labels) = sample_instance.collect_all_syntax_chunks(5)
    for i in xrange(len(chunk_starts)):
        cur_start = chunk_starts[i]
        cur_end = chunk_ends[i]
        cur_label = chunk_labels[i]
        cur_text = sample_instance.getTokChunk(cur_start, cur_end)
        print("{}-{}-{}:{}".format(cur_start, cur_end, cur_label, cur_text))
    print("DONE!")
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading training set.')
    trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
        FLAGS.train_path, isLower=FLAGS.isLower)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading dev set.')
    devset, dev_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
        FLAGS.test_path, isLower=FLAGS.isLower)
    print('Number of dev samples: {}'.format(len(devset)))

    if FLAGS.finetune_path != "":
        print('Loading finetune set.')
        ftset, ft_ans_len = NP2P_data_stream.read_all_GenerationDatasets(
            FLAGS.ft_path, isLower=FLAGS.isLower)
        print('Number of finetune samples: {}'.format(len(ftset)))
    else:
        ftset, ft_ans_len = (None, 0)

    max_actual_len = max(train_ans_len, ft_ans_len, dev_ans_len)
    print('Max answer length: {}, truncated to {}'.format(
        max_actual_len, FLAGS.max_answer_len))

    enc_word_vocab = None
    dec_word_vocab = None
    char_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        if FLAGS.with_word:
            enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
            dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
            print('Encoder word vocab: {}'.format(
                enc_word_vocab.word_vecs.shape))
            print('Decoder word vocab: {}'.format(
                dec_word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        (allWords, allChars) = NP2P_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))

        if FLAGS.with_word:
            enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
            dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")

    print('Encoder word vocab size {}'.format(enc_word_vocab.vocab_size))
    print('Decoder word vocab size {}'.format(dec_word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = NP2P_data_stream.DataStream(trainset,
                                                  enc_word_vocab,
                                                  dec_word_vocab,
                                                  char_vocab,
                                                  options=FLAGS,
                                                  isShuffle=True,
                                                  isLoop=True,
                                                  isSort=True)
    devDataStream = NP2P_data_stream.DataStream(devset,
                                                enc_word_vocab,
                                                dec_word_vocab,
                                                char_vocab,
                                                options=FLAGS,
                                                isShuffle=False,
                                                isLoop=False,
                                                isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    if ftset != None:
        ftDataStream = NP2P_data_stream.DataStream(ftset,
                                                   enc_word_vocab,
                                                   dec_word_vocab,
                                                   char_vocab,
                                                   options=FLAGS,
                                                   isShuffle=True,
                                                   isLoop=True,
                                                   isSort=True)
        print('Number of instances in ftDataStream: {}'.format(
            ftDataStream.get_num_instance()))
        print('Number of batches in ftDataStream: {}'.format(
            ftDataStream.get_num_batch()))

    sys.stdout.flush()

    init_scale = 0.01
    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(enc_word_vocab=enc_word_vocab,
                                         dec_word_vocab=dec_word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=None,
                                         NER_vocab=None,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab,
                                         dec_word_vocab=dec_word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=None,
                                         NER_vocab=None,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            print(var)
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_2(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                     (trainDataStream.get_num_batch() > 10000 and (step + 1) % 2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                if ftset != None:
                    best_accu, best_bleu = fine_tune(sess, saver, FLAGS,
                                                     log_file, ftDataStream,
                                                     devDataStream,
                                                     train_graph, valid_graph,
                                                     path_prefix, best_accu,
                                                     best_bleu)
                else:
                    best_accu, best_bleu = validate_and_save(
                        sess, saver, FLAGS, log_file, devDataStream,
                        valid_graph, path_prefix, best_accu, best_bleu)
                start_time = time.time()

    log_file.close()
Example #6
0
def question_gen_run(argv):
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
    #parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.')
    #parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.')
    #parser.add_argument('--mode', type=str, required=True, help='Can be `greedy` or `beam`')

    #args, unparsed = parser.parse_known_args()

    #model_prefix = args.model_prefix
    #in_path = args.in_path
    #out_path = args.out_path
    #mode = args.mode
    print(sys.argv)
    model_prefix = argv[0]
    in_path = argv[1]
    out_path = argv[2]
    mode = argv[3]

    print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])

    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = NP2P_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = char_vocab = POS_vocab = NER_vocab = None
    if FLAGS.with_word:
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    if FLAGS.with_NER:
        NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
        print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, _ = NP2P_data_stream.read_generation_datasets_from_fof(
            in_path, isLower=FLAGS.isLower)
    elif FLAGS.infile_format == 'plain':
        testset, _ = NP2P_data_stream.read_all_GenerationDatasets(
            in_path, isLower=FLAGS.isLower)
    else:
        testset, _ = NP2P_data_stream.read_all_GQA_questions(
            in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    if mode.find('beam') >= 0: batch_size = 1
    devDataStream = NP2P_data_stream.QADataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
                                                  POS_vocab,
                                                  NER_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True,
                                                  batch_size=batch_size)
    print('Number of instances in testDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
        devDataStream.get_num_batch()))

    best_path = model_prefix + ".best.model"
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-0.01, 0.01)
        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=False,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         NER_vocab=NER_vocab,
                                         options=FLAGS,
                                         mode="decode")

        ## remove word _embedding
        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        initializer = tf.global_variables_initializer()
        #gpu_fraction = 0.1
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
        #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess = tf.Session()
        sess.run(initializer)

        saver.restore(sess, best_path)  # restore the model

        total = 0
        correct = 0
        if mode.endswith('evaluate'):
            ref_outfile = open(out_path + ".ref", 'wt')
            pred_outfile = open(out_path + ".pred", 'wt')
        else:
            outfile = open(out_path, 'wt')
        total_num = devDataStream.get_num_batch()
        devDataStream.reset()
        for i in range(total_num):
            cur_batch = devDataStream.get_batch(i)
            if mode == 'pointwise':
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode=mode)
                for j in xrange(cur_batch.batch_size):
                    cur_total = cur_batch.answer_lengths[j]
                    cur_correct = 0
                    for k in xrange(cur_total):
                        if generator_output_idx[
                                j, k] == cur_batch.in_answer_words[j, k]:
                            cur_correct += 1.0
                    total += cur_total
                    correct += cur_correct
                    outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    outfile.write(sentences[j].encode('utf-8') + "\n")
                    outfile.write("========\n")
                outfile.flush()
                print('Current dev accuracy is %d/%d=%.2f' %
                      (correct, total, correct / float(total) * 100))
            elif mode in ['greedy', 'multinomial']:
                print('Batch {}'.format(i))
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode=mode)
                for j in xrange(cur_batch.batch_size):
                    outfile.write(
                        cur_batch.instances[j][1].ID_num.encode('utf-8') +
                        "\n")
                    outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    outfile.write(sentences[j].encode('utf-8') + "\n")
                    outfile.write("========\n")
                outfile.flush()
            elif mode == 'greedy_evaluate':
                print('Batch {}'.format(i))
                (sentences, prediction_lengths, generator_input_idx,
                 generator_output_idx) = search(sess,
                                                valid_graph,
                                                word_vocab,
                                                cur_batch,
                                                FLAGS,
                                                decode_mode="greedy")
                for j in xrange(cur_batch.batch_size):
                    ref_outfile.write(
                        cur_batch.instances[j][1].tokText.encode('utf-8') +
                        "\n")
                    pred_outfile.write(sentences[j].encode('utf-8') + "\n")
                ref_outfile.flush()
                pred_outfile.flush()
            elif mode == 'beam_evaluate':
                print('Instance {}'.format(i))
                ref_outfile.write(
                    cur_batch.instances[0][1].tokText.encode('utf-8') + "\n")
                ref_outfile.flush()
                hyps = run_beam_search(sess, valid_graph, word_vocab,
                                       cur_batch, FLAGS)
                cur_passage = cur_batch.instances[0][0]
                cur_id2phrase = None
                if FLAGS.with_phrase_projection:
                    (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0]
                cur_sent = hyps[0].idx_seq_to_string(cur_passage,
                                                     cur_id2phrase, word_vocab,
                                                     FLAGS)
                pred_outfile.write(cur_sent.encode('utf-8') + "\n")
                pred_outfile.flush()
            else:  # beam search
                print('Instance {}'.format(i))
                hyps = run_beam_search(sess, valid_graph, word_vocab,
                                       cur_batch, FLAGS)
                outfile.write(
                    "Input: " +
                    cur_batch.instances[0][0].tokText.encode('utf-8') + "\n")
                outfile.write(
                    "Truth: " +
                    cur_batch.instances[0][1].tokText.encode('utf-8') + "\n")
                for j in xrange(len(hyps)):
                    hyp = hyps[j]
                    cur_passage = cur_batch.instances[0][0]
                    cur_id2phrase = None
                    if FLAGS.with_phrase_projection:
                        (cur_phrase2id,
                         cur_id2phrase) = cur_batch.phrase_vocabs[0]
                    cur_sent = hyp.idx_seq_to_string(cur_passage,
                                                     cur_id2phrase, word_vocab,
                                                     FLAGS)
                    outfile.write("Hyp-{}: ".format(j) +
                                  cur_sent.encode('utf-8') +
                                  " {}".format(hyp.avg_log_prob()) + "\n")
                #outfile.write("========\n")
                outfile.flush()
        if mode.endswith('evaluate'):
            ref_outfile.close()
            pred_outfile.close()
        else:
            outfile.close()