def main(argv=()): del argv # Unused. vocab = Vocab() shp_p = tf.placeholder(tf.int32, shape=(2, )) sen_batch_p = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, None)) mask_batch_p = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, None)) labels_batch_p = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, )) max_sampling = (FLAGS.sampling_mode == 'max') decoded_samples = model_sample(sen_batch_p, mask_batch_p, shp_p, labels_batch_p, max_sampling=max_sampling) saver = tf.train.Saver() with tf.Session() as sess: coord = tf.train.Coordinator() saver.restore(sess, FLAGS.restore_ckpt_path) threads = tf.train.start_queue_runners(sess=sess, coord=coord) for label in range(FLAGS.Nlabels): if FLAGS.flip_label: flip_label = 1 - label else: flip_label = label input_file = FLAGS.input_file.split(',')[label] input_sents = open(input_file, 'r').readlines() input_sents = [sent.strip() for sent in input_sents] samples = [] for it in range(int(len(input_sents) / FLAGS.batch_size) + 1): labels_batch = np.array([0] * FLAGS.batch_size) sents = input_sents[it * FLAGS.batch_size:(it + 1) * FLAGS.batch_size] num_sents = len(sents) while len(sents) < FLAGS.batch_size: sents.extend(sents[:FLAGS.batch_size - len(sents)]) sen_batch, mask_batch, shp = vocab.construct_batch(sents) out = sess.run(decoded_samples, feed_dict={ sen_batch_p: sen_batch, mask_batch_p: mask_batch, shp_p: shp, labels_batch_p: labels_batch }) for k in range(FLAGS.batch_size): if k >= num_sents: break samples.append(vocab.convert_to_str(out[flip_label][k])) fname = FLAGS.samples_dir + '/' + FLAGS.mdl_name + '_sample_' + str( flip_label) fname += '.txt' with open(fname, 'w') as results_file: results_file.write('\n'.join(samples)) coord.request_stop() coord.join(threads)