Ejemplo n.º 1
0
def main(_):

    ## hyperparams
    hps = tf.contrib.training.HParams(img_height=FLAGS.img_height,
                                      img_width=FLAGS.img_width,
                                      img_channels=FLAGS.img_channels,
                                      discrete_outputs=FLAGS.discrete_outputs,
                                      batch_size=FLAGS.batch_size,
                                      episode_len=FLAGS.episode_len,
                                      memory_size=FLAGS.memory_size,
                                      code_size=FLAGS.code_size,
                                      opt_iters=FLAGS.opt_iters,
                                      enc_blocks=FLAGS.enc_blocks,
                                      dec_blocks=FLAGS.dec_blocks,
                                      num_filters=FLAGS.num_filters,
                                      trainable_memory=FLAGS.trainable_memory,
                                      use_bn=FLAGS.use_bn,
                                      use_ddi=FLAGS.use_ddi,
                                      lr=FLAGS.lr,
                                      epochs=FLAGS.epochs)

    ## dataset
    ds_train, ds_val, ds_test = data.get_dataset(name=FLAGS.dataset, hps=hps)

    ## model and session
    model = VBMC.VariationalBayesianMemory(hps)
    sess = tf.Session()

    ## tensorboard
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')

    ## checkpointing
    saver = tf.train.Saver()

    ## init op
    init_op = tf.global_variables_initializer()
    _ = sess.run(init_op)

    ## restoring
    if FLAGS.load_checkpoint != '' and os.path.exists(FLAGS.load_checkpoint):
        saver.restore(sess, FLAGS.load_checkpoint)
    else:
        print('load checkpoint "{}" does not exist.'.format(
            FLAGS.load_checkpoint))
        print('continue anyway? [y/N]')
        yn = input('> ')
        if yn.strip().lower() != 'y':
            print('program exiting.')
            return

    ## helper functions for the various modes supported by this application
    mode_to_routine = {
        'train': routines.train,
        'eval': routines.evaluate,
        'generate': routines.generate,
        'copy': routines.copy
    }
    routine = mode_to_routine[FLAGS.mode]

    ## rather than pass around tons of arguments,
    #  just use callbacks to perform the required functionality
    if FLAGS.mode == 'train':
        checkpoint_dir = FLAGS.checkpoint_dir
        callbacks = {
            'tensorboard': calls.tensorboard(train_writer),
            'checkpointing': calls.checkpointing(sess, saver, checkpoint_dir)
        }
        routines.train(ds_train,
                       ds_val,
                       sess,
                       model,
                       callbacks,
                       epochs=hps.epochs)

    elif FLAGS.mode == 'eval':
        routines.evaluate(ds_test, sess, model)

    else:
        output_dir = FLAGS.output_dir
        callbacks = {
            'save_png': calls.save_png(output_dir),
            'save_gif': calls.save_gif(output_dir)
        }
        routine(ds_test, sess, model, callbacks)
Ejemplo n.º 2
0
def main(_):

    ## hyperparams
    hps = tf.contrib.training.HParams(
        batch_size=FLAGS.batch_size,
        img_height=FLAGS.img_height,
        img_width=FLAGS.img_width,
        img_channels=FLAGS.img_channels,
        num_timesteps=FLAGS.num_timesteps,
        z_dim=FLAGS.z_dim,
        encoder_hidden_dim=FLAGS.encoder_hidden_dim,
        decoder_hidden_dim=FLAGS.decoder_hidden_dim,
        read_dim=FLAGS.read_dim,
        write_dim=FLAGS.write_dim,
        init_scale=FLAGS.init_scale,
        forget_bias=FLAGS.forget_bias,
        lr=FLAGS.lr,
        epochs=FLAGS.epochs)

    ## dataset
    ds_train, ds_test = get_dataset(name=FLAGS.dataset, hps=hps)

    ## model and session
    model = DRAW(hps)
    sess = tf.Session()

    ## tensorboard
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')

    ## checkpointing
    saver = tf.train.Saver()

    ## init op
    init_op = tf.global_variables_initializer()
    _ = sess.run(init_op)

    ## restoring
    if FLAGS.load_checkpoint != '':
        saver.restore(sess, FLAGS.load_checkpoint)

    ## helper functions for the various modes supported by this application
    mode_to_routine = {
        'train': routines.train,
        'eval': routines.evaluate,
        'generate': routines.generate,
        'reconstruct': routines.reconstruct,
        'generate_gif': routines.generate_gif,
        'reconstruct_gif': routines.reconstruct_gif
    }
    routine = mode_to_routine[FLAGS.mode]

    ## rather than pass around tons of arguments,
    #  just use callbacks to perform the required functionality
    if FLAGS.mode == 'train':
        checkpoint_dir = FLAGS.checkpoint_dir
        checkpoint_frequency = FLAGS.checkpoint_frequency
        callbacks = {
            'tensorboard':
            calls.tensorboard(train_writer),
            'checkpointing':
            calls.checkpointing(sess, saver, checkpoint_dir,
                                checkpoint_frequency)
        }
        routine(ds_train, sess, model, callbacks)

    elif FLAGS.mode == 'eval':
        callbacks = {}
        routine(ds_test, sess, model, callbacks)

    else:
        output_dir = FLAGS.output_dir
        callbacks = {
            'save_png': calls.save_png(output_dir),
            'save_gif': calls.save_gif(output_dir)
        }
        routine(ds_test, sess, model, callbacks)