コード例 #1
0
def get_effective_mask(self):
    if self.round_mask:
        # during train, clamp a random 50% to their rounded values, and backprop the other 50% directly
        # during test, clamp all of them to their rounded values
        which_to_clamp = tf.cond(
            learning_phase(), lambda: gen_math_ops.round(
                tf.random.uniform(self.kernel_mask.shape, minval=0, maxval=1)),
            lambda: tf.ones(self.kernel_mask.shape))
        binary_mask = gen_math_ops.round(tf.nn.sigmoid(self.kernel_mask))
    else:
        # during train, clamp all of them to 0's and 1's sampled by bernoulli and backprop the probabilities
        # during test, clamp all of them to their rounded values
        # actually, sample them too
        which_to_clamp = tf.ones(self.kernel_mask.shape)
        binary_mask = tf.cond(
            learning_phase(),
            lambda: tf.cast(tf.distributions.Bernoulli(probs=tf.nn.sigmoid(
                self.kernel_mask)).sample(),
                            dtype=tf.float32) + tf.nn.sigmoid(self.kernel_mask)
            - tf.stop_gradient(tf.nn.sigmoid(self.kernel_mask)),
            lambda: tf.cast(tf.distributions.Bernoulli(probs=tf.nn.sigmoid(
                self.kernel_mask)).sample(),
                            dtype=tf.float32))

    return which_to_clamp * binary_mask + (1 - which_to_clamp) * tf.nn.sigmoid(
        self.kernel_mask)
コード例 #2
0
def eval(sess, model, train_x, train_y, test_x, test_y, args, tb_writer, iterations):
    timerstart = time.time()
    # eval on entire train set
    tb_prefix_and_iter = ('eval_train', iterations) if tb_writer else (None, None)
    cur_train_acc, cur_train_loss = eval_on_entire_dataset(sess, model, train_x, train_y,
                    args.large_batch_size, tb_prefix_and_iter, tb_writer)

    # eval on entire test/val set
    tb_prefix_and_iter = ('eval_test', iterations) if tb_writer else (None, None)
    cur_test_acc, cur_test_loss = eval_on_entire_dataset(sess, model, test_x, test_y,
                    args.test_batch_size, tb_prefix_and_iter, tb_writer)
    
    print(('{}: train acc = {:.4f}, test acc = {:.4f}, '
        + 'train loss = {:.4f}, test loss = {:.4f} ({:.2f} s)').format(iterations,
        cur_train_acc, cur_test_acc, cur_train_loss, cur_test_loss, time.time() - timerstart))

    if 'mask' in args.arch:
        percs, ones_all, size_all = [], 0, 0
        for layer in model.trainable_weights:
            assert 'mask' in layer.name, "Should be just training masks"
            #if 'bias' in layer.name:
            #    #print('bias values: ', layer.eval())
            #    continue
            mprobs = tf.stop_gradient(tf.nn.sigmoid(layer)).eval()
            num_ones = mprobs.sum() # expected value
            # old, wrong
            #nparr = layer.eval() # before sigmoid
            #num_ones = (nparr > 0).sum() + 0.5 * (nparr == 0).sum() # expected value
            #percs.append(num_ones / nparr.size)
            percs.append(num_ones / mprobs.size)
            ones_all += num_ones
            size_all += mprobs.size
        print('[Est] percent of 1s in mask (per layer):', percs)
        print('[Est] percent of 1s in mask (total):', ones_all/size_all)
        if args.dynamic_scaling:
            layer_ones = [layer.ones_in_mask for layer in list(model.layers) if 
                    'conv2D' in layer.name or 'fc' in layer.name]
            layer_mults = [layer.multiplier for layer in list(model.layers) if 
                    'conv2D' in layer.name or 'fc' in layer.name]
            layer_sizes = [tf.size(layer.kernel).eval() for layer in list(model.layers) if 
                    'conv2D' in layer.name or 'fc' in layer.name]
            l_ones = sess.run(layer_ones, feed_dict={learning_phase(): 0}) 
            l_mults = sess.run(layer_mults, feed_dict={learning_phase(): 0}) 
            print('[Act] percent of 1s in mask (per layer):', (np.array(l_ones) / np.array(layer_sizes)).tolist())
            print('[Act] percent of 1s in mask (total):', np.sum(l_ones) / np.sum(layer_sizes))
            print('layer signed constant multipliers:', l_mults)
            
    return cur_train_acc, cur_test_acc, cur_train_loss, cur_test_loss
コード例 #3
0
def get_gradients_and_eval(sess,
                           model,
                           input_x,
                           input_y,
                           dim_sum,
                           batch_size,
                           get_eval=True,
                           get_grads=True):
    grad_sums = np.zeros(dim_sum)
    num_batches = int(input_y.shape[0] / batch_size)
    total_acc = 0
    total_loss = 0
    total_loss_no_reg = 0  # loss without counting l2 penalty

    for i in range(num_batches):
        # slice indices (should be large)
        s_start = batch_size * i
        s_end = s_start + batch_size

        fetch_dict = {}
        if get_eval:
            # fetch_dict['accuracy'] = model.accuracy
            # fetch_dict['loss'] = model.loss
            fetch_dict['loss_no_reg'] = model.loss_cross_ent
        if get_grads:
            fetch_dict['gradients'] = model.grads_to_compute

        result_dict = sess_run_dict(sess,
                                    fetch_dict,
                                    feed_dict={
                                        model.input_images:
                                        input_x[s_start:s_end],
                                        model.input_labels:
                                        input_y[s_start:s_end],
                                        learning_phase(): 0,
                                        batchnorm_learning_phase(): 1
                                    })

        if get_eval:
            # total_acc += result_dict['accuracy']
            # total_loss += result_dict['loss']
            total_loss_no_reg += result_dict['loss_no_reg']
        if get_grads:
            grads = result_dict[
                'gradients']  # grads should now be a list of np arrays
            flattened = np.concatenate([grad.flatten() for grad in grads])
            grad_sums += flattened

    acc = total_acc / num_batches
    loss = total_loss / num_batches
    loss_no_reg = total_loss_no_reg / num_batches

    return np.divide(grad_sums, num_batches), loss_no_reg
コード例 #4
0
def eval_on_entire_dataset(sess, model, input_x, input_y, batch_size,
                           tb_prefix_and_iter, tb_writer):
    #grad_sums = np.zeros(dim_sum)
    num_batches = int(input_y.shape[0] / batch_size)
    total_acc = 0
    total_loss = 0
    total_loss_no_reg = 0  # loss without counting l2 penalty

    for i in range(num_batches):
        # slice indices (should be large)
        s_start = batch_size * i
        s_end = s_start + batch_size

        fetch_dict = {
            'accuracy': model.accuracy,
            'loss': model.loss,
            'loss_no_reg': model.loss_cross_ent
        }

        result_dict = sess_run_dict(sess,
                                    fetch_dict,
                                    feed_dict={
                                        model.input_images:
                                        input_x[s_start:s_end],
                                        model.input_labels:
                                        input_y[s_start:s_end],
                                        learning_phase(): 0,
                                        batchnorm_learning_phase(): 1
                                    })  # do not use nor update moving averages

        total_acc += result_dict['accuracy']
        total_loss += result_dict['loss']
        total_loss_no_reg += result_dict['loss_no_reg']

    acc = total_acc / num_batches
    loss = total_loss / num_batches
    loss_no_reg = total_loss_no_reg / num_batches

    # tensorboard
    if tb_writer:
        tb_prefix, iterations = tb_prefix_and_iter
        summary = tf.Summary()
        summary.value.add(tag='%s_acc' % tb_prefix, simple_value=acc)
        summary.value.add(tag='%s_loss' % tb_prefix, simple_value=loss)
        summary.value.add(tag='%s_loss_no_reg' % tb_prefix,
                          simple_value=loss_no_reg)
        tb_writer.add_summary(summary, iterations)

    return acc, loss_no_reg
コード例 #5
0
def calc_one_iter_grads(sess, model, train_x, train_y, snip_batch_size, dsets):
    train_size = train_x.shape[0]
    batch_ind = np.random.choice(range(train_size),
                                 size=snip_batch_size,
                                 replace=False)
    fetch_dict = {}
    fetch_dict['gradients'] = model.grads_to_compute

    result_dict = sess_run_dict(sess,
                                fetch_dict,
                                feed_dict={
                                    model.input_images: train_x[batch_ind],
                                    model.input_labels: train_y[batch_ind],
                                    learning_phase(): 0,
                                    batchnorm_learning_phase(): 1
                                })

    grads = result_dict['gradients']
    flattened = np.concatenate([grad.flatten() for grad in grads])

    return flattened
コード例 #6
0
def eval_on_entire_dataset(sess, model, input_x, input_y, batch_size):
    num_batches = int(input_y.shape[0] / batch_size)
    total_acc = 0
    total_loss = 0
    total_loss_no_reg = 0  # loss without counting l2 penalty

    for i in range(num_batches):
        # slice indices (should be large)
        s_start = batch_size * i
        s_end = s_start + batch_size

        fetch_dict = {
            'accuracy': model.accuracy,
            'loss': model.loss,
            'loss_no_reg': model.loss_cross_ent
        }

        #sess_run_dict is from tfutil and it returns a dictionary
        result_dict = sess_run_dict(
            sess,
            fetch_dict,
            feed_dict={
                model.input_images: input_x[s_start:s_end],
                model.input_labels: input_y[s_start:s_end],
                learning_phase(): 0,
                batchnorm_learning_phase(): 1
            })  # do not use nor update moving averages (****??****)

        total_acc += result_dict['accuracy']
        total_loss += result_dict['loss']
        total_loss_no_reg += result_dict['loss_no_reg']

    acc = total_acc / num_batches
    loss = total_loss / num_batches
    loss_no_reg = total_loss_no_reg / num_batches

    return acc, loss_no_reg
コード例 #7
0
def train_and_eval(sess, model, train_x, train_y, test_x, test_y, tb_writer, dsets, args):
# def train_and_eval(sess, model, train_y_shape, train_generator, val_generator, tb_writer, dsets, args):
    # constants
    # num_batches = int(train_y_shape[0] / args.train_batch_size)
    num_batches = int(train_y.shape[0] / args.train_batch_size)
    print('Training batch size {}, number of iterations: {} per epoch, {} total'.format(
        args.train_batch_size, num_batches, args.num_epochs*num_batches))
    dim_sum = sum([tf.size(var).eval() for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])

    # adaptive learning schedule
    curr_lr = args.lr
    decay_epochs = [int(ep) for ep in args.decay_schedule.split(',')]
    if decay_epochs[-1] > 0:
        decay_epochs.append(-1) # end with something small to stop the decay
    decay_count = 0

    # initializations
    tb_summaries = tf.summary.merge(tf.get_collection('tb_train_step'))
    shuffled_indices = np.arange(train_y.shape[0])  # for no shuffling
    iterations = 0
    chunks_written = 0 # for args.save_every batches
    timerstart = time.time()

    for epoch in range(args.num_epochs):
        # print('-' * 100)
        # print('epoch {}  current lr {:.3g}'.format(epoch, curr_lr))
        if not args.no_shuffle:
            shuffled_indices = np.random.permutation(train_y.shape[0])  # for shuffled mini-batches

        if epoch == decay_epochs[decay_count]:
            curr_lr *= 0.1
            decay_count += 1

        for i in range(num_batches):
            # store current weights and gradients
            if args.save_weights and iterations % args.save_every == 0:
                dsets['all_weights'][chunks_written] = flatten_all(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
                chunks_written += 1

            # less frequent, larger evals
            if iterations % args.eval_every == 0:
                # eval on entire train set
                # cur_train_acc, cur_train_loss = eval_on_entire_dataset(sess, model, train_y_shape, train_generator,
                cur_train_acc, cur_train_loss = eval_on_entire_dataset(sess, model, train_x, train_y,
                        dim_sum, args.large_batch_size, 'eval_train', tb_writer, iterations)

                # eval on entire test/val set
                # cur_test_acc, cur_test_loss = eval_on_entire_dataset(sess, model, train_y_shape, val_generator,
                cur_test_acc, cur_test_loss = eval_on_entire_dataset(sess, model, test_x, test_y,
                        dim_sum, args.test_batch_size, 'eval_test', tb_writer, iterations)

            # print status update
            if iterations % args.print_every == 0:
                print(('{}: train acc = {:.4f}, test acc = {:.4f}, '
                    + 'train loss = {:.4f}, test loss = {:.4f}, lr = {:.4f} ({:.2f} s)').format(iterations,
                    cur_train_acc, cur_test_acc, cur_train_loss, cur_test_loss, curr_lr, time.time() - timerstart))

            # current slice for input data
            batch_indices = shuffled_indices[args.train_batch_size * i : args.train_batch_size * (i + 1)]

            # Generate batch of training data according to current slice:
            # train_x_single_b, train_y_single_b = train_generator[i]

            # training
            # fetch_dict = {'accuracy': model.accuracy, 'loss': model.loss} # no longer used
            if len(args.freeze_layers) > 0 and iterations >= args.freeze_starting:
                fetch_dict = {'train_step': model.train_step_freeze}
            elif len(args.opt2_layers) > 0:
                fetch_dict = {'train_step_1': model.train_step_1,
                    'train_step_2': model.train_step_2}
            else:
                fetch_dict = {'train_step': model.train_step}
            fetch_dict.update(model.update_dict())

            if iterations % args.log_every == 0:
                fetch_dict.update({'tb': tb_summaries})
            if args.save_training_grads:
                fetch_dict['gradients'] = model.grads_to_compute

            result_train = sess_run_dict(sess, fetch_dict, feed_dict={
                model.input_images: train_x[batch_indices],
                model.input_labels: train_y[batch_indices],
                model.input_lr: curr_lr,
                learning_phase(): 1,
                batchnorm_learning_phase(): 1})

            # log to tensorboard
            if tb_writer and iterations % args.log_every == 0:
                tb_writer.add_summary(result_train['tb'], iterations)

            if args.save_training_grads:
                dsets['training_grads'][iterations] = np.concatenate(
                    [grad.flatten() for grad in result_train['gradients']])

            iterations += 1

    # save final weight values
    if args.save_weights:
        dsets['all_weights'][chunks_written] = flatten_all(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

    # Save model ?
    saver = tf.train.Saver()
    saver.save(sess, args.output_dir + '\\model')

    # save final evals
    if iterations % args.eval_every == 0:
        # on entire train set
        # cur_train_acc, cur_train_loss = eval_on_entire_dataset(sess, model, train_y_shape, train_generator,
        cur_train_acc, cur_train_loss = eval_on_entire_dataset(sess, model, train_x, train_y,
            dim_sum, args.large_batch_size, 'eval_train', tb_writer, iterations)

        # on entire test/val set
        # cur_test_acc, cur_test_loss = eval_on_entire_dataset(sess, model, train_y_shape, val_generator,
        cur_test_acc, cur_test_loss = eval_on_entire_dataset(sess, model, test_x, test_y,
            dim_sum, args.test_batch_size, 'eval_test', tb_writer, iterations)

    # print last status update
    print(('{}: train acc = {:.4f}, test acc = {:.4f}, '
        + 'train loss = {:.4f}, test loss = {:.4f} ({:.2f} s)').format(iterations,
        cur_train_acc, cur_test_acc, cur_train_loss, cur_test_loss, time.time() - timerstart))
コード例 #8
0
def train_and_eval(sess, model, snip_batch_size, train_x, train_y, val_x,
                   val_y, test_x, test_y, tb_writer, dsets, args):
    # constants
    num_batches = int(train_y.shape[0] / args.train_batch_size)
    dim_sum = sum([tf.size(var).eval() for var in model.trainable_weights
                   ])  #dimention of weight matrices

    # adaptive learning schedule
    curr_lr = args.lr
    decay_schedule = [int(x) for x in args.decay_schedule.split(',')]
    print(decay_schedule)
    decay_count = 0

    # initializations
    tb_summaries = tf.summary.merge(tf.get_collection('train_step'))

    shuffled_indices = np.arange(train_y.shape[0])  # for no shuffling
    iterations = 0
    chunks_written = 0
    timerstart = time.time()
    iter_index = 0

    if args.save_weights:
        dsets['all_weights'][chunks_written] = flatten_all(
            model.trainable_weights)

    chunks_written += 1

    dsets['one_iter_grads'][0] = calc_one_iter_grads(sess, model, train_x,
                                                     train_y, snip_batch_size,
                                                     dsets)

    for epoch in range(args.num_epochs):
        if not args.no_shuffle:
            shuffled_indices = np.random.permutation(
                train_y.shape[0])  # for shuffled mini-batches

        if args.decay_lr and epoch == decay_schedule[decay_count]:
            curr_lr *= 0.1
            decay_count += 1
            print('dropping learning rate to ' + str(curr_lr))

        for i in range(num_batches):

            # less frequent, larger evals
            if iterations % args.eval_every == 0:
                # eval on entire train set
                cur_train_acc, cur_train_loss = eval_on_entire_dataset(
                    sess, model, train_x, train_y, dim_sum,
                    args.large_batch_size, 'eval_train', tb_writer, iterations)

                # eval on entire test/val set
                cur_test_acc, cur_test_loss = eval_on_entire_dataset(
                    sess, model, test_x, test_y, dim_sum, args.test_batch_size,
                    'eval_test', tb_writer, iterations)

                cur_val_acc, cur_val_loss = eval_on_entire_dataset(
                    sess, model, val_x, val_y, dim_sum, args.val_batch_size,
                    'eval_val', tb_writer, iterations)

                if args.save_loss:
                    dsets['train_accuracy'][iter_index] = cur_train_acc
                    dsets['train_loss'][iter_index] = cur_train_loss
                    dsets['val_accuracy'][iter_index] = cur_val_acc
                    dsets['val_loss'][iter_index] = cur_val_loss
                    dsets['test_accuracy'][iter_index] = cur_test_acc
                    dsets['test_loss'][iter_index] = cur_test_loss
                    iter_index += 1

            # print status update
            if iterations % args.print_every == 0:
                print((
                    '{}: train acc = {:.4f}, val acc = {:.4f}, test acc = {:.4f}, '
                    +
                    'train loss = {:.4f}, val loss = {:.4f}, test loss = {:.4f} ({:.2f} s)'
                ).format(iterations, cur_train_acc, cur_val_acc, cur_test_acc,
                         cur_train_loss, cur_val_loss, cur_test_loss,
                         time.time() - timerstart))

            # current slice for input data
            batch_indices = shuffled_indices[args.train_batch_size *
                                             i:args.train_batch_size * (i + 1)]

            # training
            fetch_dict = {
                'train_step': model.train_step,
                'accuracy': model.accuracy,
                'loss': model.loss
            }
            fetch_dict.update(model.update_dict())
            if iterations % args.log_every == 0:
                fetch_dict.update({'tb': tb_summaries})
            result_train = sess_run_dict(sess,
                                         fetch_dict,
                                         feed_dict={
                                             model.input_images:
                                             train_x[batch_indices],
                                             model.input_labels:
                                             train_y[batch_indices],
                                             model.input_lr:
                                             curr_lr,
                                             learning_phase():
                                             1,
                                             batchnorm_learning_phase():
                                             1
                                         })

            # log to tensorboard
            if tb_writer and iterations % args.log_every == 0:
                tb_writer.add_summary(result_train['tb'], iterations)

            iterations += 1

            if iterations == 1:
                dsets['all_weights'][chunks_written] = flatten_all(
                    model.trainable_weights)
                chunks_written += 1

            # store current weights and gradients
            if args.mode == 'save_all' and args.save_weights and iterations % args.eval_every == 0:
                dsets['all_weights'][chunks_written] = flatten_all(
                    model.trainable_weights)
                chunks_written += 1

    # save final weight values
    if args.save_weights and iterations % args.eval_every != 0:
        dsets['all_weights'][chunks_written] = flatten_all(
            model.trainable_weights)

    # save final evals
    # on entire train set
    cur_train_acc, cur_train_loss = eval_on_entire_dataset(
        sess, model, train_x, train_y, dim_sum, args.large_batch_size,
        'eval_train', tb_writer, iterations)

    # on entire test/val set
    cur_test_acc, cur_test_loss = eval_on_entire_dataset(
        sess, model, test_x, test_y, dim_sum, args.test_batch_size,
        'eval_test', tb_writer, iterations)

    cur_val_acc, cur_val_loss = eval_on_entire_dataset(sess, model, val_x,
                                                       val_y, dim_sum,
                                                       args.val_batch_size,
                                                       'eval_val', tb_writer,
                                                       iterations)

    if args.save_loss and iterations % args.eval_every != 0:
        dsets['train_accuracy'][iter_index] = cur_train_acc
        dsets['train_loss'][iter_index] = cur_train_loss
        dsets['test_accuracy'][iter_index] = cur_test_acc
        dsets['test_loss'][iter_index] = cur_test_loss
        dsets['val_accuracy'][iter_index] = cur_val_acc
        dsets['val_loss'][iter_index] = cur_val_loss

    # print last status update
    print((
        '{}: train acc = {:.4f}, val acc = {:.4f}, test acc = {:.4f}, ' +
        'train loss = {:.4f}, val loss = {:.4f}, test loss = {:.4f} ({:.2f} s)'
    ).format(iterations, cur_train_acc, cur_val_acc, cur_test_acc,
             cur_train_loss, cur_val_loss, cur_test_loss,
             time.time() - timerstart))
コード例 #9
0
def main():
    parser = make_standard_parser(
        'Train a GAN model on simple square images or Clevr two-object color images',
        arch_choices=arch_choices,
        skip_train=True,
        skip_val=True)
    parser.add_argument('--z_dim',
                        type=int,
                        default=10,
                        help='Dimension of noise vector')
    parser.add_argument('--lr2',
                        type=float,
                        default=None,
                        help='learning rate for generator')
    parser.add_argument('--feature_match',
                        '-fm',
                        action='store_true',
                        help='use feature matching loss for generator.')
    parser.add_argument(
        '--feature_match_loss_weight',
        '-fmalpha',
        type=float,
        default=1.0,
        help='weight on the feature matching loss for generator.')
    parser.add_argument(
        '--pairedz',
        action='store_true',
        help='If True, pair the same z with a training batch each epoch')
    parser.add_argument(
        '--eval-train-every',
        type=int,
        default=0,
        help='evaluate whole training set every N epochs. 0 to disable.')

    args = parser.parse_args()

    args.skipval = True

    minibatch_size = args.minibatch
    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)
    evaltrain_style = '' if args.nocolor or args.eval_train_every <= 0 else colorama.Fore.CYAN

    black_divider = True if args.arch.startswith('clevr') else False

    # Get a TF session and set numpy and TF seeds
    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. LOAD DATA
    if args.arch.startswith('simple'):
        fd = h5py.File('data/rectangle_4_uniform.h5', 'r')
        train_x = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_imagegray'],
                         dtype=float) / 255.0  # shape (768, 64, 64, 1)
        train_x = np.concatenate((train_x, val_x),
                                 axis=0)  # shape (3136, 64, 64, 1)

    elif args.arch.startswith('clevr'):
        (train_x, val_x) = load_sort_of_clevr()
        # shape (50000, 64, 64, 3)
        train_x = np.concatenate((train_x, val_x), axis=0)

    else:
        raise Exception('Unknown network architecture: %s' % args.arch)

    print 'Train data loaded: {} images, size {}'.format(
        train_x.shape[0], train_x.shape[1:])
    #print 'Val data loaded: {} images, size {}'.format(val_x.shape[0], val_x.shape[1:])

    #print 'Label dimension: {}'.format(val_y.shape[1:])

    # 1. CREATE MODEL
    assert len(train_x.shape) == 4, "image data must be of 4 dimensions"
    image_h, image_w, image_c = train_x.shape[1], train_x.shape[
        2], train_x.shape[3]

    model = build_model(args, image_h, image_w, image_c)

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    print 'Model summary:'
    # model.summary()      # TOREPLACE
    print 'Another model summary:'
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    lr_gen = args.lr2 if args.lr2 else args.lr

    if args.opt == 'sgd':
        d_opt = tf.train.MomentumOptimizer(args.lr, args.mom)
        g_opt = tf.train.MomentumOptimizer(lr_gen, args.mom)
    elif args.opt == 'rmsprop':
        d_opt = tf.train.RMSPropOptimizer(args.lr, momentum=args.mom)
        g_opt = tf.train.RMSPropOptimizer(lr_gen, momentum=args.mom)
    elif args.opt == 'adam':
        d_opt = tf.train.AdamOptimizer(args.lr, args.beta1, args.beta2)
        g_opt = tf.train.AdamOptimizer(lr_gen, args.beta1, args.beta2)

    # Optimize w.r.t all trainable params in the model

    all_vars = model.trainable_variables
    d_vars = [var for var in all_vars if 'discriminator' in var.name]
    g_vars = [var for var in all_vars if 'generator' in var.name]

    d_grads_and_vars = d_opt.compute_gradients(
        model.d_loss, d_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    d_train_step = d_opt.apply_gradients(d_grads_and_vars)
    g_grads_and_vars = g_opt.compute_gradients(
        model.g_loss, g_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    g_train_step = g_opt.apply_gradients(g_grads_and_vars)

    hist_summaries_traintest(model.d_real_logits, model.d_fake_logits)

    add_grads_and_vars_hist_summaries(d_grads_and_vars)
    add_grads_and_vars_hist_summaries(g_grads_and_vars)
    image_summaries_traintest(model.fake_images)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running
    # BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy(pretty_replaces=[('evaltrain_', ''), (
            'eval', '')]) if args.eval_train_every > 0 else StatsBuddy()

    buddy.tic()  # call if new run OR resumed run

    tf.global_variables_initializer().run()

    # 4. SETUP TENSORBOARD LOGGING

    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN
    train_iters = (train_x.shape[0]) // minibatch_size
    if not args.skipval:
        val_iters = (val_x.shape[0]) // minibatch_size

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    # 2. use same noise, eval on 100 samples and save G(z),
    np.random.seed()
    eval_batch_size = 100
    eval_z = np.random.uniform(-1, 1, size=(eval_batch_size, args.z_dim))

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        def do_log_params(ep, it, ii):
            return True

        def do_log_val(ep, it, ii):
            return True

        def do_log_train(ep, it, ii):
            return (it < train_iters and it & it - 1 == 0
                    or it >= train_iters and it % train_iters == 0
                    )  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Evaluate generator by showing random generated results
        #    Evaluate descriminator by showing seeing correct rate on generated and real (hold-out) results
        #assert(args.skipval), "only support training now"

        if not args.skipval:
            tic2()
            # use different noise, eval on larger number of samples and get
            # correct rate
            np.random.seed()
            val_z = np.random.uniform(-1, 1, size=(val_x.shape[0], args.z_dim))

            with WithTimer('sess.run val iter', quiet=not args.verbose):
                feed_dict = {
                    model.input_images: val_x,
                    model.input_noise: val_z,
                    learning_phase(): 0
                }

                if 'input_labels' in model.named_keys():
                    feed_dict.update({model.input_labels: val_y})

                val_corr_fake_bn0, val_corr_real_bn0 = sess.run(
                    [model.correct_fake, model.correct_real],
                    feed_dict=feed_dict)

                feed_dict[learning_phase()] = 1
                val_corr_fake_bn1, val_corr_real_bn1 = sess.run(
                    [model.correct_fake, model.correct_real],
                    feed_dict=feed_dict)

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                fetch_dict = {}
                if test_image_summaries is not None:
                    fetch_dict.update(
                        {'test_image_summaries': test_image_summaries})
                if test_scalar_summaries is not None:
                    fetch_dict.update(
                        {'test_scalar_summaries': test_scalar_summaries})
                if test_histogram_summaries is not None:
                    fetch_dict.update(
                        {'test_histogram_summaries': test_histogram_summaries})
                if fetch_dict:
                    summary_strs = sess_run_dict(sess,
                                                 fetch_dict,
                                                 feed_dict=feed_dict)

            buddy.note_list([
                'correct_real_bn0', 'correct_fake_bn0', 'correct_real_bn1',
                'correct_fake_bn1'
            ], [
                val_corr_real_bn0, val_corr_fake_bn0, val_corr_real_bn1,
                val_corr_fake_bn1
            ],
                            prefix='val_')

            print(
                '%3d (ep %d) val: %s (%.3gs/ep)' %
                (buddy.train_iter, buddy.epoch,
                 buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2()))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='buddy')

                if test_image_summaries is not None:
                    image_summary_str = summary_strs['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = summary_strs['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = summary_strs['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # In addition, evalutate 1000 more images
        np.random.seed()
        eval_more = np.random.uniform(-1, 1, size=(1000, args.z_dim))
        feed_dict2 = {
            # (100,-) generated outside of loop to keep the same every round
            model.input_noise:
            eval_z,
            learning_phase():
            0
        }

        eval_samples_bn0 = sess.run(model.fake_images, feed_dict=feed_dict2)

        feed_dict2[learning_phase()] = 1
        eval_samples_bn1 = sess.run(model.fake_images, feed_dict=feed_dict2)

        # feed in 10 times because coordconv cannot handle too big of a batch
        for cc in range(10):
            eval_z2 = eval_more[cc * 100:(cc + 1) * 100, :]
            _eval_more_samples = sess.run(
                model.fake_images,
                feed_dict={
                    model.input_noise: eval_z2,  # (1000,-)
                    learning_phase(): 0
                })
            eval_more_samples = _eval_more_samples if cc == 0 else np.concatenate(
                (eval_more_samples, _eval_more_samples), axis=0)

        if args.output:
            mkdir_p('{}/fake_images'.format(args.output))
            # eval_samples_bn*: e.g. (100, 64, 64, 3)
            save_images(eval_samples_bn0, [10, 10],
                        '{}/fake_images/g_out_bn0_epoch_{}_iter_{}.png'.format(
                            args.output, buddy.epoch, buddy.train_iter),
                        black_divider=black_divider)
            save_images(eval_samples_bn1, [10, 10],
                        '{}/fake_images/g_out_bn1_epoch_{}.png'.format(
                            args.output, buddy.epoch),
                        black_divider=black_divider)
            save_average_image(
                eval_more_samples,
                '{}/fake_images/g_out_averaged_epoch_{}_iter_{}.png'.format(
                    args.output, buddy.epoch, buddy.train_iter))

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            snap_end = buddy.epoch == args.epochs
            if snap_intermed or snap_end:
                # Snapshot
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open(
                        '%s/%s_misc_%04d.pkl.gz' %
                    (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)
                # snapshot sampled images too
                ff = h5py.File(
                    '%s/sampled_images_%04d.h5' % (args.output, buddy.epoch),
                    'w')
                ff.create_dataset('eval_samples_bn0', data=eval_samples_bn0)
                ff.create_dataset('eval_samples_bn1', data=eval_samples_bn1)
                ff.create_dataset('eval_z', data=eval_z)
                ff.create_dataset('eval_z_more', data=eval_more)
                ff.create_dataset('eval_more_samples', data=eval_more_samples)
                ff.close()

        # 2. Possiby evaluate the training set
        if args.eval_train_every > 0:
            if buddy.epoch % args.eval_train_every == 0:
                tic2()
                for ii in xrange(train_iters):
                    start_idx = ii * minibatch_size
                    if args.pairedz:
                        np.random.seed(args.seed + ii)
                    else:
                        np.random.seed()
                    batch_z = np.random.uniform(-1,
                                                1,
                                                size=(minibatch_size,
                                                      args.z_dim))

                    batch_x = train_x[start_idx:start_idx + minibatch_size]
                    batch_y = train_y[start_idx:start_idx + minibatch_size]

                    feed_dict = {
                        model.input_images: batch_x,
                        # model.input_labels: batch_y,
                        model.input_noise: batch_z,
                        learning_phase(): 0,
                    }

                    if 'input_labels' in model.named_keys():
                        feed_dict.update({model.input_labels: val_y})

                    fetch_dict = model.trackable_dict()
                    result_eval_train = sess_run_dict(sess,
                                                      fetch_dict,
                                                      feed_dict=feed_dict)
                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names(), [
                            result_eval_train[k]
                            for k in model.trackable_names()
                        ],
                        prefix='evaltrain_bn0_')

                    feed_dict = {
                        model.input_images: batch_x,
                        # model.input_labels: batch_y,
                        model.input_noise: batch_z,
                        learning_phase(): 1,
                    }
                    if 'input_labels' in model.named_keys():
                        feed_dict.update({model.input_labels: val_y})

                    result_eval_train = sess_run_dict(sess,
                                                      fetch_dict,
                                                      feed_dict=feed_dict)
                    buddy.note_weighted_list(
                        batch_x.shape[0],
                        model.trackable_names(), [
                            result_eval_train[k]
                            for k in model.trackable_names()
                        ],
                        prefix='evaltrain_bn1_')

                    if args.output:
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^evaltrain_bn0_')
                                    },
                                    prefix='buddy')
                        log_scalars(writer,
                                    buddy.train_iter, {
                                        'batch_%s' % name: value
                                        for name, value in buddy.last_list_re(
                                            '^evaltrain_bn1_')
                                    },
                                    prefix='buddy')
                if args.output:
                    log_scalars(writer,
                                buddy.epoch, {
                                    'mean_%s' % name: value
                                    for name, value in
                                    buddy.epoch_mean_list_re('^evaltrain_bn0_')
                                },
                                prefix='buddy')
                    log_scalars(writer,
                                buddy.epoch, {
                                    'mean_%s' % name: value
                                    for name, value in
                                    buddy.epoch_mean_list_re('^evaltrain_bn1_')
                                },
                                prefix='buddy')

                print('%3d (ep %d) evaltrain: %s (%.3gs/ep)' %
                      (buddy.train_iter, buddy.epoch,
                       buddy.epoch_mean_pretty_re(
                           '^evaltrain_bn0_', style=evaltrain_style), toc2()))
                print('%3d (ep %d) evaltrain: %s (%.3gs/ep)' %
                      (buddy.train_iter, buddy.epoch,
                       buddy.epoch_mean_pretty_re(
                           '^evaltrain_bn1_', style=evaltrain_style), toc2()))

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break  # Extra pass at end: just report val stats and skip training

        # 3. Train on training set

        if args.shuffletrain:
            train_order = np.random.permutation(train_x.shape[0])
            train_order2 = np.random.permutation(train_x.shape[0])
        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            if args.pairedz:
                np.random.seed(args.seed + ii)
            else:
                np.random.seed()

            batch_z = np.random.uniform(-1,
                                        1,
                                        size=(minibatch_size, args.z_dim))

            if args.shuffletrain:
                #batch_x = train_x[train_order[start_idx:start_idx + minibatch_size]]
                batch_x = train_x[sorted(train_order[start_idx:start_idx +
                                                     minibatch_size].tolist())]
                if args.feature_match:
                    assert args.shuffletrain, "feature matching loss requires shuffle train"
                    batch_x2 = train_x[sorted(
                        train_order2[start_idx:start_idx +
                                     minibatch_size].tolist())]
                if 'input_labels' in model.named_keys():
                    batch_y = train_y[sorted(
                        train_order[start_idx:start_idx +
                                    minibatch_size].tolist())]
            else:
                batch_x = train_x[start_idx:start_idx + minibatch_size]
                if 'input_labels' in model.named_keys():
                    batch_y = train_y[start_idx:start_idx + minibatch_size]

            feed_dict = {
                model.input_images: batch_x,
                # model.input_labels: batch_y,
                model.input_noise: batch_z,
                learning_phase(): 1,
            }

            if 'input_labels' in model.named_keys():
                feed_dict.update({model.input_labels: batch_y})
            if 'input_images2' in model.named_keys():
                feed_dict.update({model.input_images2: batch_x2})

            fetch_dict = model.trackable_and_update_dict()

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update({
                        'train_histogram_summaries':
                        train_histogram_summaries
                    })
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(sess,
                                             fetch_dict,
                                             feed_dict=feed_dict)

                # if result_train['d_loss'] < result_train['g_loss']:
                #    #print 'Only train G'
                #    sess.run(g_train_step, feed_dict=feed_dict)
                # else:
                #    #print 'Train both D and G'
                #    sess.run(d_train_step, feed_dict=feed_dict)
                #    sess.run(g_train_step, feed_dict=feed_dict)
                #    sess.run(g_train_step, feed_dict=feed_dict)
                sess.run(d_train_step, feed_dict=feed_dict)
                sess.run(g_train_step, feed_dict=feed_dict)
                sess.run(g_train_step, feed_dict=feed_dict)

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                buddy.note_weighted_list(
                    batch_x.shape[0],
                    model.trackable_names(),
                    [result_train[k] for k in model.trackable_names()],
                    prefix='train_')
                print('[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                      (buddy.train_iter, buddy.epoch, args.epochs,
                       buddy.epoch_mean_pretty_re('^train_',
                                                  style=train_style), toc2()))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train[
                        'train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'batch_%s' % name: value
                        for name, value in buddy.last_list_re('^train_')
                    },
                    prefix='buddy')

            if ii > 0 and ii % 100 == 0:
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, toc3() / 100)
                tic3()

            buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='buddy')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^val_',
                                                            style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter,
                                 buddy.epoch_mean_pretty_re('^train_',
                                                            style=train_style))

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()  # Flush and close
コード例 #10
0
def train_and_eval(sess, model, train_x, train_y, test_x, test_y, tb_writer,
                   dsets, args):
    # constants
    num_batches = int(train_y.shape[0] / args.train_batch_size)
    print(
        'Training batch size {}, number of iterations: {} per epoch, {} total'.
        format(args.train_batch_size, num_batches,
               args.num_epochs * num_batches))
    #dim_sum = sum([tf.size(var).eval() for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])

    # adaptive learning schedule
    curr_lr = args.lr
    decay_epochs = [int(ep) for ep in args.decay_schedule.split(',')]
    if decay_epochs[-1] > 0:
        decay_epochs.append(
            -1)  # need to end with something small to stop the decay
    decay_count = 0

    # initializations
    tb_summaries = tf.summary.merge(tf.get_collection('tb_train_step'))
    shuffled_indices = np.arange(train_y.shape[0])  # for no shuffling
    iterations = 0
    chunks_written = 0  # for args.save_every batches
    timerstart = time.time()

    for epoch in range(args.num_epochs):
        print('-' * 100)
        print('epoch {}  current lr {:.3g}'.format(epoch, curr_lr))
        if not args.no_shuffle:
            shuffled_indices = np.random.permutation(
                train_y.shape[0])  # for shuffled mini-batches

        if epoch == decay_epochs[decay_count]:
            curr_lr *= 0.1
            decay_count += 1

        for i in range(num_batches):
            # store current weights and gradients
            if args.save_weights and iterations % args.save_every == 0:
                dsets['all_weights'][chunks_written] = flatten_all(
                    tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
                chunks_written += 1

            # less frequent, larger evals
            if iterations % args.eval_every == 0:
                args.verbose = True if epoch < 3 else False
                eval(sess, model, train_x, train_y, test_x, test_y, args,
                     tb_writer, iterations)

                if args.signed_constant and iterations < args.print_every * 3:  # validate 3 times
                    print('Sanity check: signed constant values')
                    if args.signed_constant_multiplier:
                        print('Note: signed constant multiplier is {}'.format(
                            args.signed_constant_multiplier))
                    if args.dynamic_scaling:
                        print('Note: dynamic signed constant multiplier')
                    for layer in list(model.layers):
                        if 'conv2D' in layer.name or 'fc' in layer.name:
                            #signed_kernel = layer.signed_kernel.eval()
                            signed_kernel = sess.run(
                                layer.kernel, feed_dict={learning_phase(): 0})
                            print(
                                'Layer {} signed kernel shape {}, has unique values {}'
                                .format(layer.name, signed_kernel.shape,
                                        np.unique(signed_kernel).tolist()))

            # current slice for input data
            batch_indices = shuffled_indices[args.train_batch_size *
                                             i:args.train_batch_size * (i + 1)]

            # training
            fetch_dict = {'train_step': model.train_step}
            fetch_dict.update(model.update_dict())
            if iterations % args.log_every == 0:
                fetch_dict.update({'tb': tb_summaries})

            result_train = sess_run_dict(sess,
                                         fetch_dict,
                                         feed_dict={
                                             model.input_images:
                                             train_x[batch_indices],
                                             model.input_labels:
                                             train_y[batch_indices],
                                             model.input_lr:
                                             curr_lr,
                                             learning_phase():
                                             1,
                                             batchnorm_learning_phase():
                                             1
                                         })

            # log to tensorboard
            if tb_writer and iterations % args.log_every == 0:
                tb_writer.add_summary(result_train['tb'], iterations)

            iterations += 1

    # save final weight values
    if args.save_weights:
        dsets['all_weights'][chunks_written] = flatten_all(
            tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

    # save final evals
    if iterations % args.eval_every == 0:

        eval(sess, model, train_x, train_y, test_x, test_y, args, tb_writer,
             iterations)

    if 'mask' in args.arch:
        # for supermask: eval 10 times from different random bernoullies
        testaccs = []
        testlosses = []
        for sample in range(10):

            cur_test_acc, cur_test_loss = eval_on_entire_dataset(
                sess, model, test_x, test_y, args.test_batch_size,
                ('eval_test', iterations), tb_writer)

            testaccs.append(cur_test_acc)
            testlosses.append(cur_test_loss)
        print("all test accs:", testaccs)
        print("all test losses:", testlosses)
        print('final test acc = {:.5f}, test loss = {:.5f}'.format(
            np.mean(testaccs), np.mean(testlosses)))

        percs, ones_all, size_all = [], 0, 0
        for layer in model.trainable_weights:
            mprobs = tf.stop_gradient(tf.nn.sigmoid(layer)).eval()
            num_ones = mprobs.sum()  # expected value
            percs.append(num_ones / mprobs.size)
            ones_all += num_ones
            size_all += mprobs.size
            #nparr = layer.eval()
            #num_ones = (nparr > 0).sum() + 0.5 * (nparr == 0).sum() # expected value
            #percs.append(num_ones / nparr.size)
            #ones_all += num_ones
            #size_all += nparr.size
        print('[Est] percent of 1s in mask (per layer):', percs)
        print('[Est] percent of 1s in mask (total):', ones_all / size_all)

    if args.signed_constant:  # validate in the end
        print('Sanity check: signed constant values')
        if args.dynamic_scaling:
            print('Note: dynamic signed constant multiplier')
        elif args.signed_constant_multiplier:
            print('Note: signed constant multiplier is {}'.format(
                args.signed_constant_multiplier))
        for layer in list(model.layers):
            if 'conv2D' in layer.name or 'fc' in layer.name:
                #signed_kernel = layer.signed_kernel.eval()
                signed_kernel = sess.run(layer.kernel,
                                         feed_dict={learning_phase(): 0})
                print('Layer {} signed kernel shape {}, has unique values {}'.
                      format(layer.name, signed_kernel.shape,
                             np.unique(signed_kernel).tolist()))
コード例 #11
0
ファイル: train.py プロジェクト: tomfisher/CoordConv-1
def main():
    parser = make_standard_parser(
        'Coordconv',
        arch_choices=arch_choices,
        skip_train=True,
        skip_val=True)
    # re-add train and val h5s as optional
    parser.add_argument('--data_h5', type=str,
                        default='./data/rectangle_4_uniform.h5',
                        help='data file in hdf5.')
    parser.add_argument('--x_dim', type=int, default=64,
                        help='x dimension of the output image')
    parser.add_argument('--y_dim', type=int, default=64,
                        help='y dimension of the output image')
    parser.add_argument('--lrpolicy', type=str, default='constant',
                        choices=lr_policy_choices, help='LR policy.')
    parser.add_argument('--lrstepratio', type=float,
                        default=.1, help='LR policy step ratio.')
    parser.add_argument('--lrmaxsteps', type=int, default=5,
                        help='LR policy step ratio.')
    parser.add_argument('--lrstepevery', type=int, default=50,
                        help='LR policy step ratio.')
    parser.add_argument('--filter_size', '-fs', type=int, default=3,
                        help='filter size in deconv network')
    parser.add_argument('--channel_mul', '-mul', type=int, default=2,
        help='Deconv model channel multiplier to make bigger models')
    parser.add_argument('--use_mse_loss', '-mse', action='store_true',
                        help='use mse loss instead of cross entropy')
    parser.add_argument('--use_sigm_loss', '-sig', action='store_true',
                        help='use sigmoid loss instead of cross entropy')
    parser.add_argument('--interm_loss', '-interm', default=None,
        choices=(None, 'softmax', 'mse'),
        help='add intermediate loss to end-to-end painter model')
    parser.add_argument('--no_softmax', '-nosfmx', action='store_true',
                        help='Remove softmax sharpening layer in model')

    args = parser.parse_args()

    if args.lrpolicy == 'step':
        lr_policy = LRPolicyStep(args)
    elif args.lrpolicy == 'valstep':
        lr_policy = LRPolicyValStep(args)
    else:
        lr_policy = LRPolicyConstant(args)

    minibatch_size = args.minibatch
    train_style, val_style = (
        '', '') if args.nocolor else (
        colorama.Fore.BLUE, colorama.Fore.MAGENTA)

    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. Load data or generate data on the fly
    print 'Loading data: {}'.format(args.data_h5)

    if args.arch in ['deconv_classification',
                     'coordconv_classification',
                     'upsample_conv_coords',
                     'upsample_coordconv_coords']:

        # option a: generate data on the fly
        #data = list(itertools.product(range(args.x_dim),range(args.y_dim)))
        # random.shuffle(data)

        #train_test_split = .8
        #val_reps = int(args.x_dim * args.x_dim * train_test_split) // minibatch_size
        #val_size = val_reps * minibatch_size
        #train_end = args.x_dim * args.x_dim - val_size
        #train_x, val_x = np.array(data[:train_end]).astype('int'), np.array(data[train_end:]).astype('int')
        #train_y, val_y = None, None
        #DATA_GEN_ON_THE_FLY = True

        # option b: load the data
        fd = h5py.File(args.data_h5, 'r')

        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_onehots'], dtype=float)  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_onehots'], dtype=float)  # shape (768, 64, 64, 1)
        DATA_GEN_ON_THE_FLY = False

        # number of image channels
        image_c = train_y.shape[-1] if train_y is not None and len(train_y.shape) == 4 else 1

    elif args.arch == 'conv_onehot_image':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(
            fd['train_onehots'],
            dtype=int)  # shape (2368, 64, 64, 1)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(
            fd['val_onehots'],
            dtype=float)  # shape (768, 64, 64, 1)
        val_y = np.array(fd['val_imagegray'], dtype=float) / \
            255.0  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    elif args.arch == 'deconv_rendering':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_imagegray'], dtype=float) / \
            255.0  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    elif args.arch == 'conv_regressor' or args.arch == 'coordconv_regressor':
        fd = h5py.File(args.data_h5, 'r')
        train_y = np.array(
            fd['train_normalized_locations'],
            dtype=float)  # shape (2368, 2)
        # /255.0 # shape (2368, 64, 64, 1)
        train_x = np.array(fd['train_onehots'], dtype=float)
        val_y = np.array(
            fd['val_normalized_locations'],
            dtype=float)  # shape (768, 2)
        val_x = np.array(
            fd['val_onehots'],
            dtype=float)  # shape (768, 64, 64, 1)

        image_c = train_x.shape[-1]

    elif args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
        fd = h5py.File(args.data_h5, 'r')
        train_x = np.array(fd['train_locations'], dtype=int)  # shape (2368, 2)
        train_y = np.array(fd['train_imagegray'],
                           dtype=float) / 255.0  # shape (2368, 64, 64, 1)
        val_x = np.array(fd['val_locations'], dtype=float)  # shape (768, 2)
        val_y = np.array(fd['val_imagegray'], dtype=float) / 255.0  # shape (768, 64, 64, 1)

        # add one-hot anyways to track accuracy etc. even if not used in loss
        train_onehot = np.array(
            fd['train_onehots'],
            dtype=int)  # shape (2368, 64, 64, 1)
        val_onehot = np.array(
            fd['val_onehots'],
            dtype=int)  # shape (768, 64, 64, 1)

        image_c = train_y.shape[-1]

    train_size = train_x.shape[0]
    val_size = val_x.shape[0]

    # 1. CREATE MODEL
    input_coords = tf.placeholder(
        shape=(None,2),
        dtype='float32',
        name='input_coords')  # cast later in model into float
    input_onehot = tf.placeholder(
        shape=(None, args.x_dim, args.y_dim, 1),
        dtype='float32',
        name='input_onehot')
    input_images = tf.placeholder(
        shape=(None, args.x_dim, args.y_dim, image_c),
        dtype='float32',
        name='input_images')

    if args.arch == 'deconv_classification':
        model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim,
                              fs=args.filter_size, mul=args.channel_mul,
                              onthefly=DATA_GEN_ON_THE_FLY,
                              use_mse_loss=args.use_mse_loss,
                              use_sigm_loss=args.use_sigm_loss)

        model.a('input_coords', input_coords)

        if not DATA_GEN_ON_THE_FLY:
            model.a('input_onehot', input_onehot)

        model([input_coords]) if DATA_GEN_ON_THE_FLY else model([input_coords, input_onehot])

    if args.arch == 'conv_regressor':
        regress_type = 'conv_uniform' if 'uniform' in args.data_h5 else 'conv_quarant'
        model = ConvRegressor(l2=args.l2, mul=args.channel_mul,
                              _type=regress_type)
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        # call model on inputs
        model([input_onehot, input_coords])

    if args.arch == 'coordconv_regressor':
        model = ConvRegressor(l2=args.l2, mul=args.channel_mul,
                              _type='coordconv')
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        # call model on inputs
        model([input_onehot, input_coords])

    if args.arch == 'conv_onehot_image':
        model = ConvImagePainter(l2=args.l2, fs=args.filter_size, mul=args.channel_mul,
            use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss,
            version='working')
            # version='simple') # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)
        # call model on inputs
        model([input_onehot, input_images])

    if args.arch == 'deconv_rendering':
        model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim,
                              fs=args.filter_size, mul=args.channel_mul,
                              onthefly=False,
                              use_mse_loss=args.use_mse_loss,
                              use_sigm_loss=args.use_sigm_loss)
        model.a('input_coords', input_coords)
        model.a('input_images', input_images)
        # call model on inputs
        model([input_coords, input_images])

    elif args.arch == 'coordconv_classification':
        model = CoordConvPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            include_r=False,
            mul=args.channel_mul,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss)

        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)

        model([input_coords, input_onehot])
        #raise Exception('Not implemented yet')

    elif args.arch == 'coordconv_rendering':
        model = CoordConvImagePainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            include_r=False,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            interm_loss=args.interm_loss,
            no_softmax=args.no_softmax,
            version='working')
        # version='simple') # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)

        # always input three things to calculate relevant metrics
        model([input_coords, input_onehot, input_images])
    elif args.arch == 'deconv_bottleneck':
        model = DeconvBottleneckPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            interm_loss=args.interm_loss,
            no_softmax=args.no_softmax,
            version='working')  # version='simple' to hack a 9x9 all-ones filter solution
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model.a('input_images', input_images)

        # always input three things to calculate relevant metrics
        model([input_coords, input_onehot, input_images])

    elif args.arch == 'upsample_conv_coords' or args.arch == 'upsample_coordconv_coords':
        _coordconv = True if args.arch == 'upsample_coordconv_coords' else False
        model = UpsampleConvPainter(
            l2=args.l2,
            x_dim=args.x_dim,
            y_dim=args.y_dim,
            mul=args.channel_mul,
            fs=args.filter_size,
            use_mse_loss=args.use_mse_loss,
            use_sigm_loss=args.use_sigm_loss,
            coordconv=_coordconv)
        model.a('input_coords', input_coords)
        model.a('input_onehot', input_onehot)
        model([input_coords, input_onehot])

    print 'All model weights:'
    summarize_weights(model.trainable_weights)
    #print 'Model summary:'
    print 'Another model summary:'
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    # a placeholder for dynamic learning rate
    input_lr = tf.placeholder(tf.float32, shape=[])
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    grads_and_vars = opt.compute_gradients(
        model.loss,
        model.trainable_weights,
        gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    # added to train_ and param_ collections
    add_grads_and_vars_hist_summaries(grads_and_vars)

    summarize_opt(opt)
    print 'LR Policy:', lr_policy

    # add_grad_summaries(grads_and_vars)
    if not args.arch.endswith('regressor'):
        image_summaries_traintest(model.logits)

    if 'input_onehot' in model.named_keys():
        image_summaries_traintest(model.input_onehot)
    if 'input_images' in model.named_keys():
        image_summaries_traintest(model.input_images)
    if 'prob' in model.named_keys():
        image_summaries_traintest(model.prob)
    if 'center_prob' in model.named_keys():
        image_summaries_traintest(model.center_prob)
    if 'center_logits' in model.named_keys():
        image_summaries_traintest(model.center_logits)
    if 'pixelwise_prob' in model.named_keys():
        image_summaries_traintest(model.pixelwise_prob)
    if 'center_logits' in model.named_keys():
        image_summaries_traintest(model.center_logits)
    if 'sharpened_logits' in model.named_keys():
        image_summaries_traintest(model.sharpened_logits)

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running
    # BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (
        args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()

    buddy.tic()    # call if new run OR resumed run

    # Check if special layers are initialized right
    #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0]
    #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0]

    # Initialize any missed vars (e.g. optimization momentum, ... if not
    # loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(
        uninitialized_vars, 'init_missed_vars')
    sess.run(init_missed_vars)
    # Print warnings about any TF vs. Keras shape mismatches
    # warn_misaligned_shapes(model)
    # Make sure all variables, which are model variables, have been
    # initialized (e.g. model params and model running BN means)
    tf_assert_all_init(sess)
    # tf.global_variables_initializer().run()

    # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge

    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN

    train_iters = (train_size) // minibatch_size + \
        int(train_size % minibatch_size > 0)
    if not args.skipval:
        val_iters = (val_size) // minibatch_size + \
            int(val_size % minibatch_size > 0)

    if args.ipy:
        print 'Embed: before train / val loop (Ctrl-D to continue)'
        embed()

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        def do_log_params(ep, it, ii): return True
        def do_log_val(ep, it, ii): return True

        def do_log_train(
            ep,
            it,
            ii): return (
            it < train_iters and it & it -
            1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch,
                buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Forward test on validation set
        if not args.skipval:
            feed_dict = {learning_phase(): 0}
            if 'input_coords' in model.named_keys():
                val_coords = val_y if args.arch.endswith(
                    'regressor') else val_x
                feed_dict.update({model.input_coords: val_coords})

            if 'input_onehot' in model.named_keys():
                # if 'val_onehot' not in locals():
                if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck':
                    if args.arch == 'conv_onehot_image' or args.arch.endswith('regressor'):
                        val_onehot = val_x
                    else:
                        val_onehot = val_y
                feed_dict.update({
                    model.input_onehot: val_onehot,
                })
            if 'input_images' in model.named_keys():
                feed_dict.update({
                    model.input_images: val_images,
                })

            fetch_dict = model.trackable_dict()

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                if test_image_summaries is not None:
                    fetch_dict.update(
                        {'test_image_summaries': test_image_summaries})
                if test_scalar_summaries is not None:
                    fetch_dict.update(
                        {'test_scalar_summaries': test_scalar_summaries})
                if test_histogram_summaries is not None:
                    fetch_dict.update(
                        {'test_histogram_summaries': test_histogram_summaries})

            with WithTimer('sess.run val iter', quiet=not args.verbose):
                result_val = sess_run_dict(
                    sess, fetch_dict, feed_dict=feed_dict)

            buddy.note_list(
                model.trackable_names(), [
                    result_val[k] for k in model.trackable_names()], prefix='val_')
            print (
                '[%5d] [%2d/%2d] val: %s (%.3gs/i)' %
                (buddy.train_iter,
                 buddy.epoch,
                 args.epochs,
                 buddy.epoch_mean_pretty_re(
                     '^val_',
                     style=val_style),
                    toc2()))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer, buddy.train_iter, {
                        'mean_%s' %
                        name: value for name, value in buddy.epoch_mean_list_re('^val_')}, prefix='val')
                if test_image_summaries is not None:
                    image_summary_str = result_val['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = result_val['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = result_val['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            #snap_end = buddy.epoch == args.epochs
            snap_end = lr_policy.train_done(buddy)
            if snap_intermed or snap_end:
                # Snapshot network and buddy
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print 'snappshotted model to', save_path
                with gzip.open('%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)
                # Snapshot evaluation data and metrics
                _, _ = evaluate_net(
                    args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess)

        lr = lr_policy.get_lr(buddy)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print 'Embed: at end of training (Ctrl-D to exit)'
                embed()
            break   # Extra pass at end: just report val stats and skip training

        print '********* at epoch %d, LR is %g' % (buddy.epoch, lr)

        # 3. Train on training set
        if args.shuffletrain:
            train_order = np.random.permutation(train_size)
        tic3()
        for ii in xrange(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            end_idx = min(start_idx + minibatch_size, train_size)

            if args.shuffletrain:  # default true
                batch_x = train_x[sorted(
                    train_order[start_idx:end_idx].tolist())]
                if train_y is not None:
                    batch_y = train_y[sorted(
                        train_order[start_idx:end_idx].tolist())]
                # if 'train_onehot' in locals():
                if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
                    batch_onehot = train_onehot[sorted(
                        train_order[start_idx:end_idx].tolist())]
            else:
                batch_x = train_x[start_idx:end_idx]
                if train_y is not None:
                    batch_y = train_y[start_idx:end_idx]
                # if 'train_onehot' in locals():
                if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
                    batch_onehot = train_onehot[start_idx:end_idx]

            feed_dict = {learning_phase(): 1, input_lr: lr}
            if 'input_coords' in model.named_keys():
                batch_coords = batch_y if args.arch.endswith(
                    'regressor') else batch_x
                feed_dict.update({model.input_coords: batch_coords})
            if 'input_onehot' in model.named_keys():
                # if 'batch_onehot' not in locals():
                # if not (args.arch == 'coordconv_rendering' and
                # args.add_interm_loss):
                if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck':
                    if args.arch == 'conv_onehot_image' or args.arch.endswith(
                            'regressor'):
                        batch_onehot = batch_x
                    else:
                        batch_onehot = batch_y
                feed_dict.update({
                    model.input_onehot: batch_onehot,
                })
            if 'input_images' in model.named_keys():
                feed_dict.update({
                    model.input_images: batch_images,
                })

            fetch_dict = model.trackable_and_update_dict()

            fetch_dict.update({'train_step': train_step})

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update(
                        {'train_histogram_summaries': train_histogram_summaries})
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(
                    sess, fetch_dict, feed_dict=feed_dict)

            buddy.note_weighted_list(
                batch_x.shape[0], model.trackable_names(), [
                    result_train[k] for k in model.trackable_names()], prefix='train_')

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                print (
                    '[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                    (buddy.train_iter,
                     buddy.epoch,
                     args.epochs,
                     buddy.epoch_mean_pretty_re(
                         '^train_',
                         style=train_style),
                        toc2()))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train['train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer, buddy.train_iter, {
                        'batch_%s' %
                        name: value for name, value in buddy.last_list_re('^train_')}, prefix='train')

            if ii > 0 and ii % 100 == 0:
                print '  %d: Average iteration time over last 100 train iters: %.3gs' % (
                    ii, toc3() / 100)
                tic3()

            buddy.inc_train_iter()   # after finished training a mini-batch

        buddy.inc_epoch()   # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer, buddy.train_iter, {
                    'mean_%s' %
                    name: value for name, value in buddy.epoch_mean_list_re('^train_')}, prefix='train')

    print '\nFinal'
    print '%02d:%d val:   %s' % (buddy.epoch,
                                 buddy.train_iter,
                                 buddy.epoch_mean_pretty_re(
                                     '^val_',
                                     style=val_style))
    print '%02d:%d train: %s' % (buddy.epoch,
                                 buddy.train_iter,
                                 buddy.epoch_mean_pretty_re(
                                     '^train_',
                                     style=train_style))

    print '\nEnd of training. Saving evaluation results on whole train and val set.'

    final_tr_metrics, final_va_metrics = evaluate_net(
        args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess)

    print '\nFinal evaluation on whole train and val'
    for name, value in final_tr_metrics.iteritems():
        print 'final_stats_eval train_%s %g' % (name, value)
    for name, value in final_va_metrics.iteritems():
        print 'final_stats_eval val_%s %g' % (name, value)

    print '\nfinal_stats epochs %g' % buddy.epoch
    print 'final_stats iters %g' % buddy.train_iter
    print 'final_stats time %g' % buddy.toc()
    for name, value in buddy.epoch_mean_list_all():
        print 'final_stats %s %g' % (name, value)

    if args.output:
        writer.close()   # Flush and close
コード例 #12
0
ファイル: train.py プロジェクト: tomfisher/CoordConv-1
def evaluate_net(args, buddy, model, train_size, train_x, train_y,
                 val_x, val_y, fd, sess, write_x=True, write_y=True):

    minibatch_size = args.minibatch
    train_iters = (train_size) // minibatch_size + \
        int(train_size % minibatch_size > 0)

    # 0 even for train set; because it's evalutation
    feed_dict_tr = {learning_phase(): 0}
    feed_dict_va = {learning_phase(): 0}

    if args.output:
        final_fetch = {'logits': model.logits}
        if 'prob' in model.named_keys():
            final_fetch.update({'prob': model.prob})
        if 'pixelwise_prob' in model.named_keys():
            final_fetch.update({'pixelwise_prob': model.pixelwise_prob})

        if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck':
            final_fetch.update({
                'center_logits': model.center_logits,
                # 'sharpened_logits': model.sharpened_logits, # or center_prob
                'center_prob': model.center_prob,  # or center_prob
            })

        ff = h5py.File(
            '%s/evaluation_%04d.h5' %
            (args.output, buddy.epoch), 'w')

        # create dataset but write later
        for kk in final_fetch.keys():
            if args.arch.endswith('regressor'):
                ff.create_dataset(kk + '_train', (minibatch_size, 2),
                    maxshape=(train_size, 2), dtype=float, 
                    compression='lzf', chunks=True)
            else:
                ff.create_dataset(kk + '_train',
                    (minibatch_size, args.x_dim, args.y_dim, 1),
                    maxshape=(train_size, args.x_dim, args.y_dim, 1),
                    dtype=float, compression='lzf', chunks=True)

        # create dataset and write immediately
        if write_x:
            ff.create_dataset('inputs_val', data=val_x)
            ff.create_dataset('inputs_train', data=train_x)
        if write_y:
            ff.create_dataset('labels_val', data=val_y)
            ff.create_dataset('labels_train', data=train_y)

    for ii in xrange(train_iters):
        start_idx = ii * minibatch_size
        end_idx = min(start_idx + minibatch_size, train_size)

        if 'input_onehot' in model.named_keys():
            feed_dict_tr.update({model.input_onehot: np.array(
                fd['train_onehots'][start_idx:end_idx], dtype=float)})
            if ii == 0:
                feed_dict_va.update(
                    {model.input_onehot: np.array(fd['val_onehots'], dtype=float)})
                #feed_dict_va.update({model.input_onehot: val_onehot})
        if 'input_images' in model.named_keys():
            feed_dict_tr.update({model.input_images: np.array(
                fd['train_imagegray'][start_idx:end_idx], dtype=float) / 255.0})
            if ii == 0:
                feed_dict_va.update({model.input_images: np.array(
                    fd['val_imagegray'], dtype=float) / 255.0})
                #feed_dict_va.update({model.input_images: val_images})

        if 'input_coords' in model.named_keys():
            if args.arch.endswith('regressor'):
                _loc_keys = (
                    'train_normalized_locations',
                    'val_normalized_locations',
                    'float32')
            else:
                _loc_keys = (
                    'train_locations', 
                    'val_locations', 
                    'int32')
            feed_dict_tr.update({model.input_coords: np.array(
                fd[_loc_keys[0]][start_idx:end_idx], dtype=_loc_keys[2])})
            if ii == 0:
                feed_dict_va.update({model.input_coords: np.array(
                    fd[_loc_keys[1]], dtype=_loc_keys[2])})

        _final_tr_metrics = sess_run_dict(
            sess, model.trackable_dict(), feed_dict=feed_dict_tr)
        _final_tr_metrics['weights'] = end_idx - start_idx

        final_tr_metrics = _final_tr_metrics if ii == 0 else merge_dict_append(
            final_tr_metrics, _final_tr_metrics)

        if args.output:
            if ii == 0:  # do only once
                final_va = sess_run_dict(
                    sess, final_fetch, feed_dict=feed_dict_va)
                for kk in final_fetch.keys():
                    ff.create_dataset(kk + '_val', data=final_va[kk])

            final_tr = sess_run_dict(sess, final_fetch, feed_dict=feed_dict_tr)
            for kk in final_fetch.keys():
                if start_idx > 0:
                    n_samples_ = ff[kk + '_train'].shape[0]
                    ff[kk + '_train'].resize(n_samples_ +
                                             end_idx - start_idx, axis=0)
                ff[kk + '_train'][start_idx:, ...] = final_tr[kk]

    final_va_metrics = sess_run_dict(
        sess, model.trackable_dict(), feed_dict=feed_dict_va)
    final_tr_metrics = average_dict_values(final_tr_metrics)

    if args.output:
        with open('%s/evaluation_%04d_metrics.pkl' % (args.output, buddy.epoch), 'w') as ffmetrics:
            tosave = {'train': final_tr_metrics,
                      'val': final_va_metrics,
                      'time_elapsed': buddy.toc()
                      }
            pickle.dump(tosave, ffmetrics)

        ff.close()
    else:
        print '\nEpoch %d evaluation on whole train and val' % buddy.epoch
        print 'Time elapsed: {}'.format(buddy.toc())
        for name, value in final_tr_metrics.iteritems():
            print 'final_stats_eval train_%s %g' % (name, value)
        for name, value in final_va_metrics.iteritems():
            print 'final_stats_eval val_%s %g' % (name, value)

    return final_tr_metrics, final_va_metrics
コード例 #13
0
def main():
    lr_policy_choices = ('constant', 'step', 'valstep')

    parser = make_standard_parser('Region Proposal Net',
                                  arch_choices=arch_choices,
                                  skip_train=True,
                                  skip_val=True)
    parser.add_argument(
        '--num',
        '-N',
        type=int,
        default=2,
        help='Load the Field-of-MNIST dataset with NUM digits per image.')
    parser.add_argument('--lrpolicy',
                        type=str,
                        default='constant',
                        choices=lr_policy_choices,
                        help='LR policy.')
    parser.add_argument('--lrstepratio',
                        type=float,
                        default=.1,
                        help='LR policy step ratio.')
    parser.add_argument('--lrmaxsteps',
                        type=int,
                        default=5,
                        help='LR policy step ratio.')
    parser.add_argument('--lrstepevery',
                        type=int,
                        default=50,
                        help='LR policy step ratio.')
    parser.add_argument('--clip',
                        action='store_true',
                        help='clip predicted and ground truth boxes.')
    parser.add_argument('--same',
                        action='store_true',
                        help='Use `same` filter instead of `valid` in conv.')
    parser.add_argument('--showbox',
                        action='store_true',
                        help='show moved box during training.')

    args = parser.parse_args()

    if args.lrpolicy == 'step':
        lr_policy = LRPolicyStep(args)
    elif args.lrpolicy == 'valstep':
        lr_policy = LRPolicyValStep(args)
    else:
        lr_policy = LRPolicyConstant(args)

    minibatch_size = 1
    train_style, val_style = ('',
                              '') if args.nocolor else (colorama.Fore.BLUE,
                                                        colorama.Fore.MAGENTA)

    sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu)

    # 0. Load data
    #train_ims, train_pos, train_class, valid_ims, valid_pos, valid_class, _, _, _ = load_tvt_n_per_field(args.num)
    #train_ims = train_ims[:5000]     # (5000, 64, 64, 1)
    #train_pos = train_pos[:5000]     # (5000, 2, 4)
    #valid_ims = valid_ims[:1000]
    #valid_pos = valid_pos[:1000]
    #_ims, _pos, _class, _, _, _, _, _, _ = load_tvt_n_per_field_centercrop(args.num)

    ff = h5py.File('data/field_of_mnist_cropped_64x64_5objs.h5', 'r')

    train_ims = np.array(ff['train_ims'])  # (9000, 64, 64, 1)
    train_pos = np.array(
        ff['train_pos'])  # (9000, 5, 4), parts of boxes may be out of canvas
    train_class = np.array(ff['train_class'])  # (9000, 5)
    valid_ims = np.array(ff['valid_ims'])  # (1000, 64, 64, 1)
    valid_pos = np.array(
        ff['valid_pos'])  # (1000, 5, 4), parts of boxes may be out of canvas
    valid_class = np.array(ff['valid_class'])  # (1000, 5)

    ff.close()

    im_h, im_w, im_c = train_ims.shape[1], train_ims.shape[2], train_ims.shape[
        3]
    train_size = train_ims.shape[0]
    val_size = valid_ims.shape[0]

    print(('Data loaded:\n\timage shape: {}x{}x{}'.format(im_h, im_w, im_c)))
    print(('\ttrain size: {}\n\ttest size: {}'.format(train_size, val_size)))
    print(('\tnumber of objects per image: {}'.format(train_pos.shape[1])))

    ####################
    # RPN prameters
    ####################
    rpn_params = RPNParams(anchors=np.array([(15, 15), (20, 20), (25, 25),
                                             (15, 20), (20, 25), (20, 15),
                                             (25, 20), (15, 25), (25, 15)]),
                           rpn_hidden_dim=32,
                           zero_box_conv=False,
                           weight_init_std=0.01,
                           anchor_scale=1.0)

    bsamp_params = BoxSamplerParams(hi_thresh=0.5,
                                    lo_thresh=0.1,
                                    sample_size=12)

    nms_params = NMSParams(
        nms_thresh=0.8,
        max_proposals=10,
    )

    # 1. CREATE MODEL

    input_images = tf.placeholder(shape=(None, im_h, im_w, im_c),
                                  dtype='float32',
                                  name='input_images')
    input_gtbox = tf.placeholder(shape=(train_pos.shape[1], 4),
                                 dtype='float32',
                                 name='input_gtbox')

    if args.arch == 'rpn_sampler':
        model = RegionProposalSampler(rpn_params,
                                      bsamp_params,
                                      nms_params,
                                      l2=args.l2,
                                      im_h=im_h,
                                      im_w=im_w,
                                      coordconv=False,
                                      clip=args.clip,
                                      filtersame=args.same)
    elif args.arch == 'coord_rpn_sampler':
        model = RegionProposalSampler(rpn_params,
                                      bsamp_params,
                                      nms_params,
                                      l2=args.l2,
                                      im_h=im_h,
                                      im_w=im_w,
                                      coordconv=True,
                                      clip=args.clip,
                                      filtersame=args.same)
    else:
        raise ValueError('Architecture {} unknown'.format(args.arch))

    if args.same:
        anchors = make_anchors_mnist_same(
            (16, 16), minibatch_size,
            rpn_params.anchors)  # (batch, 16, 16, 4k)
        input_anchors = tf.placeholder(shape=(16, 16,
                                              4 * rpn_params.num_anchors),
                                       dtype='float32',
                                       name='input_anchors')
    else:
        anchors = make_anchors_mnist((13, 13), minibatch_size,
                                     rpn_params.anchors)  # (batch, 13, 13, 4k)
        input_anchors = tf.placeholder(shape=(13, 13,
                                              4 * rpn_params.num_anchors),
                                       dtype='float32',
                                       name='input_anchors')
    anchors = anchors[0]

    model.a('input_images', input_images)
    model.a('input_anchors', input_anchors)
    model.a('input_gtbox', input_gtbox)

    model([input_images, input_anchors, input_gtbox])

    print('All model weights:')
    summarize_weights(model.trainable_weights)
    #print 'Model summary:'
    print('Another model summary:')
    model.summarize_named(prefix='  ')
    print_trainable_warnings(model)

    # 2. COMPUTE GRADS AND CREATE OPTIMIZER
    input_lr = tf.placeholder(
        tf.float32, shape=[])  # a placeholder for dynamic learning rate
    if args.opt == 'sgd':
        opt = tf.train.MomentumOptimizer(input_lr, args.mom)
    elif args.opt == 'rmsprop':
        opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom)
    elif args.opt == 'adam':
        opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2)

    grads_and_vars = opt.compute_gradients(
        model.loss,
        model.trainable_weights,
        gate_gradients=tf.train.Optimizer.GATE_GRAPH)
    train_step = opt.apply_gradients(grads_and_vars)
    add_grads_and_vars_hist_summaries(
        grads_and_vars)  # added to train_ and param_ collections

    summarize_opt(opt)
    print(('LR Policy:', lr_policy))

    # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization
    saver = tf.train.Saver(
        max_to_keep=None) if (args.output or args.load) else None
    if args.load:
        ckptfile, miscfile = args.load.split(':')
        # Restore values directly to graph
        saver.restore(sess, ckptfile)
        with gzip.open(miscfile) as ff:
            saved = pickle.load(ff)
            buddy = saved['buddy']
    else:
        buddy = StatsBuddy()

    buddy.tic()  # call if new run OR resumed run

    # Check if special layers are initialized right
    #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0]
    #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0]

    # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint)
    uninitialized_vars = tf_get_uninitialized_variables(sess)
    init_missed_vars = tf.variables_initializer(uninitialized_vars,
                                                'init_missed_vars')
    sess.run(init_missed_vars)
    tf_assert_all_init(sess)

    # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge

    train_histogram_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_histogram')
    train_scalar_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_scalar')
    test_histogram_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_histogram')
    test_scalar_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_scalar')
    param_histogram_summaries = get_collection_intersection_summary(
        'param_collection', 'orig_histogram')
    train_image_summaries = get_collection_intersection_summary(
        'train_collection', 'orig_image')
    test_image_summaries = get_collection_intersection_summary(
        'test_collection', 'orig_image')

    writer = None
    if args.output:
        mkdir_p(args.output)
        writer = tf.summary.FileWriter(args.output, sess.graph)

    # 5. TRAIN

    train_iters = (train_size) // minibatch_size
    if not args.skipval:
        val_iters = (val_size) // minibatch_size

    if args.output:
        show_indices = np.random.permutation(val_size)[:9]
        mkdir_p('{}/figures'.format(args.output))

    if args.ipy:
        print('Embed: before train / val loop (Ctrl-D to continue)')
        embed()

    while buddy.epoch < args.epochs + 1:
        # How often to log data
        do_log_params = lambda ep, it, ii: True
        do_log_val = lambda ep, it, ii: True
        do_log_train = lambda ep, it, ii: (
            it < train_iters and it & it - 1 == 0 or it >= train_iters and it %
            train_iters == 0)  # Log on powers of two then every epoch

        # 0. Log params
        if args.output and do_log_params(
                buddy.epoch, buddy.train_iter,
                0) and param_histogram_summaries is not None:
            params_summary_str, = sess.run([param_histogram_summaries])
            writer.add_summary(params_summary_str, buddy.train_iter)

        # 1. Forward test on validation set
        if not args.skipval:
            for ii in range(val_iters):
                tic2()
                start_idx = ii * minibatch_size
                end_idx = min(start_idx + minibatch_size, val_size)
                if not end_idx > start_idx:
                    continue

                feed_dict = {
                    model.input_images: valid_ims[start_idx:end_idx],
                    model.input_anchors: anchors,
                    model.input_gtbox: valid_pos[start_idx:end_idx][0],
                    learning_phase(): 0
                }

                fetch_dict = model.trackable_dict()

                if args.output and do_log_val(buddy.epoch, buddy.train_iter,
                                              0):
                    if test_image_summaries is not None:
                        fetch_dict.update(
                            {'test_image_summaries': test_image_summaries})
                    if test_scalar_summaries is not None:
                        fetch_dict.update(
                            {'test_scalar_summaries': test_scalar_summaries})
                    if test_histogram_summaries is not None:
                        fetch_dict.update({
                            'test_histogram_summaries':
                            test_histogram_summaries
                        })

                with WithTimer('sess.run val iter', quiet=not args.verbose):
                    result_val = sess_run_dict(sess,
                                               fetch_dict,
                                               feed_dict=feed_dict)

                ## DEBUG
                ## dynamic p_size and n_size, shouldn slightly very every sample
                #if ii > 0 and ii % 100 == 0:
                #    print 'VALIDATION --- '
                #    print sess.run(model.p_size, feed_dict=feed_dict)
                #    print sess.run(model.n_size, feed_dict=feed_dict)
                ## END DEBUG

                buddy.note_weighted_list(
                    minibatch_size,
                    model.trackable_names(),
                    [result_val[k] for k in model.trackable_names()],
                    prefix='val_')

            # Done all val set
            print(('[%5d] [%2d/%2d] val: %s (%.3gs/i)' %
                   (buddy.train_iter, buddy.epoch, args.epochs,
                    buddy.epoch_mean_pretty_re('^val_',
                                               style=val_style), toc2())))

            if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0):
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'mean_%s' % name: value
                        for name, value in buddy.epoch_mean_list_re('^val_')
                    },
                    prefix='val')
                if test_image_summaries is not None:
                    image_summary_str = result_val['test_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                if test_scalar_summaries is not None:
                    scalar_summary_str = result_val['test_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if test_histogram_summaries is not None:
                    hist_summary_str = result_val['test_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)

        # show some boxes
        if args.showbox:  #and (valid_losses[epoch]/previous_best < 1- args.thresh):
            show_indices = [55, 555, 678]
            for show_idx in show_indices:
                [pos_box, pos_score, neg_box, neg_score] = sess.run(
                    [
                        model.pos_box, model.pos_score, model.neg_box,
                        model.neg_score
                    ],
                    feed_dict={
                        model.input_images: valid_ims[show_idx:show_idx + 1],
                        model.input_anchors: anchors,
                        model.input_gtbox: valid_pos[show_idx],
                        learning_phase(): 0
                    })
                subplot(1, 3, show_indices.index(show_idx) + 1)
                #plot_boxes_pos_neg(valid_ims[show_idx], valid_pos[show_idx], pos_box, neg_box)
                plot_pos_boxes(valid_ims[show_idx],
                               valid_pos[show_idx],
                               pos_box,
                               pos_score,
                               showlabel=False)
            show()

        if args.output:
            switch_backend('Agg')
            plot_fetch_dict = {
                'pos_box': model.pos_box,
                'pos_score': model.pos_score,
                'neg_box': model.neg_box,
                'neg_score': model.neg_score,
                'nms_boxes': model.nms_boxes,
                'nms_scores': model.nms_scores,
            }

            #fig1, ax1 = subplots(3,3)  # plot train boxes
            #fig2, ax2 = subplots(3,3)  # plot test/nms boxes
            for cc, show_idx in enumerate(show_indices, 1):
                feed_dict = {
                    model.input_images: valid_ims[show_idx:show_idx + 1],
                    model.input_anchors: anchors,
                    model.input_gtbox: valid_pos[show_idx],
                    learning_phase(): 0
                }
                result_plots = sess_run_dict(sess,
                                             plot_fetch_dict,
                                             feed_dict=feed_dict)
                fig1 = figure(1)
                subplot(3, 3, cc)
                plot_boxes_pos_neg(valid_ims[show_idx], valid_pos[show_idx],
                                   result_plots['pos_box'],
                                   result_plots['neg_box'])
                fig2 = figure(2)
                subplot(3, 3, cc)
                #plot_pos_boxes(valid_ims[show_idx], valid_pos[show_idx], result_plots['nms_boxes'], result_plots['nms_scores'], showlabel=False)
                # normalize scores between 0 and 5, to be used as line width
                _score_as_lw = 5 * (result_plots['nms_scores'] -
                                    result_plots['nms_scores'].min()) / (
                                        result_plots['nms_scores'].max() -
                                        result_plots['nms_scores'].min())
                plot_pos_boxes_thickness(valid_ims[show_idx],
                                         valid_pos[show_idx],
                                         result_plots['nms_boxes'],
                                         result_plots['nms_scores'])

            fig1.set_size_inches(10, 10)
            fig1.savefig('{}/figures/pos_neg_train_box_epoch_{}.png'.format(
                args.output, buddy.epoch),
                         dpi=100)
            fig2.set_size_inches(10, 10)
            fig2.savefig('{}/figures/nms_test_box_epoch_{}.png'.format(
                args.output, buddy.epoch),
                         dpi=100)

            # plot test/nms boxes
            fig, _ = subplots()

        # 2. Possiby Snapshot, possibly quit
        if args.output and args.snapshot_to and args.snapshot_every:
            snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0
            #snap_end = buddy.epoch == args.epochs
            snap_end = lr_policy.train_done(buddy)
            if snap_intermed or snap_end:
                # Snapshot network and buddy
                save_path = saver.save(
                    sess, '%s/%s_%04d.ckpt' %
                    (args.output, args.snapshot_to, buddy.epoch))
                print(('snappshotted model to', save_path))
                with gzip.open(
                        '%s/%s_misc_%04d.pkl.gz' %
                    (args.output, args.snapshot_to, buddy.epoch), 'w') as ff:
                    saved = {'buddy': buddy}
                    pickle.dump(saved, ff)

        lr = lr_policy.get_lr(buddy)

        if buddy.epoch == args.epochs:
            if args.ipy:
                print('Embed: at end of training (Ctrl-D to exit)')
                embed()
            break  # Extra pass at end: just report val stats and skip training

        print(('********* at epoch %d, LR is %g' % (buddy.epoch, lr)))

        # 3. Train on training set
        if args.shuffletrain:
            train_order = np.random.permutation(train_size)
        tic3()
        for ii in range(train_iters):
            tic2()
            start_idx = ii * minibatch_size
            end_idx = min(start_idx + minibatch_size, train_size)

            if not end_idx > start_idx:
                continue

            if args.shuffletrain:  # default true
                batch_ims = train_ims[sorted(
                    train_order[start_idx:end_idx].tolist())]
                batch_pos = train_pos[sorted(
                    train_order[start_idx:end_idx].tolist())]
            else:
                batch_ims = train_ims[start_idx:end_idx]
                batch_pos = train_pos[start_idx:end_idx]

            feed_dict = {
                model.input_images: batch_ims,
                model.input_anchors: anchors,
                model.input_gtbox: batch_pos[0],
                learning_phase(): 1,
                input_lr: lr
            }

            fetch_dict = model.trackable_and_update_dict()

            fetch_dict.update({'train_step': train_step})

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    fetch_dict.update({
                        'train_histogram_summaries':
                        train_histogram_summaries
                    })
                if train_scalar_summaries is not None:
                    fetch_dict.update(
                        {'train_scalar_summaries': train_scalar_summaries})
                if train_image_summaries is not None:
                    fetch_dict.update(
                        {'train_image_summaries': train_image_summaries})

            with WithTimer('sess.run train iter', quiet=not args.verbose):
                result_train = sess_run_dict(sess,
                                             fetch_dict,
                                             feed_dict=feed_dict)

            buddy.note_weighted_list(
                minibatch_size,
                model.trackable_names(),
                [result_train[k] for k in model.trackable_names()],
                prefix='train_')

            if do_log_train(buddy.epoch, buddy.train_iter, ii):
                print(('[%5d] [%2d/%2d] train: %s (%.3gs/i)' %
                       (buddy.train_iter, buddy.epoch, args.epochs,
                        buddy.epoch_mean_pretty_re(
                            '^train_', style=train_style), toc2())))

            if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii):
                if train_histogram_summaries is not None:
                    hist_summary_str = result_train[
                        'train_histogram_summaries']
                    writer.add_summary(hist_summary_str, buddy.train_iter)
                if train_scalar_summaries is not None:
                    scalar_summary_str = result_train['train_scalar_summaries']
                    writer.add_summary(scalar_summary_str, buddy.train_iter)
                if train_image_summaries is not None:
                    image_summary_str = result_train['train_image_summaries']
                    writer.add_summary(image_summary_str, buddy.train_iter)
                log_scalars(
                    writer,
                    buddy.train_iter, {
                        'batch_%s' % name: value
                        for name, value in buddy.last_list_re('^train_')
                    },
                    prefix='train')

            if ii > 0 and ii % 100 == 0:
                print((
                    '  %d: Average iteration time over last 100 train iters: %.3gs'
                    % (ii, toc3() / 100)))
                tic3()

            buddy.inc_train_iter()  # after finished training a mini-batch

        buddy.inc_epoch()  # after finished training whole pass through set

        if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0):
            log_scalars(
                writer,
                buddy.train_iter, {
                    'mean_%s' % name: value
                    for name, value in buddy.epoch_mean_list_re('^train_')
                },
                prefix='train')

    print('\nFinal')
    print(('%02d:%d val:   %s' %
           (buddy.epoch, buddy.train_iter,
            buddy.epoch_mean_pretty_re('^val_', style=val_style))))
    print(('%02d:%d train: %s' %
           (buddy.epoch, buddy.train_iter,
            buddy.epoch_mean_pretty_re('^train_', style=train_style))))

    print(
        '\nEnd of training. Saving evaluation results on whole train and val set.'
    )

    if args.output:
        writer.close()  # Flush and close