def main(_): # FILESYSTEM SETUP ------------------------------------------------------------ assert FLAGS.data_dir, "Must specify data location!" assert FLAGS.log_dir, "Must specify experiment to log to!" assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir assert FLAGS.cnn_ckpt, "Must specify where to load CNN checkpoint from!" assert FLAGS.variant, "Must specific shapeworld variant" # Build saving folders save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag train_path = save_root + os.sep + "train" eval_path = save_root + os.sep + "eval" test_path = save_root + os.sep + "test" if not tf.gfile.IsDirectory(train_path): tf.gfile.MakeDirs(train_path) tf.gfile.MakeDirs(eval_path) tf.gfile.MakeDirs(test_path) tf.logging.info("Creating training directory: %s", train_path) tf.logging.info("Creating eval directory: %s", eval_path) tf.logging.info("Creating eval directory: %s", test_path) else: tf.logging.info("Using training directory: %s", train_path) tf.logging.info("Using eval directory: %s", eval_path) # Sanity check tf.reset_default_graph() tf.logging.info("Clean graph reset...") try: dataset = Dataset.create(dtype=FLAGS.dtype, name=FLAGS.name, variant=FLAGS.variant, config=FLAGS.data_dir) dataset.pixel_noise_stddev = 0.1 except Exception: raise ValueError( "variant=%s did not point to a valid Shapeworld dataset" % FLAGS.variant) # Get parsing and parameter feats params = Config(mode="train", sw_specification=dataset.specification()) params.cnn_checkpoint = FLAGS.cnn_ckpt params.batch_size = FLAGS.batch_size # MODEL SETUP ------------------------------------------------------------ g = tf.Graph() with g.as_default(): parser = FullSequenceBatchParser( src_vocab=dataset.vocabularies['language']) params.vocab_size = len(parser.tgt_vocab) batch = tf_util.batch_records(dataset, mode="train", batch_size=params.batch_size) model = CaptioningModel(config=params, batch_parser=parser) if FLAGS.glove_dir: tf.logging.info("Loading GloVe Embeddings...") gl = GloveLoader(vocab=parser.tgt_vocab, glove_dir=FLAGS.glove_dir, dims=FLAGS.glove_dim, load_new=False) glove_initials = gl.get_embeddings_matrix() tf.logging.info("Building model with GloVe initialisation...") model.build_model(batch, embedding_init=glove_initials) else: tf.logging.info("Building model without GloVe initialisation...") model.build_model(batch) tf.logging.info("Network built...") # TRAINING OPERATION SETUP ------------------------------------------------------------ with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): train_op = tf.contrib.layers.optimize_loss( loss=model.batch_loss, global_step=model.global_step, learning_rate=params.initial_learning_rate, optimizer=params.optimizer, clip_gradients=params.clip_gradients, ) logging_saver = tf.train.Saver( max_to_keep=params.max_checkpoints_to_keep) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(logdir=train_path, graph=g) tf.logging.info('###' * 20) tf.logging.info("Beginning shape2seq network training for %d steps" % params.num_total_steps) with tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True)) as sess: tf.logging.info("### Trainable Variables") for var in tf.trainable_variables(): print("-> %s" % var.op.name) coordinator = tf.train.Coordinator() queue_threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) # Initialise everything sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) tf.logging.info("Restoring CNN...") model.init_fn(sess) start_train_time = time.time() # Loss accumulator and logging interval generator at [25%, 50%, 75%, 100%] * epoch logging_loss = [] logging_points = np.linspace(0, params.num_steps_per_epoch, 4, endpoint=False, dtype=np.int32) logging_points = np.fliplr( [params.num_steps_per_epoch - logging_points])[0] for c_epoch in range(0, params.num_epochs): tf.logging.info("Running epoch %d" % c_epoch) for c_step in trange(params.num_steps_per_epoch * c_epoch, params.num_steps_per_epoch * (c_epoch + 1)): if c_step in logging_points: _, loss_, summaries = sess.run( fetches=[train_op, model.batch_loss, summary_op]) loss_ = logging_loss + [loss_] logging_loss = [] avg_loss = np.mean(loss_).squeeze() new_summ = tf.Summary() new_summ.value.add(tag="train/avg_loss", simple_value=avg_loss) train_writer.add_summary( new_summ, tf.train.global_step(sess, model.global_step)) train_writer.add_summary( summaries, tf.train.global_step(sess, model.global_step)) train_writer.flush() tf.logging.info( " -> Average loss step %d, for last %d steps: %.5f" % (c_step, len(loss_), avg_loss)) # Run without summaries else: _, loss_, = sess.run(fetches=[train_op, model.batch_loss]) logging_loss.append(loss_) logging_saver.save(sess=sess, save_path=train_path + os.sep + "model", global_step=tf.train.global_step( sess, model.global_step)) coordinator.request_stop() coordinator.join(threads=queue_threads) end_time = time.time() - start_train_time tf.logging.info('Training complete in %.2f-secs/%.2f-mins/%.2f-hours', end_time, end_time / 60, end_time / (60 * 60))
filehandle.write(',validation ' + name) filehandle.write('\n') iteration_end = iteration_start + args.iterations - 1 with Model(name=args.model, learning_rate=parameters.pop('learning_rate', 1e-3), weight_decay=parameters.pop('weight_decay', None), clip_gradients=parameters.pop('clip_gradients', None), model_directory=args.model_dir, summary_directory=args.summary_dir) as model: dropout = parameters.pop('dropout_rate', None) module = import_module('models.{}.{}'.format(args.type, args.model)) if args.tf_records: inputs = tf_util.batch_records(dataset=dataset, mode='train', batch_size=args.batch_size) module.model(model=model, inputs=inputs, dataset_parameters=dataset_parameters, **parameters) else: module.model( model=model, inputs=dict(), dataset_parameters=dataset_parameters, **parameters ) # no input tensors, hence None for placeholder creation model.finalize(restore=args.restore) if args.verbosity >= 1:
filehandle.write(',validation ' + name) filehandle.write('\n') iteration_end = iteration_start + args.iterations - 1 with Model(name=args.model, learning_rate=parameters.pop('learning_rate'), weight_decay=parameters.pop('weight_decay', 0.0), model_directory=args.model_dir, summary_directory=args.summary_dir) as model: dropout = parameters.pop('dropout_rate', 0.0) module = import_module('models.{}.{}'.format(args.type, args.model)) if args.tf_records: module.model(model=model, inputs=tf_util.batch_records( dataset=dataset, batch_size=args.batch_size, noise_range=args.pixel_noise), **parameters) else: module.model( model=model, inputs=dict(), **parameters ) # no input tensors, hence None for placeholder creation model.finalize(restore=args.restore) if args.verbosity >= 1: sys.stdout.write(' parameters: {:,}\n'.format( model.num_parameters)) sys.stdout.write(' bytes: {:,}\n'.format(model.num_bytes)) sys.stdout.write('{} train model...\n'.format( datetime.now().strftime('%H:%M:%S'))) sys.stdout.write(' 0% {}/{} '.format(
def main(_): # FILESYSTEM SETUP ------------------------------------------------------------ assert FLAGS.data_dir, "Must specify data location!" assert FLAGS.log_dir, "Must specify experiment to log to!" assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir assert FLAGS.parse_type # Folder setup for saving summaries and loading checkpoints save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag test_path = save_root + os.sep + "test" if not tf.gfile.IsDirectory(test_path): tf.gfile.MakeDirs(test_path) train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train" model_ckpt = tf.train.latest_checkpoint( train_path) # Get checkpoint to load tf.logging.info("Loading checkpoint %s", model_ckpt) assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path # Sanity check graph reset tf.reset_default_graph() tf.logging.info("Clean graph reset...") # try: dataset = Dataset.create(dtype=FLAGS.dtype, name=FLAGS.name, config=FLAGS.data_dir) dataset.pixel_noise_stddev = 0.1 dataset.random_sampling = False # except Exception: # raise ValueError("config=%s did not point to a valid Shapeworld dataset" % FLAGS.data_dir) # Get parsing and parameter feats params = Config(mode="test", sw_specification=dataset.specification()) # Parse decoding arg from CLI params.decode_type = FLAGS.decode_type assert params.decode_type in ['greedy', 'sample', 'beam'] # MODEL SETUP ------------------------------------------------------------ g = tf.Graph() with g.as_default(): parser = SimpleBatchParser(src_vocab=dataset.vocabularies['language'], batch_type=FLAGS.parse_type) vocab, rev_vocab = parser.get_vocab() params.vocab_size = len(parser.tgt_vocab) batch = tf_util.batch_records(dataset, mode=FLAGS.data_partition, batch_size=params.batch_size) model = CaptioningModel(config=params, batch_parser=parser) model.build_model(batch) restore_model = tf.train.Saver() tf.logging.info("Network built...") # TESTING SETUP ------------------------------------------------------------ if FLAGS.num_imgs < 1: num_imgs = params.instances_per_shard * params.num_shards else: num_imgs = FLAGS.num_imgs tf.logging.info("Running test for %d images", num_imgs) test_writer = tf.summary.FileWriter(logdir=test_path, graph=g) with tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True)) as sess: # Launch data loading queues coordinator = tf.train.Coordinator() queue_threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) # Model restoration restore_model.restore(sess, model_ckpt) tf.logging.info("Model restored!") # Trained model does not need initialisation. Init the vocab conversation tables sess.run([tf.tables_initializer()]) # Freeze graph sess.graph.finalize() # Get global step global_step = tf.train.global_step(sess, model.global_step) tf.logging.info("Successfully loaded %s at global step = %d.", os.path.basename(model_ckpt), global_step) start_test_time = time.time() corrects = [] incorrects = [] # For correctly formed, but wrong captions misses = [] # For incorrectly formed captions perplexities = [] for b_idx in range(num_imgs): # idx_batch = dataset.generate(n=params.batch_size, mode=FLAGS.data_partition, include_model=True) reference_caps, inf_decoder_outputs, batch_perplexity = sess.run( fetches=[ model.reference_captions, model.inf_decoder_output, model.batch_perplexity ], feed_dict={model.phase: 0}) ref_cap = reference_caps.squeeze() inf_cap = inf_decoder_outputs.sample_id.squeeze() perplexities.append(batch_perplexity) if inf_cap.ndim > 0 and inf_cap.ndim > 0: print("%d REF -> %s | INF -> %s" % (b_idx, " ".join( rev_vocab[r] for r in ref_cap), " ".join(rev_vocab[r] for r in inf_cap))) # Strip <S>, </S> and any irrelevant tokens and convert to list for order insensitivity ref_cap = set([ tok for tok in ref_cap if int(tok) not in parser.token_filter ]) inf_cap = set([ tok for tok in inf_cap if int(tok) not in parser.token_filter ]) if np.all([i in ref_cap for i in inf_cap]): corrects.append(1) else: incorrects.append((ref_cap, inf_cap)) else: print("Skipping %d as inf_cap %s is malformed" % (b_idx, inf_cap)) misses.append(1) # Overall scores for checkpoint avg_acc = np.mean(corrects).squeeze() std_acc = np.std(corrects).squeeze() print("Accuracy: %s -> %.5f ± %.5f | Misses: %d " % (FLAGS.parse_type, avg_acc, std_acc, len(misses))) avg_perplexity = np.mean(perplexities).squeeze() std_perplexity = np.std(perplexities).squeeze() print("------------") print("PERPLEXITY -> %.5f +- %.5f" % (avg_perplexity, std_perplexity)) new_summ = tf.Summary() new_summ.value.add(tag="%s/avg_acc_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=avg_acc) new_summ.value.add(tag="%s/std_acc_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=std_acc) new_summ.value.add(tag="%s/perplexity_avg_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=avg_perplexity) new_summ.value.add(tag="%s/perplexity_std_%s" % (FLAGS.data_partition, FLAGS.name), simple_value=std_perplexity) test_writer.add_summary(new_summ, tf.train.global_step(sess, model.global_step)) test_writer.flush() coordinator.request_stop() coordinator.join(threads=queue_threads) end_time = time.time() - start_test_time tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours', end_time, end_time / 60, end_time / (60 * 60))