Exemplo n.º 1
0
def main():
    parser = make_standard_parser(
        'Train a GAN model on simple square images or Clevr two-object color images',
        arch_choices=arch_choices,
        skip_train=True,
        skip_val=True)
    parser.add_argument('--z_dim',
                        type=int,
                        default=10,
                        help='Dimension of noise vector')
    parser.add_argument('--lr2',
                        type=float,
                        default=None,
                        help='learning rate for generator')
    parser.add_argument('--feature_match',
                        '-fm',
                        action='store_true',
                        help='use feature matching loss for generator.')
    parser.add_argument(
        '--feature_match_loss_weight',
        '-fmalpha',
        type=float,
        default=1.0,
        help='weight on the feature matching loss for generator.')
    parser.add_argument(
        '--pairedz',
        action='store_true',
        help='If True, pair the same z with a training batch each epoch')
    parser.add_argument(
        '--eval-train-every',
        type=int,
        default=0,
        help='evaluate whole training set every N epochs. 0 to disable.')

    args = parser.parse_args()

    args.skipval = True

    minibatch_size = args.minibatch
    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)
    evaltrain_style = '' if args.nocolor or args.eval_train_every <= 0 else colorama.Fore.CYAN

    black_divider = True if args.arch.startswith('clevr') else False

    # Get a TF session and set numpy and TF seeds
    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. LOAD DATA
    if args.arch.startswith('simple'):
        fd = h5py.File('data/rectangle_4_uniform.h5', 'r')
        train_x = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_imagegray'],
                         dtype=float) / 255.0  # shape (768, 64, 64, 1)
        train_x = np.concatenate((train_x, val_x),
                                 axis=0)  # shape (3136, 64, 64, 1)

    elif args.arch.startswith('clevr'):
        (train_x, val_x) = load_sort_of_clevr()
        # shape (50000, 64, 64, 3)
        train_x = np.concatenate((train_x, val_x), axis=0)

    else:
        raise Exception('Unknown network architecture: %s' % args.arch)

    print 'Train data loaded: {} images, size {}'.format(
        train_x.shape[0], train_x.shape[1:])
    #print 'Val data loaded: {} images, size {}'.format(val_x.shape[0], val_x.shape[1:])

    #print 'Label dimension: {}'.format(val_y.shape[1:])

    # 1. CREATE MODEL
    assert len(train_x.shape) == 4, "image data must be of 4 dimensions"
    image_h, image_w, image_c = train_x.shape[1], train_x.shape[
        2], train_x.shape[3]

    model = build_model(args, image_h, image_w, image_c)

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    print 'Model summary:'
    # model.summary()      # TOREPLACE
    print 'Another model summary:'
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    lr_gen = args.lr2 if args.lr2 else args.lr

    if args.opt == 'sgd':
        d_opt = tf.train.MomentumOptimizer(args.lr, args.mom)
        g_opt = tf.train.MomentumOptimizer(lr_gen, args.mom)
    elif args.opt == 'rmsprop':
        d_opt = tf.train.RMSPropOptimizer(args.lr, momentum=args.mom)
        g_opt = tf.train.RMSPropOptimizer(lr_gen, momentum=args.mom)
    elif args.opt == 'adam':
        d_opt = tf.train.AdamOptimizer(args.lr, args.beta1, args.beta2)
        g_opt = tf.train.AdamOptimizer(lr_gen, args.beta1, args.beta2)

    # Optimize w.r.t all trainable params in the model

    all_vars = model.trainable_variables
    d_vars = [var for var in all_vars if 'discriminator' in var.name]
    g_vars = [var for var in all_vars if 'generator' in var.name]

    d_grads_and_vars = d_opt.compute_gradients(
        model.d_loss, d_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    d_train_step = d_opt.apply_gradients(d_grads_and_vars)
    g_grads_and_vars = g_opt.compute_gradients(
        model.g_loss, g_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    g_train_step = g_opt.apply_gradients(g_grads_and_vars)

    hist_summaries_traintest(model.d_real_logits, model.d_fake_logits)

    add_grads_and_vars_hist_summaries(d_grads_and_vars)
    add_grads_and_vars_hist_summaries(g_grads_and_vars)
    image_summaries_traintest(model.fake_images)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running
    # BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy(pretty_replaces=[('evaltrain_', ''), (
            'eval', '')]) if args.eval_train_every > 0 else StatsBuddy()

    buddy.tic()  # call if new run OR resumed run

    tf.global_variables_initializer().run()

    # 4. SETUP TENSORBOARD LOGGING

    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN
    train_iters = (train_x.shape[0]) // minibatch_size
    if not args.skipval:
        val_iters = (val_x.shape[0]) // minibatch_size

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    # 2. use same noise, eval on 100 samples and save G(z),
    np.random.seed()
    eval_batch_size = 100
    eval_z = np.random.uniform(-1, 1, size=(eval_batch_size, args.z_dim))

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        def do_log_params(ep, it, ii):
            return True

        def do_log_val(ep, it, ii):
            return True

        def do_log_train(ep, it, ii):
            return (it < train_iters and it & it - 1 == 0
                    or it >= train_iters and it % train_iters == 0
                    )  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Evaluate generator by showing random generated results
        #    Evaluate descriminator by showing seeing correct rate on generated and real (hold-out) results
        #assert(args.skipval), "only support training now"

        if not args.skipval:
            tic2()
            # use different noise, eval on larger number of samples and get
            # correct rate
            np.random.seed()
            val_z = np.random.uniform(-1, 1, size=(val_x.shape[0], args.z_dim))

            with WithTimer('sess.run val iter', quiet=not args.verbose):
                feed_dict = {
                    model.input_images: val_x,
                    model.input_noise: val_z,
                    learning_phase(): 0
                }

                if 'input_labels' in model.named_keys():
                    feed_dict.update({model.input_labels: val_y})

                val_corr_fake_bn0, val_corr_real_bn0 = sess.run(
                    [model.correct_fake, model.correct_real],
                    feed_dict=feed_dict)

                feed_dict[learning_phase()] = 1
                val_corr_fake_bn1, val_corr_real_bn1 = sess.run(
                    [model.correct_fake, model.correct_real],
                    feed_dict=feed_dict)

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                fetch_dict = {}
                if test_image_summaries is not None:
                    fetch_dict.update(
                        {'test_image_summaries': test_image_summaries})
                if test_scalar_summaries is not None:
                    fetch_dict.update(
                        {'test_scalar_summaries': test_scalar_summaries})
                if test_histogram_summaries is not None:
                    fetch_dict.update(
                        {'test_histogram_summaries': test_histogram_summaries})
                if fetch_dict:
                    summary_strs = sess_run_dict(sess,
                                                 fetch_dict,
                                                 feed_dict=feed_dict)

            buddy.note_list([
                'correct_real_bn0', 'correct_fake_bn0', 'correct_real_bn1',
                'correct_fake_bn1'
            ], [
                val_corr_real_bn0, val_corr_fake_bn0, val_corr_real_bn1,
                val_corr_fake_bn1
            ],
                            prefix='val_')

            print(
                '%3d (ep %d) val: %s (%.3gs/ep)' %
                (buddy.train_iter, buddy.epoch,
                 buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2()))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='buddy')

                if test_image_summaries is not None:
                    image_summary_str = summary_strs['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = summary_strs['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = summary_strs['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # In addition, evalutate 1000 more images
        np.random.seed()
        eval_more = np.random.uniform(-1, 1, size=(1000, args.z_dim))
        feed_dict2 = {
            # (100,-) generated outside of loop to keep the same every round
            model.input_noise:
            eval_z,
            learning_phase():
            0
        }

        eval_samples_bn0 = sess.run(model.fake_images, feed_dict=feed_dict2)

        feed_dict2[learning_phase()] = 1
        eval_samples_bn1 = sess.run(model.fake_images, feed_dict=feed_dict2)

        # feed in 10 times because coordconv cannot handle too big of a batch
        for cc in range(10):
            eval_z2 = eval_more[cc * 100:(cc + 1) * 100, :]
            _eval_more_samples = sess.run(
                model.fake_images,
                feed_dict={
                    model.input_noise: eval_z2,  # (1000,-)
                    learning_phase(): 0
                })
            eval_more_samples = _eval_more_samples if cc == 0 else np.concatenate(
                (eval_more_samples, _eval_more_samples), axis=0)

        if args.output:
            mkdir_p('{}/fake_images'.format(args.output))
            # eval_samples_bn*: e.g. (100, 64, 64, 3)
            save_images(eval_samples_bn0, [10, 10],
                        '{}/fake_images/g_out_bn0_epoch_{}_iter_{}.png'.format(
                            args.output, buddy.epoch, buddy.train_iter),
                        black_divider=black_divider)
            save_images(eval_samples_bn1, [10, 10],
                        '{}/fake_images/g_out_bn1_epoch_{}.png'.format(
                            args.output, buddy.epoch),
                        black_divider=black_divider)
            save_average_image(
                eval_more_samples,
                '{}/fake_images/g_out_averaged_epoch_{}_iter_{}.png'.format(
                    args.output, buddy.epoch, buddy.train_iter))

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            snap_end = buddy.epoch == args.epochs
            if snap_intermed or snap_end:
                # Snapshot
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open(
                        '%s/%s_misc_%04d.pkl.gz' %
                    (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)
                # snapshot sampled images too
                ff = h5py.File(
                    '%s/sampled_images_%04d.h5' % (args.output, buddy.epoch),
                    'w')
                ff.create_dataset('eval_samples_bn0', data=eval_samples_bn0)
                ff.create_dataset('eval_samples_bn1', data=eval_samples_bn1)
                ff.create_dataset('eval_z', data=eval_z)
                ff.create_dataset('eval_z_more', data=eval_more)
                ff.create_dataset('eval_more_samples', data=eval_more_samples)
                ff.close()

        # 2. Possiby evaluate the training set
        if args.eval_train_every > 0:
            if buddy.epoch % args.eval_train_every == 0:
                tic2()
                for ii in xrange(train_iters):
                    start_idx = ii * minibatch_size
                    if args.pairedz:
                        np.random.seed(args.seed + ii)
                    else:
                        np.random.seed()
                    batch_z = np.random.uniform(-1,
                                                1,
                                                size=(minibatch_size,
                                                      args.z_dim))

                    batch_x = train_x[start_idx:start_idx + minibatch_size]
                    batch_y = train_y[start_idx:start_idx + minibatch_size]

                    feed_dict = {
                        model.input_images: batch_x,
                        # model.input_labels: batch_y,
                        model.input_noise: batch_z,
                        learning_phase(): 0,
                    }

                    if 'input_labels' in model.named_keys():
                        feed_dict.update({model.input_labels: val_y})

                    fetch_dict = model.trackable_dict()
                    result_eval_train = sess_run_dict(sess,
                                                      fetch_dict,
                                                      feed_dict=feed_dict)
                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names(), [
                            result_eval_train[k]
                            for k in model.trackable_names()
                        ],
                        prefix='evaltrain_bn0_')

                    feed_dict = {
                        model.input_images: batch_x,
                        # model.input_labels: batch_y,
                        model.input_noise: batch_z,
                        learning_phase(): 1,
                    }
                    if 'input_labels' in model.named_keys():
                        feed_dict.update({model.input_labels: val_y})

                    result_eval_train = sess_run_dict(sess,
                                                      fetch_dict,
                                                      feed_dict=feed_dict)
                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names(), [
                            result_eval_train[k]
                            for k in model.trackable_names()
                        ],
                        prefix='evaltrain_bn1_')

                    if args.output:
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^evaltrain_bn0_')
                                    },
                                    prefix='buddy')
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^evaltrain_bn1_')
                                    },
                                    prefix='buddy')
                if args.output:
                    log_scalars(writer,
                                buddy.epoch, {
                                    'mean_%s' % name: value
                                    for name, value in
                                    buddy.epoch_mean_list_re('^evaltrain_bn0_')
                                },
                                prefix='buddy')
                    log_scalars(writer,
                                buddy.epoch, {
                                    'mean_%s' % name: value
                                    for name, value in
                                    buddy.epoch_mean_list_re('^evaltrain_bn1_')
                                },
                                prefix='buddy')

                print('%3d (ep %d) evaltrain: %s (%.3gs/ep)' %
                      (buddy.train_iter, buddy.epoch,
                       buddy.epoch_mean_pretty_re(
                           '^evaltrain_bn0_', style=evaltrain_style), toc2()))
                print('%3d (ep %d) evaltrain: %s (%.3gs/ep)' %
                      (buddy.train_iter, buddy.epoch,
                       buddy.epoch_mean_pretty_re(
                           '^evaltrain_bn1_', style=evaltrain_style), toc2()))

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break  # Extra pass at end: just report val stats and skip training

        # 3. Train on training set

        if args.shuffletrain:
            train_order = np.random.permutation(train_x.shape[0])
            train_order2 = np.random.permutation(train_x.shape[0])
        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            if args.pairedz:
                np.random.seed(args.seed + ii)
            else:
                np.random.seed()

            batch_z = np.random.uniform(-1,
                                        1,
                                        size=(minibatch_size, args.z_dim))

            if args.shuffletrain:
                #batch_x = train_x[train_order[start_idx:start_idx + minibatch_size]]
                batch_x = train_x[sorted(train_order[start_idx:start_idx +
                                                     minibatch_size].tolist())]
                if args.feature_match:
                    assert args.shuffletrain, "feature matching loss requires shuffle train"
                    batch_x2 = train_x[sorted(
                        train_order2[start_idx:start_idx +
                                     minibatch_size].tolist())]
                if 'input_labels' in model.named_keys():
                    batch_y = train_y[sorted(
                        train_order[start_idx:start_idx +
                                    minibatch_size].tolist())]
            else:
                batch_x = train_x[start_idx:start_idx + minibatch_size]
                if 'input_labels' in model.named_keys():
                    batch_y = train_y[start_idx:start_idx + minibatch_size]

            feed_dict = {
                model.input_images: batch_x,
                # model.input_labels: batch_y,
                model.input_noise: batch_z,
                learning_phase(): 1,
            }

            if 'input_labels' in model.named_keys():
                feed_dict.update({model.input_labels: batch_y})
            if 'input_images2' in model.named_keys():
                feed_dict.update({model.input_images2: batch_x2})

            fetch_dict = model.trackable_and_update_dict()

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update({
                        'train_histogram_summaries':
                        train_histogram_summaries
                    })
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(sess,
                                             fetch_dict,
                                             feed_dict=feed_dict)

                # if result_train['d_loss'] < result_train['g_loss']:
                #    #print 'Only train G'
                #    sess.run(g_train_step, feed_dict=feed_dict)
                # else:
                #    #print 'Train both D and G'
                #    sess.run(d_train_step, feed_dict=feed_dict)
                #    sess.run(g_train_step, feed_dict=feed_dict)
                #    sess.run(g_train_step, feed_dict=feed_dict)
                sess.run(d_train_step, feed_dict=feed_dict)
                sess.run(g_train_step, feed_dict=feed_dict)
                sess.run(g_train_step, feed_dict=feed_dict)

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                buddy.note_weighted_list(
                    batch_x.shape[0],
                    model.trackable_names(),
                    [result_train[k] for k in model.trackable_names()],
                    prefix='train_')
                print('[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                      (buddy.train_iter, buddy.epoch, args.epochs,
                       buddy.epoch_mean_pretty_re('^train_',
                                                  style=train_style), toc2()))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train[
                        'train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'batch_%s' % name: value
                        for name, value in buddy.last_list_re('^train_')
                    },
                    prefix='buddy')

            if ii > 0 and ii % 100 == 0:
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, toc3() / 100)
                tic3()

            buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='buddy')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^val_',
                                                            style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^train_',
                                                            style=train_style))

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()  # Flush and close
Exemplo n.º 2
0
def main():
    parser = make_standard_parser(
        'Coordconv',
        arch_choices=arch_choices,
        skip_train=True,
        skip_val=True)
    # re-add train and val h5s as optional
    parser.add_argument('--data_h5', type=str,
                        default='./data/rectangle_4_uniform.h5',
                        help='data file in hdf5.')
    parser.add_argument('--x_dim', type=int, default=64,
                        help='x dimension of the output image')
    parser.add_argument('--y_dim', type=int, default=64,
                        help='y dimension of the output image')
    parser.add_argument('--lrpolicy', type=str, default='constant',
                        choices=lr_policy_choices, help='LR policy.')
    parser.add_argument('--lrstepratio', type=float,
                        default=.1, help='LR policy step ratio.')
    parser.add_argument('--lrmaxsteps', type=int, default=5,
                        help='LR policy step ratio.')
    parser.add_argument('--lrstepevery', type=int, default=50,
                        help='LR policy step ratio.')
    parser.add_argument('--filter_size', '-fs', type=int, default=3,
                        help='filter size in deconv network')
    parser.add_argument('--channel_mul', '-mul', type=int, default=2,
        help='Deconv model channel multiplier to make bigger models')
    parser.add_argument('--use_mse_loss', '-mse', action='store_true',
                        help='use mse loss instead of cross entropy')
    parser.add_argument('--use_sigm_loss', '-sig', action='store_true',
                        help='use sigmoid loss instead of cross entropy')
    parser.add_argument('--interm_loss', '-interm', default=None,
        choices=(None, 'softmax', 'mse'),
        help='add intermediate loss to end-to-end painter model')
    parser.add_argument('--no_softmax', '-nosfmx', action='store_true',
                        help='Remove softmax sharpening layer in model')

    args = parser.parse_args()

    if args.lrpolicy == 'step':
        lr_policy = LRPolicyStep(args)
    elif args.lrpolicy == 'valstep':
        lr_policy = LRPolicyValStep(args)
    else:
        lr_policy = LRPolicyConstant(args)

    minibatch_size = args.minibatch
    train_style, val_style = (
        '', '') if args.nocolor else (
        colorama.Fore.BLUE, colorama.Fore.MAGENTA)

    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. Load data or generate data on the fly
    print 'Loading data: {}'.format(args.data_h5)

    if args.arch in ['deconv_classification',
                     'coordconv_classification',
                     'upsample_conv_coords',
                     'upsample_coordconv_coords']:

        # option a: generate data on the fly
        #data = list(itertools.product(range(args.x_dim),range(args.y_dim)))
        # random.shuffle(data)

        #train_test_split = .8
        #val_reps = int(args.x_dim * args.x_dim * train_test_split) // minibatch_size
        #val_size = val_reps * minibatch_size
        #train_end = args.x_dim * args.x_dim - val_size
        #train_x, val_x = np.array(data[:train_end]).astype('int'), np.array(data[train_end:]).astype('int')
        #train_y, val_y = None, None
        #DATA_GEN_ON_THE_FLY = True

        # option b: load the data
        fd = h5py.File(args.data_h5, 'r')

        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_onehots'], dtype=float)  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_onehots'], dtype=float)  # shape (768, 64, 64, 1)
        DATA_GEN_ON_THE_FLY = False

        # number of image channels
        image_c = train_y.shape[-1] if train_y is not None and len(train_y.shape) == 4 else 1

    elif args.arch == 'conv_onehot_image':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(
            fd['train_onehots'],
            dtype=int)  # shape (2368, 64, 64, 1)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(
            fd['val_onehots'],
            dtype=float)  # shape (768, 64, 64, 1)
        val_y = np.array(fd['val_imagegray'], dtype=float) / \
            255.0  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    elif args.arch == 'deconv_rendering':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_imagegray'], dtype=float) / \
            255.0  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    elif args.arch == 'conv_regressor' or args.arch == 'coordconv_regressor':
        fd = h5py.File(args.data_h5, 'r')
        train_y = np.array(
            fd['train_normalized_locations'],
            dtype=float)  # shape (2368, 2)
        # /255.0 # shape (2368, 64, 64, 1)
        train_x = np.array(fd['train_onehots'], dtype=float)
        val_y = np.array(
            fd['val_normalized_locations'],
            dtype=float)  # shape (768, 2)
        val_x = np.array(
            fd['val_onehots'],
            dtype=float)  # shape (768, 64, 64, 1)

        image_c = train_x.shape[-1]

    elif args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_imagegray'], dtype=float) / 255.0  # shape (768, 64, 64, 1)

        # add one-hot anyways to track accuracy etc. even if not used in loss
        train_onehot = np.array(
            fd['train_onehots'],
            dtype=int)  # shape (2368, 64, 64, 1)
        val_onehot = np.array(
            fd['val_onehots'],
            dtype=int)  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    train_size = train_x.shape[0]
    val_size = val_x.shape[0]

    # 1. CREATE MODEL
    input_coords = tf.placeholder(
        shape=(None,2),
        dtype='float32',
        name='input_coords')  # cast later in model into float
    input_onehot = tf.placeholder(
        shape=(None, args.x_dim, args.y_dim, 1),
        dtype='float32',
        name='input_onehot')
    input_images = tf.placeholder(
        shape=(None, args.x_dim, args.y_dim, image_c),
        dtype='float32',
        name='input_images')

    if args.arch == 'deconv_classification':
        model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim,
                              fs=args.filter_size, mul=args.channel_mul,
                              onthefly=DATA_GEN_ON_THE_FLY,
                              use_mse_loss=args.use_mse_loss,
                              use_sigm_loss=args.use_sigm_loss)

        model.a('input_coords', input_coords)

        if not DATA_GEN_ON_THE_FLY:
            model.a('input_onehot', input_onehot)

        model([input_coords]) if DATA_GEN_ON_THE_FLY else model([input_coords, input_onehot])

    if args.arch == 'conv_regressor':
        regress_type = 'conv_uniform' if 'uniform' in args.data_h5 else 'conv_quarant'
        model = ConvRegressor(l2=args.l2, mul=args.channel_mul,
                              _type=regress_type)
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        # call model on inputs
        model([input_onehot, input_coords])

    if args.arch == 'coordconv_regressor':
        model = ConvRegressor(l2=args.l2, mul=args.channel_mul,
                              _type='coordconv')
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        # call model on inputs
        model([input_onehot, input_coords])

    if args.arch == 'conv_onehot_image':
        model = ConvImagePainter(l2=args.l2, fs=args.filter_size, mul=args.channel_mul,
            use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss,
            version='working')
            # version='simple') # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)
        # call model on inputs
        model([input_onehot, input_images])

    if args.arch == 'deconv_rendering':
        model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim,
                              fs=args.filter_size, mul=args.channel_mul,
                              onthefly=False,
                              use_mse_loss=args.use_mse_loss,
                              use_sigm_loss=args.use_sigm_loss)
        model.a('input_coords', input_coords)
        model.a('input_images', input_images)
        # call model on inputs
        model([input_coords, input_images])

    elif args.arch == 'coordconv_classification':
        model = CoordConvPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            include_r=False,
            mul=args.channel_mul,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss)

        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)

        model([input_coords, input_onehot])
        #raise Exception('Not implemented yet')

    elif args.arch == 'coordconv_rendering':
        model = CoordConvImagePainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            include_r=False,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            interm_loss=args.interm_loss,
            no_softmax=args.no_softmax,
            version='working')
        # version='simple') # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)

        # always input three things to calculate relevant metrics
        model([input_coords, input_onehot, input_images])
    elif args.arch == 'deconv_bottleneck':
        model = DeconvBottleneckPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            interm_loss=args.interm_loss,
            no_softmax=args.no_softmax,
            version='working')  # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)

        # always input three things to calculate relevant metrics
        model([input_coords, input_onehot, input_images])

    elif args.arch == 'upsample_conv_coords' or args.arch == 'upsample_coordconv_coords':
        _coordconv = True if args.arch == 'upsample_coordconv_coords' else False
        model = UpsampleConvPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            coordconv=_coordconv)
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model([input_coords, input_onehot])

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    #print 'Model summary:'
    print 'Another model summary:'
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    # a placeholder for dynamic learning rate
    input_lr = tf.placeholder(tf.float32, shape=[])
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    grads_and_vars = opt.compute_gradients(
        model.loss,
        model.trainable_weights,
        gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    # added to train_ and param_ collections
    add_grads_and_vars_hist_summaries(grads_and_vars)

    summarize_opt(opt)
    print 'LR Policy:', lr_policy

    # add_grad_summaries(grads_and_vars)
    if not args.arch.endswith('regressor'):
        image_summaries_traintest(model.logits)

    if 'input_onehot' in model.named_keys():
        image_summaries_traintest(model.input_onehot)
    if 'input_images' in model.named_keys():
        image_summaries_traintest(model.input_images)
    if 'prob' in model.named_keys():
        image_summaries_traintest(model.prob)
    if 'center_prob' in model.named_keys():
        image_summaries_traintest(model.center_prob)
    if 'center_logits' in model.named_keys():
        image_summaries_traintest(model.center_logits)
    if 'pixelwise_prob' in model.named_keys():
        image_summaries_traintest(model.pixelwise_prob)
    if 'center_logits' in model.named_keys():
        image_summaries_traintest(model.center_logits)
    if 'sharpened_logits' in model.named_keys():
        image_summaries_traintest(model.sharpened_logits)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running
    # BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (
        args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()

    buddy.tic()    # call if new run OR resumed run

    # Check if special layers are initialized right
    #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0]
    #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0]

    # Initialize any missed vars (e.g. optimization momentum, ... if not
    # loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(
        uninitialized_vars, 'init_missed_vars')
    sess.run(init_missed_vars)
    # Print warnings about any TF vs. Keras shape mismatches
    # warn_misaligned_shapes(model)
    # Make sure all variables, which are model variables, have been
    # initialized (e.g. model params and model running BN means)
    tf_assert_all_init(sess)
    # tf.global_variables_initializer().run()

    # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge

    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN

    train_iters = (train_size) // minibatch_size + \
        int(train_size % minibatch_size > 0)
    if not args.skipval:
        val_iters = (val_size) // minibatch_size + \
            int(val_size % minibatch_size > 0)

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        def do_log_params(ep, it, ii): return True
        def do_log_val(ep, it, ii): return True

        def do_log_train(
            ep,
            it,
            ii): return (
            it < train_iters and it & it -
            1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch,
                buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Forward test on validation set
        if not args.skipval:
            feed_dict = {learning_phase(): 0}
            if 'input_coords' in model.named_keys():
                val_coords = val_y if args.arch.endswith(
                    'regressor') else val_x
                feed_dict.update({model.input_coords: val_coords})

            if 'input_onehot' in model.named_keys():
                # if 'val_onehot' not in locals():
                if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck':
                    if args.arch == 'conv_onehot_image' or args.arch.endswith('regressor'):
                        val_onehot = val_x
                    else:
                        val_onehot = val_y
                feed_dict.update({
                    model.input_onehot: val_onehot,
                })
            if 'input_images' in model.named_keys():
                feed_dict.update({
                    model.input_images: val_images,
                })

            fetch_dict = model.trackable_dict()

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                if test_image_summaries is not None:
                    fetch_dict.update(
                        {'test_image_summaries': test_image_summaries})
                if test_scalar_summaries is not None:
                    fetch_dict.update(
                        {'test_scalar_summaries': test_scalar_summaries})
                if test_histogram_summaries is not None:
                    fetch_dict.update(
                        {'test_histogram_summaries': test_histogram_summaries})

            with WithTimer('sess.run val iter', quiet=not args.verbose):
                result_val = sess_run_dict(
                    sess, fetch_dict, feed_dict=feed_dict)

            buddy.note_list(
                model.trackable_names(), [
                    result_val[k] for k in model.trackable_names()], prefix='val_')
            print (
                '[%5d] [%2d/%2d] val: %s (%.3gs/i)' %
                (buddy.train_iter,
                 buddy.epoch,
                 args.epochs,
                 buddy.epoch_mean_pretty_re(
                     '^val_',
                     style=val_style),
                    toc2()))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer, buddy.train_iter, {
                        'mean_%s' %
                        name: value for name, value in buddy.epoch_mean_list_re('^val_')}, prefix='val')
                if test_image_summaries is not None:
                    image_summary_str = result_val['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = result_val['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = result_val['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            #snap_end = buddy.epoch == args.epochs
            snap_end = lr_policy.train_done(buddy)
            if snap_intermed or snap_end:
                # Snapshot network and buddy
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open('%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)
                # Snapshot evaluation data and metrics
                _, _ = evaluate_net(
                    args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess)

        lr = lr_policy.get_lr(buddy)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break   # Extra pass at end: just report val stats and skip training

        print '********* at epoch %d, LR is %g' % (buddy.epoch, lr)

        # 3. Train on training set
        if args.shuffletrain:
            train_order = np.random.permutation(train_size)
        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            end_idx = min(start_idx + minibatch_size, train_size)

            if args.shuffletrain:  # default true
                batch_x = train_x[sorted(
                    train_order[start_idx:end_idx].tolist())]
                if train_y is not None:
                    batch_y = train_y[sorted(
                        train_order[start_idx:end_idx].tolist())]
                # if 'train_onehot' in locals():
                if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
                    batch_onehot = train_onehot[sorted(
                        train_order[start_idx:end_idx].tolist())]
            else:
                batch_x = train_x[start_idx:end_idx]
                if train_y is not None:
                    batch_y = train_y[start_idx:end_idx]
                # if 'train_onehot' in locals():
                if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
                    batch_onehot = train_onehot[start_idx:end_idx]

            feed_dict = {learning_phase(): 1, input_lr: lr}
            if 'input_coords' in model.named_keys():
                batch_coords = batch_y if args.arch.endswith(
                    'regressor') else batch_x
                feed_dict.update({model.input_coords: batch_coords})
            if 'input_onehot' in model.named_keys():
                # if 'batch_onehot' not in locals():
                # if not (args.arch == 'coordconv_rendering' and
                # args.add_interm_loss):
                if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck':
                    if args.arch == 'conv_onehot_image' or args.arch.endswith(
                            'regressor'):
                        batch_onehot = batch_x
                    else:
                        batch_onehot = batch_y
                feed_dict.update({
                    model.input_onehot: batch_onehot,
                })
            if 'input_images' in model.named_keys():
                feed_dict.update({
                    model.input_images: batch_images,
                })

            fetch_dict = model.trackable_and_update_dict()

            fetch_dict.update({'train_step': train_step})

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update(
                        {'train_histogram_summaries': train_histogram_summaries})
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(
                    sess, fetch_dict, feed_dict=feed_dict)

            buddy.note_weighted_list(
                batch_x.shape[0], model.trackable_names(), [
                    result_train[k] for k in model.trackable_names()], prefix='train_')

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                print (
                    '[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                    (buddy.train_iter,
                     buddy.epoch,
                     args.epochs,
                     buddy.epoch_mean_pretty_re(
                         '^train_',
                         style=train_style),
                        toc2()))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train['train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer, buddy.train_iter, {
                        'batch_%s' %
                        name: value for name, value in buddy.last_list_re('^train_')}, prefix='train')

            if ii > 0 and ii % 100 == 0:
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, toc3() / 100)
                tic3()

            buddy.inc_train_iter()   # after finished training a mini-batch

        buddy.inc_epoch()   # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer, buddy.train_iter, {
                    'mean_%s' %
                    name: value for name, value in buddy.epoch_mean_list_re('^train_')}, prefix='train')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch,
                                 buddy.train_iter,
                                 buddy.epoch_mean_pretty_re(
                                     '^val_',
                                     style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch,
                                 buddy.train_iter,
                                 buddy.epoch_mean_pretty_re(
                                     '^train_',
                                     style=train_style))

    print '\nEnd of training. Saving evaluation results on whole train and val set.'

    final_tr_metrics, final_va_metrics = evaluate_net(
        args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess)

    print '\nFinal evaluation on whole train and val'
    for name, value in final_tr_metrics.iteritems():
        print 'final_stats_eval train_%s %g' % (name, value)
    for name, value in final_va_metrics.iteritems():
        print 'final_stats_eval val_%s %g' % (name, value)

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()   # Flush and close