def main(_): pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.experiment_dir): os.makedirs(FLAGS.experiment_dir) expt_num = "1" else: expts = os.listdir(FLAGS.experiment_dir) last_expr = max([int(folder) for folder in expts]) expt_num = str(last_expr + 1) expt_result_path = os.path.join(FLAGS.experiment_dir, expt_num) os.makedirs(expt_result_path) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) chkpt_result_path = os.path.join(FLAGS.checkpoint_dir, expt_num) os.makedirs(chkpt_result_path) params_e_path = os.path.join(expt_result_path, "params.json") params_c_path = os.path.join(chkpt_result_path, "params.json") with open(params_e_path, 'w') as params_e, \ open(params_c_path, 'w') as params_c: json.dump(flags.FLAGS.__flags, params_e) json.dump(flags.FLAGS.__flags, params_c) # Generate the indexes word2idx, field2idx, qword2idx, nF, max_words_in_table, word_set = \ setup(FLAGS.data_dir, '../embeddings', FLAGS.n, FLAGS.batch_size, FLAGS.nW, FLAGS.min_field_freq, FLAGS.nQ) # Create the dataset objects train_dataset = DataSet(FLAGS.data_dir, 'train', FLAGS.n, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.batch_size, word2idx, field2idx, qword2idx, FLAGS.max_fields, FLAGS.word_max_fields, max_words_in_table, word_set) num_train_examples = train_dataset.num_examples() valid_dataset = DataSet(FLAGS.data_dir, 'valid', FLAGS.n, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.batch_size, word2idx, field2idx, qword2idx, FLAGS.max_fields, FLAGS.word_max_fields, max_words_in_table, word_set) test_dataset = DataSet(FLAGS.data_dir, 'test', FLAGS.n, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.batch_size, word2idx, field2idx, qword2idx, FLAGS.max_fields, FLAGS.word_max_fields, max_words_in_table, word_set) # The sizes of respective conditioning variables # for placeholder generation context_size = (FLAGS.n - 1) zp_size = context_size * FLAGS.word_max_fields zm_size = context_size * FLAGS.word_max_fields gf_size = FLAGS.max_fields gw_size = max_words_in_table copy_size = FLAGS.word_max_fields proj_size = FLAGS.nW + max_words_in_table # Generate the TensorFlow graph with tf.Graph().as_default(): #Set the random seed for reproducibility tf.set_random_seed(1234) # Create the CopyAttention model model = CopyAttention(FLAGS.n, FLAGS.d, FLAGS.g, FLAGS.nhu, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.learning_rate, max_words_in_table, FLAGS.max_fields, FLAGS.word_max_fields, FLAGS.xavier) # Placeholders for train and validation context_pl, zp_pl, zm_pl, gf_pl, gw_pl, next_pl, copy_pl, proj_pl = \ placeholder_inputs(FLAGS.batch_size, context_size, zp_size, zm_size, gf_size, gw_size, copy_size, proj_size) # Placeholders for test context_plt, zp_plt, zm_plt, gf_plt, gw_plt, copy_plt, proj_plt, next_plt = \ placeholder_inputs_single(context_size, zp_size, zm_size, gf_size, gw_size, copy_size, proj_size) # Train and validation part of the model predict = model.inference(FLAGS.batch_size, context_pl, zp_pl, zm_pl, gf_pl, gw_pl, copy_pl, proj_pl) loss = model.loss(predict, next_pl) train_op = model.training(loss) # evaluate = model.evaluate(predict, next_pl) # Test component of the model # The batch_size parameter is replaced with 1. pred_single = model.inference(1, context_plt, zp_plt, zm_plt, gf_plt, gw_plt, copy_plt, proj_plt) predicted_label = model.predict(pred_single) # Initialize the variables and start the session init = tf.initialize_all_variables() saver = tf.train.Saver() sess = tf.Session() sess.run(init) for epoch in range(1, FLAGS.num_epochs + 1): train_dataset.generate_permutation() start_e = time.time() for i in range(num_train_examples): feed_dict = fill_feed_dict(train_dataset, i, context_pl, zp_pl, zm_pl, gf_pl, gw_pl, next_pl, copy_pl, proj_pl) _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) if i % FLAGS.print_every == 0: print "Epoch : %d\tStep : %d\tLoss : %0.3f" % (epoch, i, loss_value) if i == -1 and i % FLAGS.valid_every == 0: print "Validation starting" valid_loss = do_eval(sess, train_op, loss, valid_dataset, context_pl, zp_pl, zm_pl, gf_pl, gw_pl, next_pl, copy_pl, proj_pl) print "Epoch : %d\tValidation loss: %0.5f" % (i, valid_loss) if i != 0 and i % FLAGS.sample_every == 0: test_dataset.reset_context() pos = 0 len_sent = 0 prev_predict = word2idx['<start>'] res_path = os.path.join(expt_result_path, 'sample.txt') with open(res_path, 'a') as exp: while pos != 1: feed_dict_t, idx2wq = fill_feed_dict_single(test_dataset, prev_predict, 0, context_plt, zp_plt, zm_plt, gf_plt, gw_plt, next_plt, copy_plt, proj_plt) prev_predict = sess.run([predicted_label], feed_dict=feed_dict_t) prev = prev_predict[0][0][0] if prev in idx2wq: exp.write(idx2wq[prev] + ' ') len_sent = len_sent + 1 else: exp.write('<unk> ') len_sent = len_sent + 1 if prev == word2idx['.']: pos = 1 exp.write('\n') if len_sent == 50: break prev_predict = prev duration_e = time.time() - start_e print "Time taken for epoch : %d is %0.3f minutes" % (epoch, duration_e/60) print "Saving checkpoint for epoch %d" % (epoch) checkpoint_file = os.path.join(chkpt_result_path, str(epoch) + '.ckpt') saver.save(sess, checkpoint_file) print "Validation starting" start = time.time() valid_loss = do_eval(sess, train_op, loss, valid_dataset, context_pl, zp_pl, zm_pl, gf_pl, gw_pl, next_pl, copy_pl, proj_pl) duration = time.time() - start print "Epoch : %d\tValidation loss: %0.5f" % (epoch, valid_loss) print "Time taken for validating epoch %d : %0.3f" % (epoch, duration) valid_res = os.path.join(expt_result_path, 'valid_loss.txt') with open(valid_res, 'a') as vloss_f: vloss_f.write("Epoch : %d\tValidation loss: %0.5f" % (epoch, valid_loss))
def main(_): pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.experiment_dir): os.makedirs(FLAGS.experiment_dir) expt_num = "1" else: expts = os.listdir(FLAGS.experiment_dir) last_expr = max([int(folder) for folder in expts]) expt_num = str(last_expr + 1) expt_result_path = os.path.join(FLAGS.experiment_dir, expt_num) os.makedirs(expt_result_path) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) chkpt_result_path = os.path.join(FLAGS.checkpoint_dir, expt_num) os.makedirs(chkpt_result_path) params_e_path = os.path.join(expt_result_path, "params.json") params_c_path = os.path.join(chkpt_result_path, "params.json") with open(params_e_path, 'w') as params_e, \ open(params_c_path, 'w') as params_c: json.dump(flags.FLAGS.__flags, params_e) json.dump(flags.FLAGS.__flags, params_c) # Generate the indexes word2idx, field2idx, qword2idx, nF, max_words_in_table, word_set = \ setup(FLAGS.data_dir, '../embeddings', FLAGS.n, FLAGS.batch_size, FLAGS.nW, FLAGS.min_field_freq, FLAGS.nQ) # Create the dataset objects train_dataset = DataSet(FLAGS.data_dir, 'train', FLAGS.n, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.batch_size, word2idx, field2idx, qword2idx, FLAGS.max_fields, FLAGS.word_max_fields, max_words_in_table, word_set) num_train_examples = train_dataset.num_examples() valid_dataset = DataSet(FLAGS.data_dir, 'valid', FLAGS.n, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.batch_size, word2idx, field2idx, qword2idx, FLAGS.max_fields, FLAGS.word_max_fields, max_words_in_table, word_set) test_dataset = DataSet(FLAGS.data_dir, 'test', FLAGS.n, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.batch_size, word2idx, field2idx, qword2idx, FLAGS.max_fields, FLAGS.word_max_fields, max_words_in_table, word_set) # The sizes of respective conditioning variables # for placeholder generation context_size = (FLAGS.n - 1) zp_size = context_size * FLAGS.word_max_fields zm_size = context_size * FLAGS.word_max_fields gf_size = FLAGS.max_fields gw_size = max_words_in_table copy_size = FLAGS.word_max_fields proj_size = FLAGS.nW + max_words_in_table # Generate the TensorFlow graph with tf.Graph().as_default(): # Create the CopyAttention model model = CopyAttention(FLAGS.n, FLAGS.d, FLAGS.g, FLAGS.nhu, FLAGS.nW, nF, FLAGS.nQ, FLAGS.l, FLAGS.learning_rate, max_words_in_table, FLAGS.max_fields, FLAGS.word_max_fields) # Placeholders for train and validation context_pl, zp_pl, zm_pl, gf_pl, gw_pl, next_pl, copy_pl, proj_pl = \ placeholder_inputs(FLAGS.batch_size, context_size, zp_size, zm_size, gf_size, gw_size, copy_size, proj_size) # Placeholders for test context_plt, zp_plt, zm_plt, gf_plt, gw_plt, copy_plt, proj_plt, next_plt = \ placeholder_inputs_single(context_size, zp_size, zm_size, gf_size, gw_size, copy_size, proj_size) # Train and validation part of the model predict = model.inference(FLAGS.batch_size, context_pl, zp_pl, zm_pl, gf_pl, gw_pl, copy_pl, proj_pl) loss = model.loss(predict, next_pl) train_op = model.training(loss) # evaluate = model.evaluate(predict, next_pl) # Test component of the model # The batch_size parameter is replaced with 1. pred_single = model.inference(1, context_plt, zp_plt, zm_plt, gf_plt, gw_plt, copy_plt, proj_plt) predicted_label = model.predict(pred_single) # Initialize the variables and start the session init = tf.initialize_all_variables() saver = tf.train.Saver() sess = tf.Session() ckpt_file = os.path.join('../checkpoint', '15', '16.ckpt') saver.restore(sess, ckpt_file) #sess.run(init) start_g = time.time() num_test_boxes = test_dataset.num_infoboxes() res_path = os.path.join('../experiment/', '15', 'generated.txt') with open(res_path, 'a') as exp: for k in range(num_test_boxes): test_dataset.reset_context() pos = 0 len_sent = 0 prev_predict = word2idx['<start>'] while pos != 1: feed_dict_t, idx2wq = fill_feed_dict_single( test_dataset, prev_predict, k, context_plt, zp_plt, zm_plt, gf_plt, gw_plt, next_plt, copy_plt, proj_plt) prev_predict = sess.run([predicted_label], feed_dict=feed_dict_t) prev = prev_predict[0][0][0] if prev in idx2wq: exp.write(idx2wq[prev] + ' ') len_sent = len_sent + 1 else: exp.write('<unk> ') len_sent = len_sent + 1 if prev == word2idx['.']: pos = 1 exp.write('\n') if len_sent == 50: break prev_predict = prev duration_g = time.time() - start_g print "Time taken for generating sentences : %0.3f minutes" % ( duration_g / 60)