def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading train set.') if FLAGS.infile_format == 'fof': trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.train_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.train_path, isLower=FLAGS.isLower) else: trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.train_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of training samples: {}'.format(len(trainset))) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof( FLAGS.test_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': testset, test_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.test_path, isLower=FLAGS.isLower) else: testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions( FLAGS.test_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of test samples: {}'.format(len(testset))) max_actual_len = max(train_ans_len, test_ans_len) print('Max answer length: {}, truncated to {}'.format( max_actual_len, FLAGS.max_answer_len)) word_vocab = None POS_vocab = None NER_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(path_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) print('Number of allPOSs: {}'.format(len(allPOSs))) print('Number of allNERs: {}'.format(len(allNERs))) if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") if FLAGS.with_POS: POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build') POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") if FLAGS.with_NER: NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build') NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab") print('word vocab size {}'.format(word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ( 'ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_2( sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or ( step + 1) == max_steps: print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 # Evaluate against the validation set. start_time = time.time() print('Validation Data Eval:') res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS, suffix=str(step)) if valid_graph.mode == 'evaluate': dev_loss = res_dict['dev_loss'] dev_accu = res_dict['dev_accu'] dev_right = int(res_dict['dev_right']) dev_total = int(res_dict['dev_total']) print('Dev loss = %.4f' % dev_loss) log_file.write('Dev loss = %.4f\n' % dev_loss) print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) log_file.flush() if best_accu < dev_accu: print('Saving weights, ACCU {} (prev_best) < {} (cur)'. format(best_accu, dev_accu)) saver.save(sess, best_path) best_accu = dev_accu FLAGS.best_accu = dev_accu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") else: dev_bleu = res_dict['dev_bleu'] print('Dev bleu = %.4f' % dev_bleu) log_file.write('Dev bleu = %.4f\n' % dev_bleu) log_file.flush() if best_bleu < dev_bleu: print('Saving weights, BLEU {} (prev_best) < {} (cur)'. format(best_bleu, dev_bleu)) saver.save(sess, best_path) best_bleu = dev_bleu FLAGS.best_bleu = dev_bleu namespace_utils.save_namespace( FLAGS, path_prefix + ".config.json") duration = time.time() - start_time print('Duration %.3f sec' % (duration)) sys.stdout.flush() log_file.write('Duration %.3f sec\n' % (duration)) log_file.flush() log_file.close()
if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, _ = NP2P_data_stream.read_generation_datasets_from_fof( in_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': testset, _ = NP2P_data_stream.read_all_GenerationDatasets( in_path, isLower=FLAGS.isLower) else: testset, _ = NP2P_data_stream.read_all_GQA_questions( in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1 if mode not in ( 'pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): batch_size = 1 devDataStream = NP2P_data_stream.DataStream(testset,
if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) action_vocab = Vocab(model_prefix + ".action_vocab", fileformat='txt2') print('action_vocab: {}'.format(action_vocab.word_vecs.shape)) feat_vocab = Vocab(model_prefix + ".feat_vocab", fileformat='txt2') print('feat_vocab: {}'.format(feat_vocab.word_vecs.shape)) print('Loading test set.') if use_dep: testset = NP2P_data_stream.read_Testset(in_path, ulfdep=args.ulf) elif FLAGS.infile_format == 'fof': testset = NP2P_data_stream.read_generation_datasets_from_fof( in_path, isLower=FLAGS.isLower, ulfdep=args.ulf) else: testset = NP2P_data_stream.read_all_GenerationDatasets( in_path, isLower=FLAGS.isLower, ulfdep=args.ulf) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = 1 assert batch_size == 1 devDataStream = NP2P_data_stream.DataStream(testset, word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, feat_vocab=feat_vocab, action_vocab=action_vocab, options=FLAGS, isShuffle=False, isLoop=False,
all_spans = re.split("\\s+", self.syntaxSpans) for i in xrange(len(all_spans)): cur_span = all_spans[i] items = re.split("-", cur_span) cur_start = int(items[0]) cur_end = int(items[1]) cur_label = items[2] if cur_end - cur_start >= max_chunk_len: continue self.chunk_starts.append(cur_start) self.chunk_ends.append(cur_end) self.chunk_labels.append(cur_label) return (self.chunk_starts, self.chunk_ends, self.chunk_labels) if __name__ == "__main__": import NP2P_data_stream inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data/val.json.tok" all_instances, _ = NP2P_data_stream.read_all_GenerationDatasets( inpath, isLower=True) sample_instance = all_instances[0][1] print('Raw text: {}'.format(sample_instance.rawText)) (chunk_starts, chunk_ends, chunk_labels) = sample_instance.collect_all_syntax_chunks(5) for i in xrange(len(chunk_starts)): cur_start = chunk_starts[i] cur_end = chunk_ends[i] cur_label = chunk_labels[i] cur_text = sample_instance.getTokChunk(cur_start, cur_end) print("{}-{}-{}:{}".format(cur_start, cur_end, cur_label, cur_text)) print("DONE!")
def main(_): print('Configurations:') print(FLAGS) log_dir = FLAGS.model_dir if not os.path.exists(log_dir): os.makedirs(log_dir) path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) log_file_path = path_prefix + ".log" print('Log file path: {}'.format(log_file_path)) log_file = open(log_file_path, 'wt') log_file.write("{}\n".format(FLAGS)) log_file.flush() # save configuration namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('Loading training set.') trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.train_path, isLower=FLAGS.isLower) print('Number of training samples: {}'.format(len(trainset))) print('Loading dev set.') devset, dev_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.test_path, isLower=FLAGS.isLower) print('Number of dev samples: {}'.format(len(devset))) if FLAGS.finetune_path != "": print('Loading finetune set.') ftset, ft_ans_len = NP2P_data_stream.read_all_GenerationDatasets( FLAGS.ft_path, isLower=FLAGS.isLower) print('Number of finetune samples: {}'.format(len(ftset))) else: ftset, ft_ans_len = (None, 0) max_actual_len = max(train_ans_len, ft_ans_len, dev_ans_len) print('Max answer length: {}, truncated to {}'.format( max_actual_len, FLAGS.max_answer_len)) enc_word_vocab = None dec_word_vocab = None char_vocab = None has_pretrained_model = False best_path = path_prefix + ".best.model" if os.path.exists(best_path + ".index"): has_pretrained_model = True print('!!Existing pretrained model. Loading vocabs.') if FLAGS.with_word: enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2') dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2') print('Encoder word vocab: {}'.format( enc_word_vocab.word_vecs.shape)) print('Decoder word vocab: {}'.format( dec_word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) else: print('Collecting vocabs.') (allWords, allChars) = NP2P_data_stream.collect_vocabs(trainset) print('Number of words: {}'.format(len(allWords))) print('Number of allChars: {}'.format(len(allChars))) if FLAGS.with_word: enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2') dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2') if FLAGS.with_char: char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') char_vocab.dump_to_txt2(path_prefix + ".char_vocab") print('Encoder word vocab size {}'.format(enc_word_vocab.vocab_size)) print('Decoder word vocab size {}'.format(dec_word_vocab.vocab_size)) sys.stdout.flush() print('Build DataStream ... ') trainDataStream = NP2P_data_stream.DataStream(trainset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) devDataStream = NP2P_data_stream.DataStream(devset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True) print('Number of instances in trainDataStream: {}'.format( trainDataStream.get_num_instance())) print('Number of instances in devDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in trainDataStream: {}'.format( trainDataStream.get_num_batch())) print('Number of batches in devDataStream: {}'.format( devDataStream.get_num_batch())) if ftset != None: ftDataStream = NP2P_data_stream.DataStream(ftset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, isShuffle=True, isLoop=True, isSort=True) print('Number of instances in ftDataStream: {}'.format( ftDataStream.get_num_instance())) print('Number of batches in ftDataStream: {}'.format( ftDataStream.get_num_batch())) sys.stdout.flush() init_scale = 0.01 # initialize the best bleu and accu scores for current training session best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 if best_accu > 0.0: print('With initial dev accuracy {}'.format(best_accu)) if best_bleu > 0.0: print('With initial dev BLEU score {}'.format(best_bleu)) with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.name_scope("Train"): with tf.variable_scope("Model", reuse=None, initializer=initializer): train_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab, POS_vocab=None, NER_vocab=None, options=FLAGS, mode=FLAGS.mode) assert FLAGS.mode in ( 'ce_train', 'rl_train', ) valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=True, initializer=initializer): valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab, POS_vocab=None, NER_vocab=None, options=FLAGS, mode=valid_mode) initializer = tf.global_variables_initializer() vars_ = {} for var in tf.all_variables(): if FLAGS.fix_word_vec and "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue print(var) vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) sess = tf.Session() sess.run(initializer) if has_pretrained_model: print("Restoring model from " + best_path) saver.restore(sess, best_path) print("DONE!") if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: print("Getting BLEU score for the model") best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] FLAGS.best_bleu = best_bleu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('BLEU = %.4f' % best_bleu) log_file.write('BLEU = %.4f\n' % best_bleu) if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: print("Getting ACCU score for the model") best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] FLAGS.best_accu = best_accu namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") print('ACCU = %.4f' % best_accu) log_file.write('ACCU = %.4f\n' % best_accu) print('Start the training loop.') train_size = trainDataStream.get_num_batch() max_steps = train_size * FLAGS.max_epochs total_loss = 0.0 start_time = time.time() for step in xrange(max_steps): cur_batch = trainDataStream.nextBatch() if FLAGS.mode == 'rl_train': loss_value = train_graph.run_rl_training_2( sess, cur_batch, FLAGS) elif FLAGS.mode == 'ce_train': loss_value = train_graph.run_ce_training( sess, cur_batch, FLAGS) total_loss += loss_value if step % 100 == 0: print('{} '.format(step), end="") sys.stdout.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ (trainDataStream.get_num_batch() > 10000 and (step + 1) % 2000 == 0): print() duration = time.time() - start_time print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) log_file.flush() sys.stdout.flush() total_loss = 0.0 if ftset != None: best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file, ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu) else: best_accu, best_bleu = validate_and_save( sess, saver, FLAGS, log_file, devDataStream, valid_graph, path_prefix, best_accu, best_bleu) start_time = time.time() log_file.close()
def question_gen_run(argv): #parser = argparse.ArgumentParser() #parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.') #parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.') #parser.add_argument('--out_path', type=str, required=True, help='The path to the output file.') #parser.add_argument('--mode', type=str, required=True, help='Can be `greedy` or `beam`') #args, unparsed = parser.parse_known_args() #model_prefix = args.model_prefix #in_path = args.in_path #out_path = args.out_path #mode = args.mode print(sys.argv) model_prefix = argv[0] in_path = argv[1] out_path = argv[2] mode = argv[3] print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) # load the configuration file print('Loading configurations from ' + model_prefix + ".config.json") FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") FLAGS = NP2P_trainer.enrich_options(FLAGS) # load vocabs print('Loading vocabs.') word_vocab = char_vocab = POS_vocab = NER_vocab = None if FLAGS.with_word: word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) if FLAGS.with_char: char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) if FLAGS.with_POS: POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) if FLAGS.with_NER: NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) print('Loading test set.') if FLAGS.infile_format == 'fof': testset, _ = NP2P_data_stream.read_generation_datasets_from_fof( in_path, isLower=FLAGS.isLower) elif FLAGS.infile_format == 'plain': testset, _ = NP2P_data_stream.read_all_GenerationDatasets( in_path, isLower=FLAGS.isLower) else: testset, _ = NP2P_data_stream.read_all_GQA_questions( in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) print('Number of samples: {}'.format(len(testset))) print('Build DataStream ... ') batch_size = -1 if mode.find('beam') >= 0: batch_size = 1 devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size) print('Number of instances in testDataStream: {}'.format( devDataStream.get_num_instance())) print('Number of batches in testDataStream: {}'.format( devDataStream.get_num_batch())) best_path = model_prefix + ".best.model" with tf.Graph().as_default(): initializer = tf.random_uniform_initializer(-0.01, 0.01) with tf.name_scope("Valid"): with tf.variable_scope("Model", reuse=False, initializer=initializer): valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab, options=FLAGS, mode="decode") ## remove word _embedding vars_ = {} for var in tf.all_variables(): if "word_embedding" in var.name: continue if not var.name.startswith("Model"): continue vars_[var.name.split(":")[0]] = var saver = tf.train.Saver(vars_) initializer = tf.global_variables_initializer() #gpu_fraction = 0.1 #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction) #sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) sess = tf.Session() sess.run(initializer) saver.restore(sess, best_path) # restore the model total = 0 correct = 0 if mode.endswith('evaluate'): ref_outfile = open(out_path + ".ref", 'wt') pred_outfile = open(out_path + ".pred", 'wt') else: outfile = open(out_path, 'wt') total_num = devDataStream.get_num_batch() devDataStream.reset() for i in range(total_num): cur_batch = devDataStream.get_batch(i) if mode == 'pointwise': (sentences, prediction_lengths, generator_input_idx, generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode=mode) for j in xrange(cur_batch.batch_size): cur_total = cur_batch.answer_lengths[j] cur_correct = 0 for k in xrange(cur_total): if generator_output_idx[ j, k] == cur_batch.in_answer_words[j, k]: cur_correct += 1.0 total += cur_total correct += cur_correct outfile.write( cur_batch.instances[j][1].tokText.encode('utf-8') + "\n") outfile.write(sentences[j].encode('utf-8') + "\n") outfile.write("========\n") outfile.flush() print('Current dev accuracy is %d/%d=%.2f' % (correct, total, correct / float(total) * 100)) elif mode in ['greedy', 'multinomial']: print('Batch {}'.format(i)) (sentences, prediction_lengths, generator_input_idx, generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode=mode) for j in xrange(cur_batch.batch_size): outfile.write( cur_batch.instances[j][1].ID_num.encode('utf-8') + "\n") outfile.write( cur_batch.instances[j][1].tokText.encode('utf-8') + "\n") outfile.write(sentences[j].encode('utf-8') + "\n") outfile.write("========\n") outfile.flush() elif mode == 'greedy_evaluate': print('Batch {}'.format(i)) (sentences, prediction_lengths, generator_input_idx, generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="greedy") for j in xrange(cur_batch.batch_size): ref_outfile.write( cur_batch.instances[j][1].tokText.encode('utf-8') + "\n") pred_outfile.write(sentences[j].encode('utf-8') + "\n") ref_outfile.flush() pred_outfile.flush() elif mode == 'beam_evaluate': print('Instance {}'.format(i)) ref_outfile.write( cur_batch.instances[0][1].tokText.encode('utf-8') + "\n") ref_outfile.flush() hyps = run_beam_search(sess, valid_graph, word_vocab, cur_batch, FLAGS) cur_passage = cur_batch.instances[0][0] cur_id2phrase = None if FLAGS.with_phrase_projection: (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0] cur_sent = hyps[0].idx_seq_to_string(cur_passage, cur_id2phrase, word_vocab, FLAGS) pred_outfile.write(cur_sent.encode('utf-8') + "\n") pred_outfile.flush() else: # beam search print('Instance {}'.format(i)) hyps = run_beam_search(sess, valid_graph, word_vocab, cur_batch, FLAGS) outfile.write( "Input: " + cur_batch.instances[0][0].tokText.encode('utf-8') + "\n") outfile.write( "Truth: " + cur_batch.instances[0][1].tokText.encode('utf-8') + "\n") for j in xrange(len(hyps)): hyp = hyps[j] cur_passage = cur_batch.instances[0][0] cur_id2phrase = None if FLAGS.with_phrase_projection: (cur_phrase2id, cur_id2phrase) = cur_batch.phrase_vocabs[0] cur_sent = hyp.idx_seq_to_string(cur_passage, cur_id2phrase, word_vocab, FLAGS) outfile.write("Hyp-{}: ".format(j) + cur_sent.encode('utf-8') + " {}".format(hyp.avg_log_prob()) + "\n") #outfile.write("========\n") outfile.flush() if mode.endswith('evaluate'): ref_outfile.close() pred_outfile.close() else: outfile.close()