예제 #1
0
def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''):
    devDataStream.reset()
    gen = []
    ref = []
    dev_loss = 0.0
    dev_right = 0.0
    dev_total = 0.0
    for batch_index in xrange(devDataStream.get_num_batch()): # for each batch
        cur_batch = devDataStream.get_batch(batch_index)
        cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch)
        if valid_graph.mode == 'evaluate':
            accu_value, loss_value = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True)
            dev_loss += loss_value
            dev_right += accu_value
            dev_total += np.sum(cur_batch.sent_len)
        elif valid_graph.mode == 'evaluate_bleu':
            gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist())
            ref.extend(cur_batch.sent_out.tolist())
        else:
            assert False

    if valid_graph.mode == 'evaluate':
        return {'dev_loss':dev_loss, 'dev_accu':1.0*dev_right/dev_total, 'dev_right':dev_right, 'dev_total':dev_total, }
    else:
        return {'dev_bleu':document_bleu(valid_graph.word_vocab,gen,ref,suffix), }
def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''):
    devDataStream.reset()
    gen = []
    ref = []
    dev_loss = 0.0
    dev_right1 = 0.0
    dev_right2 = 0.0
    dev_total = 0.0
    for batch_index in xrange(devDataStream.get_num_batch()):  # for each batch
        cur_batch = devDataStream.get_batch(batch_index)
        cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch)
#         file = open('../data/train_output.txt', 'a+')
        if valid_graph.mode == 'evaluate':
            dic = valid_graph.word_vocab.id2word
            accu1_value, loss_value, vocab_score1, vocab_score2, greedy_words, loss_weights = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True)
            accu2_value = accu1_value[1]
            accu1_value = accu1_value[0]  # accu_value contains accu1 and accu2
            dev_loss += loss_value
            dev_right1 += accu1_value
            dev_right2 += accu2_value
            dev_total += np.sum(cur_batch.sent_len)
#             with tf.Session() as sess:
#                 # sess.run(tf.global_variables_initializer())
#                 vocab_score1 = sess.run([vocab_score1])
#                 greedy_words = sess.run([greedy_words])
#                 vocab_score2 = sess.run([vocab_score2])
#             file.write("target:       " + str(cur_batch.target_ref) + '\n')
#             file.write("out_seqs1:    " + str(_values_to_words(vocab_score1, loss_weights, dic)) + '\n')
#             file.write("greedy_words: " + str(_values_to_words(greedy_words, loss_weights, dic)) + '\n')
#             file.write("out_seqs2:    " + str(_values_to_words(vocab_score2, loss_weights, dic)) + '\n')
        elif valid_graph.mode == 'evaluate_bleu':
            gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist())
            ref.extend(cur_batch.sent_out.tolist())
        else:
            assert False

    if valid_graph.mode == 'evaluate':
        return {'dev_loss':dev_loss, 'dev_accu1':1.0*dev_right1/dev_total, 'dev_accu':1.0*dev_right2/dev_total, 'dev_right1':dev_right1, 'dev_right':dev_right2, 'dev_total':dev_total, }
    else:
        return {'dev_bleu':document_bleu(valid_graph.word_vocab,gen,ref,suffix), }
def fine_tune(sess, saver, FLAGS, log_file,
    ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu):
    print('=====Start the fine tuning.')
    sys.stdout.flush()
    max_steps = ftDataStream.get_num_batch() * 1
    best_path = path_prefix + ".best.model"
    total_loss = 0.0
    start_time = time.time()
    for step in xrange(max_steps):
        cur_batch = ftDataStream.nextBatch()
        cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch)
        if FLAGS.mode == 'rl_train':
            loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
        elif FLAGS.mode == 'ce_train':
            loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
        total_loss += loss_value

        if step % 100==0:
            print('{} '.format(step), end="")
            sys.stdout.flush()

        # Save a checkpoint and evaluate the model periodically.
        if (step + 1) % ftDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
            print()
            duration = time.time() - start_time
            print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
            sys.stdout.flush()
            log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
            log_file.flush()
            best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file,
                    devDataStream, valid_graph, path_prefix, best_accu, best_bleu)
            total_loss = 0.0
            start_time = time.time()

    print('=====End the fine tuning.')
    sys.stdout.flush()
    return best_accu, best_bleu
예제 #4
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    word_vocab_enc = None
    word_vocab_dec = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
        print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape))
        word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
        print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
    else:
        print('Collecting vocabs.')
        word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
        word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        allEdgelabels = set([line.strip().split()[0] \
                for line in open(FLAGS.edgelabel_vocab, 'rU')])
        edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab SRC size {}'.format(word_vocab_enc.vocab_size))
    print('word vocab TGT size {}'.format(word_vocab_dec.vocab_size))
    sys.stdout.flush()

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_from_fof(FLAGS.train_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading test set.')
    if FLAGS.infile_format == 'fof':
        testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_from_fof(FLAGS.test_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        testset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of test samples: {}'.format(len(testset)))

    max_node = max(trn_node, tst_node)
    max_in_neigh = max(trn_in_neigh, tst_in_neigh)
    max_out_neigh = max(trn_out_neigh, tst_out_neigh)
    max_sent = max(trn_sent, tst_sent)
    print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len))

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=True, isLoop=True, isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=False, isLoop=False, isSort=True)
    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode)

        assert FLAGS.mode in ('ce_train', 'rl_train', )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(word_vocab_enc=word_vocab_enc, word_vocab_dec=word_vocab_dec, Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab, options=FLAGS, mode=valid_mode)

        initializer = tf.global_variables_initializer()

        for var in tf.trainable_variables():
            print(var)

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            cur_batch = G2S_data_stream.G2SBatchPadd(cur_batch)
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100==0:
                print('{} '.format(step), end="")
                sys.stdout.flush()


            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or (step != 0 and step%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step))
                if valid_graph.mode == 'evaluate':
                    dev_loss = res_dict['dev_loss']
                    dev_accu = res_dict['dev_accu']
                    dev_right = int(res_dict['dev_right'])
                    dev_total = int(res_dict['dev_total'])
                    print('Dev loss = %.4f' % dev_loss)
                    log_file.write('Dev loss = %.4f\n' % dev_loss)
                    print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
                    log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total))
                    log_file.flush()
                    if best_accu < dev_accu:
                        print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
                        saver.save(sess, best_path)
                        best_accu = dev_accu
                        FLAGS.best_accu = dev_accu
                        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                else:
                    dev_bleu = res_dict['dev_bleu']
                    print('Dev bleu = %.4f' % dev_bleu)
                    log_file.write('Dev bleu = %.4f\n' % dev_bleu)
                    log_file.flush()
                    if best_bleu < dev_bleu:
                        print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu))
                        saver.save(sess, best_path)
                        best_bleu = dev_bleu
                        FLAGS.best_bleu = dev_bleu
                        namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
    FLAGS = G2S_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab",
                            fileformat='txt2')
    print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    char_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    print('Loading test set from {}.'.format(in_path))
    testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    if mode not in (
            'pointwise',
            'multinomial',
            'greedy',
            'greedy_evaluate',
    ):
        batch_size = 1

    devDataStream = G2S_data_stream.G2SDataStream(testset,
                                                  word_vocab,
                                                  char_vocab,
예제 #6
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(
        FLAGS.train_path)
    print('Number of training samples: {}'.format(len(trainset)))

    print('Loading dev set.')
    devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(
        FLAGS.test_path)
    print('Number of dev samples: {}'.format(len(devset)))

    if FLAGS.finetune_path != "":
        print('Loading finetune set.')
        ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file(
            FLAGS.finetune_path)
        print('Number of finetune samples: {}'.format(len(ftset)))
    else:
        ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0)

    max_node = max(trn_node, tst_node, ft_node)
    max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh)
    max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh)
    max_sent = max(trn_sent, tst_sent, ft_sent)
    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(
        max_sent, FLAGS.max_answer_len))

    word_vocab = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    if ftset != None:
        ftDataStream = G2S_data_stream.G2SDataStream(ftset,
                                                     word_vocab,
                                                     char_vocab,
                                                     edgelabel_vocab,
                                                     options=FLAGS,
                                                     isShuffle=True,
                                                     isLoop=True,
                                                     isSort=True)
        print('Number of instances in ftDataStream: {}'.format(
            ftDataStream.get_num_instance()))
        print('Number of batches in ftDataStream: {}'.format(
            ftDataStream.get_num_batch()))

    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in ('ce_train', 'rl_train', 'transformer')
        valid_mode = 'evaluate' if FLAGS.mode in (
            'ce_train', 'transformer') else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode in ('ce_train', 'rl_train',
                              'transformer') and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode in ('ce_train', 'rl_train', 'transformer'):
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                    (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                if ftset != None:
                    best_accu, best_bleu = fine_tune(sess, saver, FLAGS,
                                                     log_file, ftDataStream,
                                                     devDataStream,
                                                     train_graph, valid_graph,
                                                     path_prefix, best_accu,
                                                     best_bleu)
                else:
                    best_accu, best_bleu = validate_and_save(
                        sess, saver, FLAGS, log_file, devDataStream,
                        valid_graph, path_prefix, best_accu, best_bleu)
                start_time = time.time()

    log_file.close()
예제 #7
0
def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading train set.')
    if FLAGS.infile_format == 'fof':
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_from_fof(
            FLAGS.train_path, FLAGS)
    else:
        trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_nary_file(
            FLAGS.train_path, FLAGS)

    random.shuffle(trainset)
    devset = trainset[:200]
    trainset = trainset[200:]

    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of dev samples: {}'.format(len(devset)))

    max_node = trn_node
    max_in_neigh = trn_in_neigh
    max_out_neigh = trn_out_neigh
    max_sent = trn_sent
    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max entity size: {}, truncated to {}'.format(
        max_sent, FLAGS.max_entity_size))

    word_vocab = None
    char_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=False)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=False)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode='train')

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         options=FLAGS,
                                         mode='evaluate')

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        last_step = 0
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            _, loss_value, _ = train_graph.execute(sess,
                                                   cur_batch,
                                                   FLAGS,
                                                   is_train=True)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (
                    step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss / (step - last_step), duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss /
                                (step - last_step), duration))
                sys.stdout.flush()
                log_file.flush()
                last_step = step
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess,
                                    valid_graph,
                                    devDataStream,
                                    options=FLAGS,
                                    suffix=str(step))
                dev_loss = res_dict['dev_loss']
                dev_accu = res_dict['dev_accu']
                dev_right = int(res_dict['dev_right'])
                dev_total = int(res_dict['dev_total'])
                print('Dev loss = %.4f' % dev_loss)
                log_file.write('Dev loss = %.4f\n' % dev_loss)
                print('Dev accu = %.4f %d/%d' %
                      (dev_accu, dev_right, dev_total))
                log_file.write('Dev accu = %.4f %d/%d\n' %
                               (dev_accu, dev_right, dev_total))
                log_file.flush()
                if best_accu < dev_accu:
                    print('Saving weights, ACCU {} (prev_best) < {} (cur)'.
                          format(best_accu, dev_accu))
                    saver.save(sess, best_path)
                    best_accu = dev_accu
                    FLAGS.best_accu = dev_accu
                    namespace_utils.save_namespace(
                        FLAGS, path_prefix + ".config.json")
                    json.dump(res_dict['data'], open(FLAGS.output_path, 'w'))
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()

    log_file.close()
    char_vocab = None
    POS_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
    if FLAGS.with_POS:
        POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
        print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab",
                            fileformat='txt2')
    print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))

    print('Loading test set from {}.'.format(in_path))
    if hasattr(FLAGS, 'num_relations') == False:
        FLAGS.num_relations = 2
    testset = G2S_data_stream.read_bionlp_file(in_path, in_dep_path, FLAGS)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    testDataStream = G2S_data_stream.G2SDataStream(FLAGS,
                                                   testset,
                                                   word_vocab,
                                                   char_vocab,
                                                   POS_vocab,
                                                   edgelabel_vocab,
                                                   isShuffle=False,
                                                   isLoop=False,
                                                   isSort=False)
    print('Number of instances in testDataStream: {}'.format(
        testDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(
예제 #9
0
    print('Loading vocabs.')
    word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
    print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape))
    word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
    print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape))
    edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2')
    print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    char_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))


    print('Loading test set from {}.'.format(in_path))
    if FLAGS.infile_format == 'fof':
        testset, _, _, _, _ = G2S_data_stream.read_amr_from_fof(in_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    else:
        testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path, FLAGS,
                word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size=-1
    if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ):
        batch_size = 1

    devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab_enc, word_vocab_dec, char_vocab, edgelabel_vocab,
            options=FLAGS, isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size)
    print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch()))
예제 #10
0
def main(_):
    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    print('Loading data.')
    FLAGS.num_relations = 2
    trainset = G2S_data_stream.read_bionlp_file(FLAGS.train_path, FLAGS.train_dep_path, FLAGS)
    if FLAGS.dev_gen == 'shuffle':
        random.shuffle(trainset)
    elif FLAGS.dev_gen == 'last':
        trainset.reverse()
    N = int(len(trainset)*FLAGS.dev_percent)
    devset = trainset[:N]
    trainset = trainset[N:]

    print('Number of training samples: {}'.format(len(trainset)))
    print('Number of dev samples: {}'.format(len(devset)))
    print('Number of relations: {}'.format(FLAGS.num_relations))

    word_vocab = None
    char_vocab = None
    POS_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        if FLAGS.with_POS:
            POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')
            print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
        print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    else:
        print('Collecting vocabs.')
        all_words = set()
        all_chars = set()
        all_poses = set()
        all_edgelabels = set()
        G2S_data_stream.collect_vocabs(trainset, all_words, all_chars, all_poses, all_edgelabels)
        G2S_data_stream.collect_vocabs(devset, all_words, all_chars, all_poses, all_edgelabels)
        print('Number of words: {}'.format(len(all_words)))
        print('Number of chars: {}'.format(len(all_chars)))
        print('Number of poses: {}'.format(len(all_poses)))
        print('Number of edgelabels: {}'.format(len(all_edgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        if FLAGS.with_char:
            char_vocab = Vocab(voc=all_chars, dim=FLAGS.char_dim, fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        if FLAGS.with_POS:
            POS_vocab = Vocab(voc=all_poses, dim=FLAGS.POS_dim, fileformat='build')
            POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
        edgelabel_vocab = Vocab(voc=all_edgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(FLAGS, trainset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
            isShuffle=True, isLoop=True, isSort=True, is_training=True)

    devDataStream = G2S_data_stream.G2SDataStream(FLAGS, devset, word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
            isShuffle=False, isLoop=False, isSort=True)
    print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
    sys.stdout.flush()

    FLAGS.trn_bch_num = trainDataStream.get_num_batch()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None, initializer=initializer):
                train_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
                                         FLAGS, mode='train')

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model", reuse=True, initializer=initializer):
                valid_graph = ModelGraph(word_vocab, char_vocab, POS_vocab, edgelabel_vocab,
                                         FLAGS, mode='evaluate')

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if abs(best_accu) < 1e-5:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess, valid_graph, devDataStream, FLAGS)['dev_f1']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        last_step = 0
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            _, _, cur_loss, _ = train_graph.execute(sess, cur_batch, FLAGS, is_train=True)
            total_loss += cur_loss

            if step % 100==0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss/(step-last_step), duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss/(step-last_step), duration))
                sys.stdout.flush()
                log_file.flush()
                last_step = step
                total_loss = 0.0

                # Evaluate against the validation set.
                start_time = time.time()
                print('Validation Data Eval:')
                res_dict = evaluate(sess, valid_graph, devDataStream, FLAGS)
                dev_loss = res_dict['dev_loss']
                dev_accu = res_dict['dev_f1']
                dev_precision = res_dict['dev_precision']
                dev_recall = res_dict['dev_recall']
                print('Dev loss = %.4f' % dev_loss)
                log_file.write('Dev loss = %.4f\n' % dev_loss)
                print('Dev F1 = %.4f, P = %.4f, R = %.4f' % (dev_accu, dev_precision, dev_recall))
                log_file.write('Dev F1 = %.4f, P =  %.4f, R = %.4f\n' % (dev_accu, dev_precision, dev_recall))
                log_file.flush()
                if best_accu < dev_accu:
                    print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
                    saver.save(sess, best_path)
                    best_accu = dev_accu
                    FLAGS.best_accu = dev_accu
                    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
                duration = time.time() - start_time
                print('Duration %.3f sec' % (duration))
                sys.stdout.flush()

                log_file.write('Duration %.3f sec\n' % (duration))
                log_file.flush()
                start_time = time.time()

    log_file.close()
예제 #11
0
    # load vocabs
    print('Loading vocabs.')
    word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
    print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
    edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab",
                            fileformat='txt2')
    print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    char_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    print('Loading test set from {}.'.format(in_path))
    if FLAGS.infile_format == 'fof':
        testset = G2S_data_stream.read_nary_from_fof(in_path,
                                                     FLAGS,
                                                     is_rev=False)
        testset_rev = G2S_data_stream.read_nary_from_fof(in_path,
                                                         FLAGS,
                                                         is_rev=True)
    else:
        testset = G2S_data_stream.read_nary_file(in_path, FLAGS, is_rev=False)
        testset_rev = G2S_data_stream.read_nary_file(in_path,
                                                     FLAGS,
                                                     is_rev=True)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    devDataStream = G2S_data_stream.G2SDataStream(testset,
                                                  word_vocab,
예제 #12
0
    # load the configuration file
    print('Loading configurations from ' + model_prefix + ".config.json")
    FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
    FLAGS = G2S_trainer.enrich_options(FLAGS)

    # load vocabs
    print('Loading vocabs.')
    word_vocab_enc = Vocab(FLAGS.word_vec_src_path, fileformat='txt2')
    print('word_vocab SRC: {}'.format(word_vocab_enc.word_vecs.shape))
    word_vocab_dec = Vocab(FLAGS.word_vec_tgt_path, fileformat='txt2')
    print('word_vocab TGT: {}'.format(word_vocab_dec.word_vecs.shape))

    print('Loading test set from {}.'.format(in_path))
    if FLAGS.infile_format == 'fof':
        testset, _, _, _ = G2S_data_stream.read_amr_from_fof(in_path)
    else:
        testset, _, _, _ = G2S_data_stream.read_amr_file(in_path)
    print('Number of samples: {}'.format(len(testset)))

    print('Build DataStream ... ')
    batch_size = -1
    if mode not in (
            'pointwise',
            'multinomial',
            'greedy',
            'greedy_evaluate',
    ):
        batch_size = 1

    devDataStream = G2S_data_stream.G2SDataStream(testset,