def main(_): ## hyperparams hps = tf.contrib.training.HParams(img_height=FLAGS.img_height, img_width=FLAGS.img_width, img_channels=FLAGS.img_channels, discrete_outputs=FLAGS.discrete_outputs, batch_size=FLAGS.batch_size, episode_len=FLAGS.episode_len, memory_size=FLAGS.memory_size, code_size=FLAGS.code_size, opt_iters=FLAGS.opt_iters, enc_blocks=FLAGS.enc_blocks, dec_blocks=FLAGS.dec_blocks, num_filters=FLAGS.num_filters, trainable_memory=FLAGS.trainable_memory, use_bn=FLAGS.use_bn, use_ddi=FLAGS.use_ddi, lr=FLAGS.lr, epochs=FLAGS.epochs) ## dataset ds_train, ds_val, ds_test = data.get_dataset(name=FLAGS.dataset, hps=hps) ## model and session model = VBMC.VariationalBayesianMemory(hps) sess = tf.Session() ## tensorboard train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test') ## checkpointing saver = tf.train.Saver() ## init op init_op = tf.global_variables_initializer() _ = sess.run(init_op) ## restoring if FLAGS.load_checkpoint != '' and os.path.exists(FLAGS.load_checkpoint): saver.restore(sess, FLAGS.load_checkpoint) else: print('load checkpoint "{}" does not exist.'.format( FLAGS.load_checkpoint)) print('continue anyway? [y/N]') yn = input('> ') if yn.strip().lower() != 'y': print('program exiting.') return ## helper functions for the various modes supported by this application mode_to_routine = { 'train': routines.train, 'eval': routines.evaluate, 'generate': routines.generate, 'copy': routines.copy } routine = mode_to_routine[FLAGS.mode] ## rather than pass around tons of arguments, # just use callbacks to perform the required functionality if FLAGS.mode == 'train': checkpoint_dir = FLAGS.checkpoint_dir callbacks = { 'tensorboard': calls.tensorboard(train_writer), 'checkpointing': calls.checkpointing(sess, saver, checkpoint_dir) } routines.train(ds_train, ds_val, sess, model, callbacks, epochs=hps.epochs) elif FLAGS.mode == 'eval': routines.evaluate(ds_test, sess, model) else: output_dir = FLAGS.output_dir callbacks = { 'save_png': calls.save_png(output_dir), 'save_gif': calls.save_gif(output_dir) } routine(ds_test, sess, model, callbacks)
def main(_): ## hyperparams hps = tf.contrib.training.HParams( batch_size=FLAGS.batch_size, img_height=FLAGS.img_height, img_width=FLAGS.img_width, img_channels=FLAGS.img_channels, num_timesteps=FLAGS.num_timesteps, z_dim=FLAGS.z_dim, encoder_hidden_dim=FLAGS.encoder_hidden_dim, decoder_hidden_dim=FLAGS.decoder_hidden_dim, read_dim=FLAGS.read_dim, write_dim=FLAGS.write_dim, init_scale=FLAGS.init_scale, forget_bias=FLAGS.forget_bias, lr=FLAGS.lr, epochs=FLAGS.epochs) ## dataset ds_train, ds_test = get_dataset(name=FLAGS.dataset, hps=hps) ## model and session model = DRAW(hps) sess = tf.Session() ## tensorboard train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test') ## checkpointing saver = tf.train.Saver() ## init op init_op = tf.global_variables_initializer() _ = sess.run(init_op) ## restoring if FLAGS.load_checkpoint != '': saver.restore(sess, FLAGS.load_checkpoint) ## helper functions for the various modes supported by this application mode_to_routine = { 'train': routines.train, 'eval': routines.evaluate, 'generate': routines.generate, 'reconstruct': routines.reconstruct, 'generate_gif': routines.generate_gif, 'reconstruct_gif': routines.reconstruct_gif } routine = mode_to_routine[FLAGS.mode] ## rather than pass around tons of arguments, # just use callbacks to perform the required functionality if FLAGS.mode == 'train': checkpoint_dir = FLAGS.checkpoint_dir checkpoint_frequency = FLAGS.checkpoint_frequency callbacks = { 'tensorboard': calls.tensorboard(train_writer), 'checkpointing': calls.checkpointing(sess, saver, checkpoint_dir, checkpoint_frequency) } routine(ds_train, sess, model, callbacks) elif FLAGS.mode == 'eval': callbacks = {} routine(ds_test, sess, model, callbacks) else: output_dir = FLAGS.output_dir callbacks = { 'save_png': calls.save_png(output_dir), 'save_gif': calls.save_gif(output_dir) } routine(ds_test, sess, model, callbacks)