def main(): parser = make_standard_parser( 'Train a GAN model on simple square images or Clevr two-object color images', arch_choices=arch_choices, skip_train=True, skip_val=True) parser.add_argument('--z_dim', type=int, default=10, help='Dimension of noise vector') parser.add_argument('--lr2', type=float, default=None, help='learning rate for generator') parser.add_argument('--feature_match', '-fm', action='store_true', help='use feature matching loss for generator.') parser.add_argument( '--feature_match_loss_weight', '-fmalpha', type=float, default=1.0, help='weight on the feature matching loss for generator.') parser.add_argument( '--pairedz', action='store_true', help='If True, pair the same z with a training batch each epoch') parser.add_argument( '--eval-train-every', type=int, default=0, help='evaluate whole training set every N epochs. 0 to disable.') args = parser.parse_args() args.skipval = True minibatch_size = args.minibatch train_style, val_style = ('', '') if args.nocolor else (colorama.Fore.BLUE, colorama.Fore.MAGENTA) evaltrain_style = '' if args.nocolor or args.eval_train_every <= 0 else colorama.Fore.CYAN black_divider = True if args.arch.startswith('clevr') else False # Get a TF session and set numpy and TF seeds sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu) # 0. LOAD DATA if args.arch.startswith('simple'): fd = h5py.File('data/rectangle_4_uniform.h5', 'r') train_x = np.array(fd['train_imagegray'], dtype=float) / 255.0 # shape (2368, 64, 64, 1) val_x = np.array(fd['val_imagegray'], dtype=float) / 255.0 # shape (768, 64, 64, 1) train_x = np.concatenate((train_x, val_x), axis=0) # shape (3136, 64, 64, 1) elif args.arch.startswith('clevr'): (train_x, val_x) = load_sort_of_clevr() # shape (50000, 64, 64, 3) train_x = np.concatenate((train_x, val_x), axis=0) else: raise Exception('Unknown network architecture: %s' % args.arch) print('Train data loaded: {} images, size {}'.format( train_x.shape[0], train_x.shape[1:])) #print 'Val data loaded: {} images, size {}'.format(val_x.shape[0], val_x.shape[1:]) #print 'Label dimension: {}'.format(val_y.shape[1:]) # 1. CREATE MODEL assert len(train_x.shape) == 4, "image data must be of 4 dimensions" image_h, image_w, image_c = train_x.shape[1], train_x.shape[ 2], train_x.shape[3] model = build_model(args, image_h, image_w, image_c) print('All model weights:') summarize_weights(model.trainable_weights) print('Model summary:') # model.summary() # TOREPLACE print('Another model summary:') model.summarize_named(prefix=' ') print_trainable_warnings(model) # 2. COMPUTE GRADS AND CREATE OPTIMIZER lr_gen = args.lr2 if args.lr2 else args.lr if args.opt == 'sgd': d_opt = tf.train.MomentumOptimizer(args.lr, args.mom) g_opt = tf.train.MomentumOptimizer(lr_gen, args.mom) elif args.opt == 'rmsprop': d_opt = tf.train.RMSPropOptimizer(args.lr, momentum=args.mom) g_opt = tf.train.RMSPropOptimizer(lr_gen, momentum=args.mom) elif args.opt == 'adam': d_opt = tf.train.AdamOptimizer(args.lr, args.beta1, args.beta2) g_opt = tf.train.AdamOptimizer(lr_gen, args.beta1, args.beta2) # Optimize w.r.t all trainable params in the model all_vars = model.trainable_variables d_vars = [var for var in all_vars if 'discriminator' in var.name] g_vars = [var for var in all_vars if 'generator' in var.name] d_grads_and_vars = d_opt.compute_gradients( model.d_loss, d_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH) d_train_step = d_opt.apply_gradients(d_grads_and_vars) g_grads_and_vars = g_opt.compute_gradients( model.g_loss, g_vars, gate_gradients=tf.train.Optimizer.GATE_GRAPH) g_train_step = g_opt.apply_gradients(g_grads_and_vars) hist_summaries_traintest(model.d_real_logits, model.d_fake_logits) add_grads_and_vars_hist_summaries(d_grads_and_vars) add_grads_and_vars_hist_summaries(g_grads_and_vars) image_summaries_traintest(model.fake_images) # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running # BN means, optimization momentum, ...) and then finalize initialization saver = tf.train.Saver( max_to_keep=None) if (args.output or args.load) else None if args.load: ckptfile, miscfile = args.load.split(':') # Restore values directly to graph saver.restore(sess, ckptfile) with gzip.open(miscfile) as ff: saved = pickle.load(ff) buddy = saved['buddy'] else: buddy = StatsBuddy(pretty_replaces=[('evaltrain_', ''), ( 'eval', '')]) if args.eval_train_every > 0 else StatsBuddy() buddy.tic() # call if new run OR resumed run tf.global_variables_initializer().run() # 4. SETUP TENSORBOARD LOGGING param_histogram_summaries = get_collection_intersection_summary( 'param_collection', 'orig_histogram') train_histogram_summaries = get_collection_intersection_summary( 'train_collection', 'orig_histogram') train_scalar_summaries = get_collection_intersection_summary( 'train_collection', 'orig_scalar') test_histogram_summaries = get_collection_intersection_summary( 'test_collection', 'orig_histogram') test_scalar_summaries = get_collection_intersection_summary( 'test_collection', 'orig_scalar') train_image_summaries = get_collection_intersection_summary( 'train_collection', 'orig_image') test_image_summaries = get_collection_intersection_summary( 'test_collection', 'orig_image') writer = None if args.output: mkdir_p(args.output) writer = tf.summary.FileWriter(args.output, sess.graph) # 5. TRAIN train_iters = (train_x.shape[0]) // minibatch_size if not args.skipval: val_iters = (val_x.shape[0]) // minibatch_size if args.ipy: print('Embed: before train / val loop (Ctrl-D to continue)') embed() # 2. use same noise, eval on 100 samples and save G(z), np.random.seed() eval_batch_size = 100 eval_z = np.random.uniform(-1, 1, size=(eval_batch_size, args.z_dim)) while buddy.epoch < args.epochs + 1: # How often to log data def do_log_params(ep, it, ii): return True def do_log_val(ep, it, ii): return True def do_log_train(ep, it, ii): return (it < train_iters and it & it - 1 == 0 or it >= train_iters and it % train_iters == 0 ) # Log on powers of two then every epoch # 0. Log params if args.output and do_log_params( buddy.epoch, buddy.train_iter, 0) and param_histogram_summaries is not None: params_summary_str, = sess.run([param_histogram_summaries]) writer.add_summary(params_summary_str, buddy.train_iter) # 1. Evaluate generator by showing random generated results # Evaluate descriminator by showing seeing correct rate on generated and real (hold-out) results #assert(args.skipval), "only support training now" if not args.skipval: tic2() # use different noise, eval on larger number of samples and get # correct rate np.random.seed() val_z = np.random.uniform(-1, 1, size=(val_x.shape[0], args.z_dim)) with WithTimer('sess.run val iter', quiet=not args.verbose): feed_dict = { model.input_images: val_x, model.input_noise: val_z, learning_phase(): 0 } if 'input_labels' in model.named_keys(): feed_dict.update({model.input_labels: val_y}) val_corr_fake_bn0, val_corr_real_bn0 = sess.run( [model.correct_fake, model.correct_real], feed_dict=feed_dict) feed_dict[learning_phase()] = 1 val_corr_fake_bn1, val_corr_real_bn1 = sess.run( [model.correct_fake, model.correct_real], feed_dict=feed_dict) if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): fetch_dict = {} if test_image_summaries is not None: fetch_dict.update( {'test_image_summaries': test_image_summaries}) if test_scalar_summaries is not None: fetch_dict.update( {'test_scalar_summaries': test_scalar_summaries}) if test_histogram_summaries is not None: fetch_dict.update( {'test_histogram_summaries': test_histogram_summaries}) if fetch_dict: summary_strs = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_list([ 'correct_real_bn0', 'correct_fake_bn0', 'correct_real_bn1', 'correct_fake_bn1' ], [ val_corr_real_bn0, val_corr_fake_bn0, val_corr_real_bn1, val_corr_fake_bn1 ], prefix='val_') print(('%3d (ep %d) val: %s (%.3gs/ep)' % (buddy.train_iter, buddy.epoch, buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2()))) if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_') }, prefix='buddy') if test_image_summaries is not None: image_summary_str = summary_strs['test_image_summaries'] writer.add_summary(image_summary_str, buddy.train_iter) if test_scalar_summaries is not None: scalar_summary_str = summary_strs['test_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) if test_histogram_summaries is not None: hist_summary_str = summary_strs['test_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) # In addition, evalutate 1000 more images np.random.seed() eval_more = np.random.uniform(-1, 1, size=(1000, args.z_dim)) feed_dict2 = { # (100,-) generated outside of loop to keep the same every round model.input_noise: eval_z, learning_phase(): 0 } eval_samples_bn0 = sess.run(model.fake_images, feed_dict=feed_dict2) feed_dict2[learning_phase()] = 1 eval_samples_bn1 = sess.run(model.fake_images, feed_dict=feed_dict2) # feed in 10 times because coordconv cannot handle too big of a batch for cc in range(10): eval_z2 = eval_more[cc * 100:(cc + 1) * 100, :] _eval_more_samples = sess.run( model.fake_images, feed_dict={ model.input_noise: eval_z2, # (1000,-) learning_phase(): 0 }) eval_more_samples = _eval_more_samples if cc == 0 else np.concatenate( (eval_more_samples, _eval_more_samples), axis=0) if args.output: mkdir_p('{}/fake_images'.format(args.output)) # eval_samples_bn*: e.g. (100, 64, 64, 3) save_images(eval_samples_bn0, [10, 10], '{}/fake_images/g_out_bn0_epoch_{}_iter_{}.png'.format( args.output, buddy.epoch, buddy.train_iter), black_divider=black_divider) save_images(eval_samples_bn1, [10, 10], '{}/fake_images/g_out_bn1_epoch_{}.png'.format( args.output, buddy.epoch), black_divider=black_divider) save_average_image( eval_more_samples, '{}/fake_images/g_out_averaged_epoch_{}_iter_{}.png'.format( args.output, buddy.epoch, buddy.train_iter)) # 2. Possiby Snapshot, possibly quit if args.output and args.snapshot_to and args.snapshot_every: snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0 snap_end = buddy.epoch == args.epochs if snap_intermed or snap_end: # Snapshot save_path = saver.save( sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch)) print('snappshotted model to', save_path) with gzip.open( '%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff: saved = {'buddy': buddy} pickle.dump(saved, ff) # snapshot sampled images too ff = h5py.File( '%s/sampled_images_%04d.h5' % (args.output, buddy.epoch), 'w') ff.create_dataset('eval_samples_bn0', data=eval_samples_bn0) ff.create_dataset('eval_samples_bn1', data=eval_samples_bn1) ff.create_dataset('eval_z', data=eval_z) ff.create_dataset('eval_z_more', data=eval_more) ff.create_dataset('eval_more_samples', data=eval_more_samples) ff.close() # 2. Possiby evaluate the training set if args.eval_train_every > 0: if buddy.epoch % args.eval_train_every == 0: tic2() for ii in range(train_iters): start_idx = ii * minibatch_size if args.pairedz: np.random.seed(args.seed + ii) else: np.random.seed() batch_z = np.random.uniform(-1, 1, size=(minibatch_size, args.z_dim)) batch_x = train_x[start_idx:start_idx + minibatch_size] batch_y = train_y[start_idx:start_idx + minibatch_size] feed_dict = { model.input_images: batch_x, # model.input_labels: batch_y, model.input_noise: batch_z, learning_phase(): 0, } if 'input_labels' in model.named_keys(): feed_dict.update({model.input_labels: val_y}) fetch_dict = model.trackable_dict() result_eval_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names(), [ result_eval_train[k] for k in model.trackable_names() ], prefix='evaltrain_bn0_') feed_dict = { model.input_images: batch_x, # model.input_labels: batch_y, model.input_noise: batch_z, learning_phase(): 1, } if 'input_labels' in model.named_keys(): feed_dict.update({model.input_labels: val_y}) result_eval_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names(), [ result_eval_train[k] for k in model.trackable_names() ], prefix='evaltrain_bn1_') if args.output: log_scalars(writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re( '^evaltrain_bn0_') }, prefix='buddy') log_scalars(writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re( '^evaltrain_bn1_') }, prefix='buddy') if args.output: log_scalars(writer, buddy.epoch, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^evaltrain_bn0_') }, prefix='buddy') log_scalars(writer, buddy.epoch, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^evaltrain_bn1_') }, prefix='buddy') print( ('%3d (ep %d) evaltrain: %s (%.3gs/ep)' % (buddy.train_iter, buddy.epoch, buddy.epoch_mean_pretty_re( '^evaltrain_bn0_', style=evaltrain_style), toc2()))) print( ('%3d (ep %d) evaltrain: %s (%.3gs/ep)' % (buddy.train_iter, buddy.epoch, buddy.epoch_mean_pretty_re( '^evaltrain_bn1_', style=evaltrain_style), toc2()))) if buddy.epoch == args.epochs: if args.ipy: print('Embed: at end of training (Ctrl-D to exit)') embed() break # Extra pass at end: just report val stats and skip training # 3. Train on training set if args.shuffletrain: train_order = np.random.permutation(train_x.shape[0]) train_order2 = np.random.permutation(train_x.shape[0]) tic3() for ii in range(train_iters): tic2() start_idx = ii * minibatch_size if args.pairedz: np.random.seed(args.seed + ii) else: np.random.seed() batch_z = np.random.uniform(-1, 1, size=(minibatch_size, args.z_dim)) if args.shuffletrain: #batch_x = train_x[train_order[start_idx:start_idx + minibatch_size]] batch_x = train_x[sorted(train_order[start_idx:start_idx + minibatch_size].tolist())] if args.feature_match: assert args.shuffletrain, "feature matching loss requires shuffle train" batch_x2 = train_x[sorted( train_order2[start_idx:start_idx + minibatch_size].tolist())] if 'input_labels' in model.named_keys(): batch_y = train_y[sorted( train_order[start_idx:start_idx + minibatch_size].tolist())] else: batch_x = train_x[start_idx:start_idx + minibatch_size] if 'input_labels' in model.named_keys(): batch_y = train_y[start_idx:start_idx + minibatch_size] feed_dict = { model.input_images: batch_x, # model.input_labels: batch_y, model.input_noise: batch_z, learning_phase(): 1, } if 'input_labels' in model.named_keys(): feed_dict.update({model.input_labels: batch_y}) if 'input_images2' in model.named_keys(): feed_dict.update({model.input_images2: batch_x2}) fetch_dict = model.trackable_and_update_dict() if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if train_histogram_summaries is not None: fetch_dict.update({ 'train_histogram_summaries': train_histogram_summaries }) if train_scalar_summaries is not None: fetch_dict.update( {'train_scalar_summaries': train_scalar_summaries}) if train_image_summaries is not None: fetch_dict.update( {'train_image_summaries': train_image_summaries}) with WithTimer('sess.run train iter', quiet=not args.verbose): result_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) # if result_train['d_loss'] < result_train['g_loss']: # #print 'Only train G' # sess.run(g_train_step, feed_dict=feed_dict) # else: # #print 'Train both D and G' # sess.run(d_train_step, feed_dict=feed_dict) # sess.run(g_train_step, feed_dict=feed_dict) # sess.run(g_train_step, feed_dict=feed_dict) sess.run(d_train_step, feed_dict=feed_dict) sess.run(g_train_step, feed_dict=feed_dict) sess.run(g_train_step, feed_dict=feed_dict) if do_log_train(buddy.epoch, buddy.train_iter, ii): buddy.note_weighted_list( batch_x.shape[0], model.trackable_names(), [result_train[k] for k in model.trackable_names()], prefix='train_') print(('[%5d] [%2d/%2d] train: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch, args.epochs, buddy.epoch_mean_pretty_re( '^train_', style=train_style), toc2()))) if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if train_histogram_summaries is not None: hist_summary_str = result_train[ 'train_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_scalar_summaries is not None: scalar_summary_str = result_train['train_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) if train_image_summaries is not None: image_summary_str = result_train['train_image_summaries'] writer.add_summary(image_summary_str, buddy.train_iter) log_scalars( writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re('^train_') }, prefix='buddy') if ii > 0 and ii % 100 == 0: print( ' %d: Average iteration time over last 100 train iters: %.3gs' % (ii, toc3() / 100)) tic3() buddy.inc_train_iter() # after finished training a mini-batch buddy.inc_epoch() # after finished training whole pass through set if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^train_') }, prefix='buddy') print('\nFinal') print('%02d:%d val: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style))) print('%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style))) print('\nfinal_stats epochs %g' % buddy.epoch) print('final_stats iters %g' % buddy.train_iter) print('final_stats time %g' % buddy.toc()) for name, value in buddy.epoch_mean_list_all(): print('final_stats %s %g' % (name, value)) if args.output: writer.close() # Flush and close
def main(): parser = make_standard_parser( 'Distributed Training of Direct or RProj model on Imagenet', arch_choices=arch_choices) parser.add_argument('--vsize', type=int, default=100, help='Dimension of intrinsic parmaeter space.') parser.add_argument('--minibatch', '--mb', type=int, default=256, help='Size of minibatch.') parser.add_argument('--denseproj', action='store_true', help='Use a dense projection.') parser.add_argument('--sparseproj', action='store_true', help='Use a sparse projection.') parser.add_argument('--fastfoodproj', action='store_true', help='Use a fastfood projection.') args = parser.parse_args() minibatch_size = args.minibatch train_style, val_style = ('', '') if args.nocolor else (colorama.Fore.BLUE, colorama.Fore.MAGENTA) n_proj_specified = sum( [args.denseproj, args.sparseproj, args.fastfoodproj]) if args.arch in arch_choices_projected: assert n_proj_specified == 1, 'Arch "%s" requires projection. Specify exactly one of {denseproj, sparseproj, fastfoodproj} options.' % args.arch else: assert n_proj_specified == 0, 'Arch "%s" does not require projection, so do not specify any of {denseproj, sparseproj, fastfoodproj} options.' % args.arch if args.denseproj: proj_type = 'dense' elif args.sparseproj: proj_type = 'sparse' else: proj_type = 'fastfood' # Initialize Horovod hvd.init() #minibatch_size = 256 worker_minibatch_size = minibatch_size / hvd.size() # Pin GPU to be used to process local rank (one GPU per process) config = tf.ConfigProto() config.gpu_options.allow_growth = True #config.log_device_placement=True my_rank = hvd.local_rank() print "I am worker ", my_rank config.gpu_options.visible_device_list = str(hvd.local_rank()) K.set_session(tf.Session(config=config)) # Adjust number of epochs based on number of GPUs. epochs = args.epochs # Add hook to broadcast variables from rank 0 to all other processes during # initialization. #hooks = [hvd.BroadcastGlobalVariablesHook(0)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. # 0. LOAD DATA train_h5 = h5py.File(args.train_h5, 'r') train_x = train_h5['images'] train_y = train_h5['labels'] val_h5 = h5py.File(args.val_h5, 'r') val_x = val_h5['images'] val_y = val_h5['labels'] # load into memory if less than 1 GB if train_x.size * 4 + val_x.size * 4 < 1e9: train_x, train_y = np.array(train_x), np.array(train_y) val_x, val_y = np.array(val_x), np.array(val_y) # 1. CREATE MODEL extra_feed_dict = {} with WithTimer('Make model'): if args.arch == 'alexnet_dir': shift_in = np.array([104, 117, 123], dtype='float32') model = build_alexnet_direct(weight_decay=args.l2, shift_in=shift_in) randmirrors = True randcrops = True cropsize = (227, 227) elif args.arch == 'squeeze_dir': model = build_squeezenet_direct(weight_decay=args.l2, shift_in=np.array([104, 117, 123])) randmirrors = True randcrops = True cropsize = (224, 224) elif args.arch == 'alexnet': if proj_type == 'fastfood': model = build_alexnet_fastfood(weight_decay=args.l2, shift_in=np.array( [104, 117, 123]), vsize=args.vsize) else: raise Exception('not implemented') randmirrors = True randcrops = True cropsize = (227, 227) elif args.arch == 'squeeze': if proj_type == 'fastfood': model = build_squeezenet_fastfood(weight_decay=args.l2, shift_in=np.array( [104, 117, 123]), vsize=args.vsize) else: raise Exception('not implemented') randmirrors = True randcrops = True cropsize = (224, 224) else: raise Exception('Unknown network architecture: %s' % args.arch) if my_rank == 0: print 'All model weights:' summarize_weights(model.trainable_weights) print 'Model summary:' model.summary() model.print_trainable_warnings() lr = args.lr if args.opt == 'sgd': opt = tf.train.MomentumOptimizer(lr, args.mom) elif args.opt == 'rmsprop': opt = tf.train.RMSPropOptimizer(lr, momentum=args.mom) elif args.opt == 'adam': opt = tf.train.AdamOptimizer(lr, args.beta1, args.beta2) # Add Horovod Distributed Optimizer opt = hvd.DistributedOptimizer(opt) global_step = tf.contrib.framework.get_or_create_global_step() train_step = opt.minimize(model.v.loss, global_step=global_step) sess = K.get_session() sess.run(hvd.broadcast_global_variables(0)) # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization saver = tf.train.Saver( max_to_keep=None) if (args.output or args.load) else None if args.load: ckptfile, miscfile = args.load.split(':') # Restore values directly to graph saver.restore(sess, ckptfile) with gzip.open(miscfile) as ff: saved = pickle.load(ff) buddy = saved['buddy'] else: buddy = StatsBuddy() buddy.tic() # call if new run OR resumed run # 4. SETUP TENSORBOARD LOGGING train_histogram_summaries = get_collection_intersection_summary( 'train_collection', 'orig_histogram') train_scalar_summaries = get_collection_intersection_summary( 'train_collection', 'orig_scalar') val_histogram_summaries = get_collection_intersection_summary( 'val_collection', 'orig_histogram') val_scalar_summaries = get_collection_intersection_summary( 'val_collection', 'orig_scalar') param_histogram_summaries = get_collection_intersection_summary( 'param_collection', 'orig_histogram') writer = None if args.output: mkdir_p(args.output) writer = tf.summary.FileWriter(args.output, sess.graph) ## 5. TRAIN train_iters = (train_y.shape[0] - 1) / minibatch_size val_iters = (val_y.shape[0] - 1) / minibatch_size impreproc = ImagePreproc() if args.ipy: print 'Embed: before train / val loop (Ctrl-D to continue)' embed() while buddy.epoch < args.epochs + 1: # How often to log data do_log_params = lambda ep, it, ii: True do_log_val = lambda ep, it, ii: True do_log_train = lambda ep, it, ii: ( it < train_iters and it & it - 1 == 0 or it >= train_iters and it % train_iters == 0) # Log on powers of two then every epoch # 0. Log params if args.output and do_log_params( buddy.epoch, buddy.train_iter, 0) and param_histogram_summaries is not None: params_summary_str, = sess.run([param_histogram_summaries]) writer.add_summary(params_summary_str, buddy.train_iter) # 1. Evaluate val set performance if not args.skipval: tic2() for ii in xrange(val_iters): with WithTimer('(worker %d) val iter %d/%d' % (my_rank, ii, val_iters), quiet=not args.verbose): start_idx = ii * minibatch_size # each worker gets a portion of the minibatch my_start = start_idx + my_rank * worker_minibatch_size my_end = my_start + worker_minibatch_size batch_x = val_x[my_start:my_end] batch_y = val_y[my_start:my_end] #print "**** I am worker %d, my val batch starts %d and ends %d"%(my_rank, my_start, my_end) if randcrops: batch_x = impreproc.center_crops(batch_x, cropsize) feed_dict = { model.v.input_images: batch_x, model.v.input_labels: batch_y, K.learning_phase(): 0, } feed_dict.update(extra_feed_dict) fetch_dict = model.trackable_dict with WithTimer('(worker %d) sess.run val iter' % my_rank, quiet=not args.verbose): result_val = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names, [result_val[k] for k in model.trackable_names], prefix='val_') if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_') }, prefix='buddy') print( '\ntime: %f. after training for %d epochs:\n%3d (worker %d) val: %s (%.3gs/i)' % (buddy.toc(), buddy.epoch, buddy.train_iter, my_rank, buddy.epoch_mean_pretty_re( '^val_', style=val_style), toc2() / val_iters)) # 2. Possiby Snapshot, possibly quit # only worker 0 handles it if args.output and args.snapshot_to and args.snapshot_every: snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0 snap_end = buddy.epoch == args.epochs if snap_intermed or snap_end: # Snapshot if my_rank == 0: save_path = saver.save( sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch)) print 'snappshotted model to', save_path with gzip.open( '%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff: saved = {'buddy': buddy} pickle.dump(saved, ff) if buddy.epoch == args.epochs: if args.ipy: print 'Embed: at end of training (Ctrl-D to exit)' embed() break # Extra pass at end: just report val stats and skip training # 3. Train on training set tic3() for ii in xrange(train_iters): tic2() with WithTimer('(worker %d) train iter %d/%d' % (my_rank, ii, train_iters), quiet=not args.verbose): if args.shuffletrain: start_idx = np.random.randint(train_x.shape[0] - minibatch_size) else: start_idx = ii * minibatch_size # each worker gets a portion of the minibatch my_start = start_idx + my_rank * worker_minibatch_size my_end = my_start + worker_minibatch_size #print "**** ii is %d, train_iters is %d"%(ii, train_iters) #print "**** I am worker %d, my training batch starts %d and ends %d (total: %d)"%(my_rank, my_start, my_end, train_x.shape[0]) batch_x = train_x[my_start:my_end] batch_y = train_y[my_start:my_end] if randcrops: batch_x = impreproc.random_crops(batch_x, cropsize, randmirrors) feed_dict = { model.v.input_images: batch_x, model.v.input_labels: batch_y, K.learning_phase(): 1, } feed_dict.update(extra_feed_dict) fetch_dict = {'train_step': train_step} fetch_dict.update(model.trackable_and_update_dict) if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if param_histogram_summaries is not None: fetch_dict.update({ 'param_histogram_summaries': param_histogram_summaries }) if train_histogram_summaries is not None: fetch_dict.update({ 'train_histogram_summaries': train_histogram_summaries }) if train_scalar_summaries is not None: fetch_dict.update( {'train_scalar_summaries': train_scalar_summaries}) with WithTimer('(worker %d) sess.run train iter' % my_rank, quiet=not args.verbose): result_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names, [result_train[k] for k in model.trackable_names], prefix='train_') if do_log_train(buddy.epoch, buddy.train_iter, ii): print('%3d (worker %d) train: %s (%.3gs/i)' % (buddy.train_iter, my_rank, buddy.epoch_mean_pretty_re( '^train_', style=train_style), toc2())) if args.output: if param_histogram_summaries is not None: hist_summary_str = result_train[ 'param_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_histogram_summaries is not None: hist_summary_str = result_train[ 'train_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_scalar_summaries is not None: scalar_summary_str = result_train[ 'train_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) log_scalars(writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re( '^train_') }, prefix='buddy') if ii > 0 and ii % 100 == 0: print ' %d: Average iteration time over last 100 train iters: %.3gs' % ( ii, toc3() / 100) tic3() buddy.inc_train_iter() # after finished training a mini-batch buddy.inc_epoch() # after finished training whole pass through set if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^train_') }, prefix='buddy') print '\nFinal' print '%02d:%d val: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style)) print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style)) print '\nfinal_stats epochs %g' % buddy.epoch print 'final_stats iters %g' % buddy.train_iter print 'final_stats time %g' % buddy.toc() for name, value in buddy.epoch_mean_list_all(): print 'final_stats %s %g' % (name, value) if args.output: writer.close() # Flush and close
def main(): parser = make_standard_parser( 'Coordconv', arch_choices=arch_choices, skip_train=True, skip_val=True) # re-add train and val h5s as optional parser.add_argument('--data_h5', type=str, default='./data/rectangle_4_uniform.h5', help='data file in hdf5.') parser.add_argument('--x_dim', type=int, default=64, help='x dimension of the output image') parser.add_argument('--y_dim', type=int, default=64, help='y dimension of the output image') parser.add_argument('--lrpolicy', type=str, default='constant', choices=lr_policy_choices, help='LR policy.') parser.add_argument('--lrstepratio', type=float, default=.1, help='LR policy step ratio.') parser.add_argument('--lrmaxsteps', type=int, default=5, help='LR policy step ratio.') parser.add_argument('--lrstepevery', type=int, default=50, help='LR policy step ratio.') parser.add_argument('--filter_size', '-fs', type=int, default=3, help='filter size in deconv network') parser.add_argument('--channel_mul', '-mul', type=int, default=2, help='Deconv model channel multiplier to make bigger models') parser.add_argument('--use_mse_loss', '-mse', action='store_true', help='use mse loss instead of cross entropy') parser.add_argument('--use_sigm_loss', '-sig', action='store_true', help='use sigmoid loss instead of cross entropy') parser.add_argument('--interm_loss', '-interm', default=None, choices=(None, 'softmax', 'mse'), help='add intermediate loss to end-to-end painter model') parser.add_argument('--no_softmax', '-nosfmx', action='store_true', help='Remove softmax sharpening layer in model') args = parser.parse_args() if args.lrpolicy == 'step': lr_policy = LRPolicyStep(args) elif args.lrpolicy == 'valstep': lr_policy = LRPolicyValStep(args) else: lr_policy = LRPolicyConstant(args) minibatch_size = args.minibatch train_style, val_style = ( '', '') if args.nocolor else ( colorama.Fore.BLUE, colorama.Fore.MAGENTA) sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu) # 0. Load data or generate data on the fly print 'Loading data: {}'.format(args.data_h5) if args.arch in ['deconv_classification', 'coordconv_classification', 'upsample_conv_coords', 'upsample_coordconv_coords']: # option a: generate data on the fly #data = list(itertools.product(range(args.x_dim),range(args.y_dim))) # random.shuffle(data) #train_test_split = .8 #val_reps = int(args.x_dim * args.x_dim * train_test_split) // minibatch_size #val_size = val_reps * minibatch_size #train_end = args.x_dim * args.x_dim - val_size #train_x, val_x = np.array(data[:train_end]).astype('int'), np.array(data[train_end:]).astype('int') #train_y, val_y = None, None #DATA_GEN_ON_THE_FLY = True # option b: load the data fd = h5py.File(args.data_h5, 'r') train_x = np.array(fd['train_locations'], dtype=int) # shape (2368, 2) train_y = np.array(fd['train_onehots'], dtype=float) # shape (2368, 64, 64, 1) val_x = np.array(fd['val_locations'], dtype=float) # shape (768, 2) val_y = np.array(fd['val_onehots'], dtype=float) # shape (768, 64, 64, 1) DATA_GEN_ON_THE_FLY = False # number of image channels image_c = train_y.shape[-1] if train_y is not None and len(train_y.shape) == 4 else 1 elif args.arch == 'conv_onehot_image': fd = h5py.File(args.data_h5, 'r') train_x = np.array( fd['train_onehots'], dtype=int) # shape (2368, 64, 64, 1) train_y = np.array(fd['train_imagegray'], dtype=float) / 255.0 # shape (2368, 64, 64, 1) val_x = np.array( fd['val_onehots'], dtype=float) # shape (768, 64, 64, 1) val_y = np.array(fd['val_imagegray'], dtype=float) / \ 255.0 # shape (768, 64, 64, 1) image_c = train_y.shape[-1] elif args.arch == 'deconv_rendering': fd = h5py.File(args.data_h5, 'r') train_x = np.array(fd['train_locations'], dtype=int) # shape (2368, 2) train_y = np.array(fd['train_imagegray'], dtype=float) / 255.0 # shape (2368, 64, 64, 1) val_x = np.array(fd['val_locations'], dtype=float) # shape (768, 2) val_y = np.array(fd['val_imagegray'], dtype=float) / \ 255.0 # shape (768, 64, 64, 1) image_c = train_y.shape[-1] elif args.arch == 'conv_regressor' or args.arch == 'coordconv_regressor': fd = h5py.File(args.data_h5, 'r') train_y = np.array( fd['train_normalized_locations'], dtype=float) # shape (2368, 2) # /255.0 # shape (2368, 64, 64, 1) train_x = np.array(fd['train_onehots'], dtype=float) val_y = np.array( fd['val_normalized_locations'], dtype=float) # shape (768, 2) val_x = np.array( fd['val_onehots'], dtype=float) # shape (768, 64, 64, 1) image_c = train_x.shape[-1] elif args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck': fd = h5py.File(args.data_h5, 'r') train_x = np.array(fd['train_locations'], dtype=int) # shape (2368, 2) train_y = np.array(fd['train_imagegray'], dtype=float) / 255.0 # shape (2368, 64, 64, 1) val_x = np.array(fd['val_locations'], dtype=float) # shape (768, 2) val_y = np.array(fd['val_imagegray'], dtype=float) / 255.0 # shape (768, 64, 64, 1) # add one-hot anyways to track accuracy etc. even if not used in loss train_onehot = np.array( fd['train_onehots'], dtype=int) # shape (2368, 64, 64, 1) val_onehot = np.array( fd['val_onehots'], dtype=int) # shape (768, 64, 64, 1) image_c = train_y.shape[-1] train_size = train_x.shape[0] val_size = val_x.shape[0] # 1. CREATE MODEL input_coords = tf.placeholder( shape=(None,2), dtype='float32', name='input_coords') # cast later in model into float input_onehot = tf.placeholder( shape=(None, args.x_dim, args.y_dim, 1), dtype='float32', name='input_onehot') input_images = tf.placeholder( shape=(None, args.x_dim, args.y_dim, image_c), dtype='float32', name='input_images') if args.arch == 'deconv_classification': model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim, fs=args.filter_size, mul=args.channel_mul, onthefly=DATA_GEN_ON_THE_FLY, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss) model.a('input_coords', input_coords) if not DATA_GEN_ON_THE_FLY: model.a('input_onehot', input_onehot) model([input_coords]) if DATA_GEN_ON_THE_FLY else model([input_coords, input_onehot]) if args.arch == 'conv_regressor': regress_type = 'conv_uniform' if 'uniform' in args.data_h5 else 'conv_quarant' model = ConvRegressor(l2=args.l2, mul=args.channel_mul, _type=regress_type) model.a('input_coords', input_coords) model.a('input_onehot', input_onehot) # call model on inputs model([input_onehot, input_coords]) if args.arch == 'coordconv_regressor': model = ConvRegressor(l2=args.l2, mul=args.channel_mul, _type='coordconv') model.a('input_coords', input_coords) model.a('input_onehot', input_onehot) # call model on inputs model([input_onehot, input_coords]) if args.arch == 'conv_onehot_image': model = ConvImagePainter(l2=args.l2, fs=args.filter_size, mul=args.channel_mul, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss, version='working') # version='simple') # version='simple' to hack a 9x9 all-ones filter solution model.a('input_onehot', input_onehot) model.a('input_images', input_images) # call model on inputs model([input_onehot, input_images]) if args.arch == 'deconv_rendering': model = DeconvPainter(l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim, fs=args.filter_size, mul=args.channel_mul, onthefly=False, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss) model.a('input_coords', input_coords) model.a('input_images', input_images) # call model on inputs model([input_coords, input_images]) elif args.arch == 'coordconv_classification': model = CoordConvPainter( l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim, include_r=False, mul=args.channel_mul, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss) model.a('input_coords', input_coords) model.a('input_onehot', input_onehot) model([input_coords, input_onehot]) #raise Exception('Not implemented yet') elif args.arch == 'coordconv_rendering': model = CoordConvImagePainter( l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim, include_r=False, mul=args.channel_mul, fs=args.filter_size, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss, interm_loss=args.interm_loss, no_softmax=args.no_softmax, version='working') # version='simple') # version='simple' to hack a 9x9 all-ones filter solution model.a('input_coords', input_coords) model.a('input_onehot', input_onehot) model.a('input_images', input_images) # always input three things to calculate relevant metrics model([input_coords, input_onehot, input_images]) elif args.arch == 'deconv_bottleneck': model = DeconvBottleneckPainter( l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim, mul=args.channel_mul, fs=args.filter_size, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss, interm_loss=args.interm_loss, no_softmax=args.no_softmax, version='working') # version='simple' to hack a 9x9 all-ones filter solution model.a('input_coords', input_coords) model.a('input_onehot', input_onehot) model.a('input_images', input_images) # always input three things to calculate relevant metrics model([input_coords, input_onehot, input_images]) elif args.arch == 'upsample_conv_coords' or args.arch == 'upsample_coordconv_coords': _coordconv = True if args.arch == 'upsample_coordconv_coords' else False model = UpsampleConvPainter( l2=args.l2, x_dim=args.x_dim, y_dim=args.y_dim, mul=args.channel_mul, fs=args.filter_size, use_mse_loss=args.use_mse_loss, use_sigm_loss=args.use_sigm_loss, coordconv=_coordconv) model.a('input_coords', input_coords) model.a('input_onehot', input_onehot) model([input_coords, input_onehot]) print 'All model weights:' summarize_weights(model.trainable_weights) #print 'Model summary:' print 'Another model summary:' model.summarize_named(prefix=' ') print_trainable_warnings(model) # 2. COMPUTE GRADS AND CREATE OPTIMIZER # a placeholder for dynamic learning rate input_lr = tf.placeholder(tf.float32, shape=[]) if args.opt == 'sgd': opt = tf.train.MomentumOptimizer(input_lr, args.mom) elif args.opt == 'rmsprop': opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom) elif args.opt == 'adam': opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2) grads_and_vars = opt.compute_gradients( model.loss, model.trainable_weights, gate_gradients=tf.train.Optimizer.GATE_GRAPH) train_step = opt.apply_gradients(grads_and_vars) # added to train_ and param_ collections add_grads_and_vars_hist_summaries(grads_and_vars) summarize_opt(opt) print 'LR Policy:', lr_policy # add_grad_summaries(grads_and_vars) if not args.arch.endswith('regressor'): image_summaries_traintest(model.logits) if 'input_onehot' in model.named_keys(): image_summaries_traintest(model.input_onehot) if 'input_images' in model.named_keys(): image_summaries_traintest(model.input_images) if 'prob' in model.named_keys(): image_summaries_traintest(model.prob) if 'center_prob' in model.named_keys(): image_summaries_traintest(model.center_prob) if 'center_logits' in model.named_keys(): image_summaries_traintest(model.center_logits) if 'pixelwise_prob' in model.named_keys(): image_summaries_traintest(model.pixelwise_prob) if 'center_logits' in model.named_keys(): image_summaries_traintest(model.center_logits) if 'sharpened_logits' in model.named_keys(): image_summaries_traintest(model.sharpened_logits) # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running # BN means, optimization momentum, ...) and then finalize initialization saver = tf.train.Saver( max_to_keep=None) if ( args.output or args.load) else None if args.load: ckptfile, miscfile = args.load.split(':') # Restore values directly to graph saver.restore(sess, ckptfile) with gzip.open(miscfile) as ff: saved = pickle.load(ff) buddy = saved['buddy'] else: buddy = StatsBuddy() buddy.tic() # call if new run OR resumed run # Check if special layers are initialized right #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0] #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0] # Initialize any missed vars (e.g. optimization momentum, ... if not # loaded from checkpoint) uninitialized_vars = tf_get_uninitialized_variables(sess) init_missed_vars = tf.variables_initializer( uninitialized_vars, 'init_missed_vars') sess.run(init_missed_vars) # Print warnings about any TF vs. Keras shape mismatches # warn_misaligned_shapes(model) # Make sure all variables, which are model variables, have been # initialized (e.g. model params and model running BN means) tf_assert_all_init(sess) # tf.global_variables_initializer().run() # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge train_histogram_summaries = get_collection_intersection_summary( 'train_collection', 'orig_histogram') train_scalar_summaries = get_collection_intersection_summary( 'train_collection', 'orig_scalar') test_histogram_summaries = get_collection_intersection_summary( 'test_collection', 'orig_histogram') test_scalar_summaries = get_collection_intersection_summary( 'test_collection', 'orig_scalar') param_histogram_summaries = get_collection_intersection_summary( 'param_collection', 'orig_histogram') train_image_summaries = get_collection_intersection_summary( 'train_collection', 'orig_image') test_image_summaries = get_collection_intersection_summary( 'test_collection', 'orig_image') writer = None if args.output: mkdir_p(args.output) writer = tf.summary.FileWriter(args.output, sess.graph) # 5. TRAIN train_iters = (train_size) // minibatch_size + \ int(train_size % minibatch_size > 0) if not args.skipval: val_iters = (val_size) // minibatch_size + \ int(val_size % minibatch_size > 0) if args.ipy: print 'Embed: before train / val loop (Ctrl-D to continue)' embed() while buddy.epoch < args.epochs + 1: # How often to log data def do_log_params(ep, it, ii): return True def do_log_val(ep, it, ii): return True def do_log_train( ep, it, ii): return ( it < train_iters and it & it - 1 == 0 or it >= train_iters and it % train_iters == 0) # Log on powers of two then every epoch # 0. Log params if args.output and do_log_params( buddy.epoch, buddy.train_iter, 0) and param_histogram_summaries is not None: params_summary_str, = sess.run([param_histogram_summaries]) writer.add_summary(params_summary_str, buddy.train_iter) # 1. Forward test on validation set if not args.skipval: feed_dict = {learning_phase(): 0} if 'input_coords' in model.named_keys(): val_coords = val_y if args.arch.endswith( 'regressor') else val_x feed_dict.update({model.input_coords: val_coords}) if 'input_onehot' in model.named_keys(): # if 'val_onehot' not in locals(): if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck': if args.arch == 'conv_onehot_image' or args.arch.endswith('regressor'): val_onehot = val_x else: val_onehot = val_y feed_dict.update({ model.input_onehot: val_onehot, }) if 'input_images' in model.named_keys(): feed_dict.update({ model.input_images: val_images, }) fetch_dict = model.trackable_dict() if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): if test_image_summaries is not None: fetch_dict.update( {'test_image_summaries': test_image_summaries}) if test_scalar_summaries is not None: fetch_dict.update( {'test_scalar_summaries': test_scalar_summaries}) if test_histogram_summaries is not None: fetch_dict.update( {'test_histogram_summaries': test_histogram_summaries}) with WithTimer('sess.run val iter', quiet=not args.verbose): result_val = sess_run_dict( sess, fetch_dict, feed_dict=feed_dict) buddy.note_list( model.trackable_names(), [ result_val[k] for k in model.trackable_names()], prefix='val_') print ( '[%5d] [%2d/%2d] val: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch, args.epochs, buddy.epoch_mean_pretty_re( '^val_', style=val_style), toc2())) if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_')}, prefix='val') if test_image_summaries is not None: image_summary_str = result_val['test_image_summaries'] writer.add_summary(image_summary_str, buddy.train_iter) if test_scalar_summaries is not None: scalar_summary_str = result_val['test_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) if test_histogram_summaries is not None: hist_summary_str = result_val['test_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) # 2. Possiby Snapshot, possibly quit if args.output and args.snapshot_to and args.snapshot_every: snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0 #snap_end = buddy.epoch == args.epochs snap_end = lr_policy.train_done(buddy) if snap_intermed or snap_end: # Snapshot network and buddy save_path = saver.save( sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch)) print 'snappshotted model to', save_path with gzip.open('%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff: saved = {'buddy': buddy} pickle.dump(saved, ff) # Snapshot evaluation data and metrics _, _ = evaluate_net( args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess) lr = lr_policy.get_lr(buddy) if buddy.epoch == args.epochs: if args.ipy: print 'Embed: at end of training (Ctrl-D to exit)' embed() break # Extra pass at end: just report val stats and skip training print '********* at epoch %d, LR is %g' % (buddy.epoch, lr) # 3. Train on training set if args.shuffletrain: train_order = np.random.permutation(train_size) tic3() for ii in xrange(train_iters): tic2() start_idx = ii * minibatch_size end_idx = min(start_idx + minibatch_size, train_size) if args.shuffletrain: # default true batch_x = train_x[sorted( train_order[start_idx:end_idx].tolist())] if train_y is not None: batch_y = train_y[sorted( train_order[start_idx:end_idx].tolist())] # if 'train_onehot' in locals(): if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck': batch_onehot = train_onehot[sorted( train_order[start_idx:end_idx].tolist())] else: batch_x = train_x[start_idx:end_idx] if train_y is not None: batch_y = train_y[start_idx:end_idx] # if 'train_onehot' in locals(): if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck': batch_onehot = train_onehot[start_idx:end_idx] feed_dict = {learning_phase(): 1, input_lr: lr} if 'input_coords' in model.named_keys(): batch_coords = batch_y if args.arch.endswith( 'regressor') else batch_x feed_dict.update({model.input_coords: batch_coords}) if 'input_onehot' in model.named_keys(): # if 'batch_onehot' not in locals(): # if not (args.arch == 'coordconv_rendering' and # args.add_interm_loss): if not args.arch == 'coordconv_rendering' and not args.arch == 'deconv_bottleneck': if args.arch == 'conv_onehot_image' or args.arch.endswith( 'regressor'): batch_onehot = batch_x else: batch_onehot = batch_y feed_dict.update({ model.input_onehot: batch_onehot, }) if 'input_images' in model.named_keys(): feed_dict.update({ model.input_images: batch_images, }) fetch_dict = model.trackable_and_update_dict() fetch_dict.update({'train_step': train_step}) if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if train_histogram_summaries is not None: fetch_dict.update( {'train_histogram_summaries': train_histogram_summaries}) if train_scalar_summaries is not None: fetch_dict.update( {'train_scalar_summaries': train_scalar_summaries}) if train_image_summaries is not None: fetch_dict.update( {'train_image_summaries': train_image_summaries}) with WithTimer('sess.run train iter', quiet=not args.verbose): result_train = sess_run_dict( sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names(), [ result_train[k] for k in model.trackable_names()], prefix='train_') if do_log_train(buddy.epoch, buddy.train_iter, ii): print ( '[%5d] [%2d/%2d] train: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch, args.epochs, buddy.epoch_mean_pretty_re( '^train_', style=train_style), toc2())) if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if train_histogram_summaries is not None: hist_summary_str = result_train['train_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_scalar_summaries is not None: scalar_summary_str = result_train['train_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) if train_image_summaries is not None: image_summary_str = result_train['train_image_summaries'] writer.add_summary(image_summary_str, buddy.train_iter) log_scalars( writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re('^train_')}, prefix='train') if ii > 0 and ii % 100 == 0: print ' %d: Average iteration time over last 100 train iters: %.3gs' % ( ii, toc3() / 100) tic3() buddy.inc_train_iter() # after finished training a mini-batch buddy.inc_epoch() # after finished training whole pass through set if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^train_')}, prefix='train') print '\nFinal' print '%02d:%d val: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re( '^val_', style=val_style)) print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re( '^train_', style=train_style)) print '\nEnd of training. Saving evaluation results on whole train and val set.' final_tr_metrics, final_va_metrics = evaluate_net( args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess) print '\nFinal evaluation on whole train and val' for name, value in final_tr_metrics.iteritems(): print 'final_stats_eval train_%s %g' % (name, value) for name, value in final_va_metrics.iteritems(): print 'final_stats_eval val_%s %g' % (name, value) print '\nfinal_stats epochs %g' % buddy.epoch print 'final_stats iters %g' % buddy.train_iter print 'final_stats time %g' % buddy.toc() for name, value in buddy.epoch_mean_list_all(): print 'final_stats %s %g' % (name, value) if args.output: writer.close() # Flush and close
def evaluate_net(args, buddy, model, train_size, train_x, train_y, val_x, val_y, fd, sess, write_x=True, write_y=True): minibatch_size = args.minibatch train_iters = (train_size) // minibatch_size + \ int(train_size % minibatch_size > 0) # 0 even for train set; because it's evalutation feed_dict_tr = {learning_phase(): 0} feed_dict_va = {learning_phase(): 0} if args.output: final_fetch = {'logits': model.logits} if 'prob' in model.named_keys(): final_fetch.update({'prob': model.prob}) if 'pixelwise_prob' in model.named_keys(): final_fetch.update({'pixelwise_prob': model.pixelwise_prob}) if args.arch == 'coordconv_rendering' or args.arch == 'deconv_bottleneck': final_fetch.update({ 'center_logits': model.center_logits, # 'sharpened_logits': model.sharpened_logits, # or center_prob 'center_prob': model.center_prob, # or center_prob }) ff = h5py.File( '%s/evaluation_%04d.h5' % (args.output, buddy.epoch), 'w') # create dataset but write later for kk in final_fetch.keys(): if args.arch.endswith('regressor'): ff.create_dataset(kk + '_train', (minibatch_size, 2), maxshape=(train_size, 2), dtype=float, compression='lzf', chunks=True) else: ff.create_dataset(kk + '_train', (minibatch_size, args.x_dim, args.y_dim, 1), maxshape=(train_size, args.x_dim, args.y_dim, 1), dtype=float, compression='lzf', chunks=True) # create dataset and write immediately if write_x: ff.create_dataset('inputs_val', data=val_x) ff.create_dataset('inputs_train', data=train_x) if write_y: ff.create_dataset('labels_val', data=val_y) ff.create_dataset('labels_train', data=train_y) for ii in xrange(train_iters): start_idx = ii * minibatch_size end_idx = min(start_idx + minibatch_size, train_size) if 'input_onehot' in model.named_keys(): feed_dict_tr.update({model.input_onehot: np.array( fd['train_onehots'][start_idx:end_idx], dtype=float)}) if ii == 0: feed_dict_va.update( {model.input_onehot: np.array(fd['val_onehots'], dtype=float)}) #feed_dict_va.update({model.input_onehot: val_onehot}) if 'input_images' in model.named_keys(): feed_dict_tr.update({model.input_images: np.array( fd['train_imagegray'][start_idx:end_idx], dtype=float) / 255.0}) if ii == 0: feed_dict_va.update({model.input_images: np.array( fd['val_imagegray'], dtype=float) / 255.0}) #feed_dict_va.update({model.input_images: val_images}) if 'input_coords' in model.named_keys(): if args.arch.endswith('regressor'): _loc_keys = ( 'train_normalized_locations', 'val_normalized_locations', 'float32') else: _loc_keys = ( 'train_locations', 'val_locations', 'int32') feed_dict_tr.update({model.input_coords: np.array( fd[_loc_keys[0]][start_idx:end_idx], dtype=_loc_keys[2])}) if ii == 0: feed_dict_va.update({model.input_coords: np.array( fd[_loc_keys[1]], dtype=_loc_keys[2])}) _final_tr_metrics = sess_run_dict( sess, model.trackable_dict(), feed_dict=feed_dict_tr) _final_tr_metrics['weights'] = end_idx - start_idx final_tr_metrics = _final_tr_metrics if ii == 0 else merge_dict_append( final_tr_metrics, _final_tr_metrics) if args.output: if ii == 0: # do only once final_va = sess_run_dict( sess, final_fetch, feed_dict=feed_dict_va) for kk in final_fetch.keys(): ff.create_dataset(kk + '_val', data=final_va[kk]) final_tr = sess_run_dict(sess, final_fetch, feed_dict=feed_dict_tr) for kk in final_fetch.keys(): if start_idx > 0: n_samples_ = ff[kk + '_train'].shape[0] ff[kk + '_train'].resize(n_samples_ + end_idx - start_idx, axis=0) ff[kk + '_train'][start_idx:, ...] = final_tr[kk] final_va_metrics = sess_run_dict( sess, model.trackable_dict(), feed_dict=feed_dict_va) final_tr_metrics = average_dict_values(final_tr_metrics) if args.output: with open('%s/evaluation_%04d_metrics.pkl' % (args.output, buddy.epoch), 'w') as ffmetrics: tosave = {'train': final_tr_metrics, 'val': final_va_metrics, 'time_elapsed': buddy.toc() } pickle.dump(tosave, ffmetrics) ff.close() else: print '\nEpoch %d evaluation on whole train and val' % buddy.epoch print 'Time elapsed: {}'.format(buddy.toc()) for name, value in final_tr_metrics.iteritems(): print 'final_stats_eval train_%s %g' % (name, value) for name, value in final_va_metrics.iteritems(): print 'final_stats_eval val_%s %g' % (name, value) return final_tr_metrics, final_va_metrics
def main(): lr_policy_choices = ('constant', 'step', 'valstep') parser = make_standard_parser('Region Proposal Net', arch_choices=arch_choices, skip_train=True, skip_val=True) parser.add_argument( '--num', '-N', type=int, default=2, help='Load the Field-of-MNIST dataset with NUM digits per image.') parser.add_argument('--lrpolicy', type=str, default='constant', choices=lr_policy_choices, help='LR policy.') parser.add_argument('--lrstepratio', type=float, default=.1, help='LR policy step ratio.') parser.add_argument('--lrmaxsteps', type=int, default=5, help='LR policy step ratio.') parser.add_argument('--lrstepevery', type=int, default=50, help='LR policy step ratio.') parser.add_argument('--clip', action='store_true', help='clip predicted and ground truth boxes.') parser.add_argument('--same', action='store_true', help='Use `same` filter instead of `valid` in conv.') parser.add_argument('--showbox', action='store_true', help='show moved box during training.') args = parser.parse_args() if args.lrpolicy == 'step': lr_policy = LRPolicyStep(args) elif args.lrpolicy == 'valstep': lr_policy = LRPolicyValStep(args) else: lr_policy = LRPolicyConstant(args) minibatch_size = 1 train_style, val_style = ('', '') if args.nocolor else (colorama.Fore.BLUE, colorama.Fore.MAGENTA) sess = setup_session_and_seeds(args.seed, assert_gpu=not args.cpu) # 0. Load data #train_ims, train_pos, train_class, valid_ims, valid_pos, valid_class, _, _, _ = load_tvt_n_per_field(args.num) #train_ims = train_ims[:5000] # (5000, 64, 64, 1) #train_pos = train_pos[:5000] # (5000, 2, 4) #valid_ims = valid_ims[:1000] #valid_pos = valid_pos[:1000] #_ims, _pos, _class, _, _, _, _, _, _ = load_tvt_n_per_field_centercrop(args.num) ff = h5py.File('data/field_of_mnist_cropped_64x64_5objs.h5', 'r') train_ims = np.array(ff['train_ims']) # (9000, 64, 64, 1) train_pos = np.array( ff['train_pos']) # (9000, 5, 4), parts of boxes may be out of canvas train_class = np.array(ff['train_class']) # (9000, 5) valid_ims = np.array(ff['valid_ims']) # (1000, 64, 64, 1) valid_pos = np.array( ff['valid_pos']) # (1000, 5, 4), parts of boxes may be out of canvas valid_class = np.array(ff['valid_class']) # (1000, 5) ff.close() im_h, im_w, im_c = train_ims.shape[1], train_ims.shape[2], train_ims.shape[ 3] train_size = train_ims.shape[0] val_size = valid_ims.shape[0] print(('Data loaded:\n\timage shape: {}x{}x{}'.format(im_h, im_w, im_c))) print(('\ttrain size: {}\n\ttest size: {}'.format(train_size, val_size))) print(('\tnumber of objects per image: {}'.format(train_pos.shape[1]))) #################### # RPN prameters #################### rpn_params = RPNParams(anchors=np.array([(15, 15), (20, 20), (25, 25), (15, 20), (20, 25), (20, 15), (25, 20), (15, 25), (25, 15)]), rpn_hidden_dim=32, zero_box_conv=False, weight_init_std=0.01, anchor_scale=1.0) bsamp_params = BoxSamplerParams(hi_thresh=0.5, lo_thresh=0.1, sample_size=12) nms_params = NMSParams( nms_thresh=0.8, max_proposals=10, ) # 1. CREATE MODEL input_images = tf.placeholder(shape=(None, im_h, im_w, im_c), dtype='float32', name='input_images') input_gtbox = tf.placeholder(shape=(train_pos.shape[1], 4), dtype='float32', name='input_gtbox') if args.arch == 'rpn_sampler': model = RegionProposalSampler(rpn_params, bsamp_params, nms_params, l2=args.l2, im_h=im_h, im_w=im_w, coordconv=False, clip=args.clip, filtersame=args.same) elif args.arch == 'coord_rpn_sampler': model = RegionProposalSampler(rpn_params, bsamp_params, nms_params, l2=args.l2, im_h=im_h, im_w=im_w, coordconv=True, clip=args.clip, filtersame=args.same) else: raise ValueError('Architecture {} unknown'.format(args.arch)) if args.same: anchors = make_anchors_mnist_same( (16, 16), minibatch_size, rpn_params.anchors) # (batch, 16, 16, 4k) input_anchors = tf.placeholder(shape=(16, 16, 4 * rpn_params.num_anchors), dtype='float32', name='input_anchors') else: anchors = make_anchors_mnist((13, 13), minibatch_size, rpn_params.anchors) # (batch, 13, 13, 4k) input_anchors = tf.placeholder(shape=(13, 13, 4 * rpn_params.num_anchors), dtype='float32', name='input_anchors') anchors = anchors[0] model.a('input_images', input_images) model.a('input_anchors', input_anchors) model.a('input_gtbox', input_gtbox) model([input_images, input_anchors, input_gtbox]) print('All model weights:') summarize_weights(model.trainable_weights) #print 'Model summary:' print('Another model summary:') model.summarize_named(prefix=' ') print_trainable_warnings(model) # 2. COMPUTE GRADS AND CREATE OPTIMIZER input_lr = tf.placeholder( tf.float32, shape=[]) # a placeholder for dynamic learning rate if args.opt == 'sgd': opt = tf.train.MomentumOptimizer(input_lr, args.mom) elif args.opt == 'rmsprop': opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom) elif args.opt == 'adam': opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2) grads_and_vars = opt.compute_gradients( model.loss, model.trainable_weights, gate_gradients=tf.train.Optimizer.GATE_GRAPH) train_step = opt.apply_gradients(grads_and_vars) add_grads_and_vars_hist_summaries( grads_and_vars) # added to train_ and param_ collections summarize_opt(opt) print(('LR Policy:', lr_policy)) # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization saver = tf.train.Saver( max_to_keep=None) if (args.output or args.load) else None if args.load: ckptfile, miscfile = args.load.split(':') # Restore values directly to graph saver.restore(sess, ckptfile) with gzip.open(miscfile) as ff: saved = pickle.load(ff) buddy = saved['buddy'] else: buddy = StatsBuddy() buddy.tic() # call if new run OR resumed run # Check if special layers are initialized right #last_layer_w = [var for var in tf.global_variables() if 'painting_layer/kernel:0' in var.name][0] #last_layer_b = [var for var in tf.global_variables() if 'painting_layer/bias:0' in var.name][0] # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint) uninitialized_vars = tf_get_uninitialized_variables(sess) init_missed_vars = tf.variables_initializer(uninitialized_vars, 'init_missed_vars') sess.run(init_missed_vars) tf_assert_all_init(sess) # 4. SETUP TENSORBOARD LOGGING with tf.summary.merge train_histogram_summaries = get_collection_intersection_summary( 'train_collection', 'orig_histogram') train_scalar_summaries = get_collection_intersection_summary( 'train_collection', 'orig_scalar') test_histogram_summaries = get_collection_intersection_summary( 'test_collection', 'orig_histogram') test_scalar_summaries = get_collection_intersection_summary( 'test_collection', 'orig_scalar') param_histogram_summaries = get_collection_intersection_summary( 'param_collection', 'orig_histogram') train_image_summaries = get_collection_intersection_summary( 'train_collection', 'orig_image') test_image_summaries = get_collection_intersection_summary( 'test_collection', 'orig_image') writer = None if args.output: mkdir_p(args.output) writer = tf.summary.FileWriter(args.output, sess.graph) # 5. TRAIN train_iters = (train_size) // minibatch_size if not args.skipval: val_iters = (val_size) // minibatch_size if args.output: show_indices = np.random.permutation(val_size)[:9] mkdir_p('{}/figures'.format(args.output)) if args.ipy: print('Embed: before train / val loop (Ctrl-D to continue)') embed() while buddy.epoch < args.epochs + 1: # How often to log data do_log_params = lambda ep, it, ii: True do_log_val = lambda ep, it, ii: True do_log_train = lambda ep, it, ii: ( it < train_iters and it & it - 1 == 0 or it >= train_iters and it % train_iters == 0) # Log on powers of two then every epoch # 0. Log params if args.output and do_log_params( buddy.epoch, buddy.train_iter, 0) and param_histogram_summaries is not None: params_summary_str, = sess.run([param_histogram_summaries]) writer.add_summary(params_summary_str, buddy.train_iter) # 1. Forward test on validation set if not args.skipval: for ii in range(val_iters): tic2() start_idx = ii * minibatch_size end_idx = min(start_idx + minibatch_size, val_size) if not end_idx > start_idx: continue feed_dict = { model.input_images: valid_ims[start_idx:end_idx], model.input_anchors: anchors, model.input_gtbox: valid_pos[start_idx:end_idx][0], learning_phase(): 0 } fetch_dict = model.trackable_dict() if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): if test_image_summaries is not None: fetch_dict.update( {'test_image_summaries': test_image_summaries}) if test_scalar_summaries is not None: fetch_dict.update( {'test_scalar_summaries': test_scalar_summaries}) if test_histogram_summaries is not None: fetch_dict.update({ 'test_histogram_summaries': test_histogram_summaries }) with WithTimer('sess.run val iter', quiet=not args.verbose): result_val = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) ## DEBUG ## dynamic p_size and n_size, shouldn slightly very every sample #if ii > 0 and ii % 100 == 0: # print 'VALIDATION --- ' # print sess.run(model.p_size, feed_dict=feed_dict) # print sess.run(model.n_size, feed_dict=feed_dict) ## END DEBUG buddy.note_weighted_list( minibatch_size, model.trackable_names(), [result_val[k] for k in model.trackable_names()], prefix='val_') # Done all val set print(('[%5d] [%2d/%2d] val: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch, args.epochs, buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2()))) if args.output and do_log_val(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_') }, prefix='val') if test_image_summaries is not None: image_summary_str = result_val['test_image_summaries'] writer.add_summary(image_summary_str, buddy.train_iter) if test_scalar_summaries is not None: scalar_summary_str = result_val['test_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) if test_histogram_summaries is not None: hist_summary_str = result_val['test_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) # show some boxes if args.showbox: #and (valid_losses[epoch]/previous_best < 1- args.thresh): show_indices = [55, 555, 678] for show_idx in show_indices: [pos_box, pos_score, neg_box, neg_score] = sess.run( [ model.pos_box, model.pos_score, model.neg_box, model.neg_score ], feed_dict={ model.input_images: valid_ims[show_idx:show_idx + 1], model.input_anchors: anchors, model.input_gtbox: valid_pos[show_idx], learning_phase(): 0 }) subplot(1, 3, show_indices.index(show_idx) + 1) #plot_boxes_pos_neg(valid_ims[show_idx], valid_pos[show_idx], pos_box, neg_box) plot_pos_boxes(valid_ims[show_idx], valid_pos[show_idx], pos_box, pos_score, showlabel=False) show() if args.output: switch_backend('Agg') plot_fetch_dict = { 'pos_box': model.pos_box, 'pos_score': model.pos_score, 'neg_box': model.neg_box, 'neg_score': model.neg_score, 'nms_boxes': model.nms_boxes, 'nms_scores': model.nms_scores, } #fig1, ax1 = subplots(3,3) # plot train boxes #fig2, ax2 = subplots(3,3) # plot test/nms boxes for cc, show_idx in enumerate(show_indices, 1): feed_dict = { model.input_images: valid_ims[show_idx:show_idx + 1], model.input_anchors: anchors, model.input_gtbox: valid_pos[show_idx], learning_phase(): 0 } result_plots = sess_run_dict(sess, plot_fetch_dict, feed_dict=feed_dict) fig1 = figure(1) subplot(3, 3, cc) plot_boxes_pos_neg(valid_ims[show_idx], valid_pos[show_idx], result_plots['pos_box'], result_plots['neg_box']) fig2 = figure(2) subplot(3, 3, cc) #plot_pos_boxes(valid_ims[show_idx], valid_pos[show_idx], result_plots['nms_boxes'], result_plots['nms_scores'], showlabel=False) # normalize scores between 0 and 5, to be used as line width _score_as_lw = 5 * (result_plots['nms_scores'] - result_plots['nms_scores'].min()) / ( result_plots['nms_scores'].max() - result_plots['nms_scores'].min()) plot_pos_boxes_thickness(valid_ims[show_idx], valid_pos[show_idx], result_plots['nms_boxes'], result_plots['nms_scores']) fig1.set_size_inches(10, 10) fig1.savefig('{}/figures/pos_neg_train_box_epoch_{}.png'.format( args.output, buddy.epoch), dpi=100) fig2.set_size_inches(10, 10) fig2.savefig('{}/figures/nms_test_box_epoch_{}.png'.format( args.output, buddy.epoch), dpi=100) # plot test/nms boxes fig, _ = subplots() # 2. Possiby Snapshot, possibly quit if args.output and args.snapshot_to and args.snapshot_every: snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0 #snap_end = buddy.epoch == args.epochs snap_end = lr_policy.train_done(buddy) if snap_intermed or snap_end: # Snapshot network and buddy save_path = saver.save( sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch)) print(('snappshotted model to', save_path)) with gzip.open( '%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff: saved = {'buddy': buddy} pickle.dump(saved, ff) lr = lr_policy.get_lr(buddy) if buddy.epoch == args.epochs: if args.ipy: print('Embed: at end of training (Ctrl-D to exit)') embed() break # Extra pass at end: just report val stats and skip training print(('********* at epoch %d, LR is %g' % (buddy.epoch, lr))) # 3. Train on training set if args.shuffletrain: train_order = np.random.permutation(train_size) tic3() for ii in range(train_iters): tic2() start_idx = ii * minibatch_size end_idx = min(start_idx + minibatch_size, train_size) if not end_idx > start_idx: continue if args.shuffletrain: # default true batch_ims = train_ims[sorted( train_order[start_idx:end_idx].tolist())] batch_pos = train_pos[sorted( train_order[start_idx:end_idx].tolist())] else: batch_ims = train_ims[start_idx:end_idx] batch_pos = train_pos[start_idx:end_idx] feed_dict = { model.input_images: batch_ims, model.input_anchors: anchors, model.input_gtbox: batch_pos[0], learning_phase(): 1, input_lr: lr } fetch_dict = model.trackable_and_update_dict() fetch_dict.update({'train_step': train_step}) if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if train_histogram_summaries is not None: fetch_dict.update({ 'train_histogram_summaries': train_histogram_summaries }) if train_scalar_summaries is not None: fetch_dict.update( {'train_scalar_summaries': train_scalar_summaries}) if train_image_summaries is not None: fetch_dict.update( {'train_image_summaries': train_image_summaries}) with WithTimer('sess.run train iter', quiet=not args.verbose): result_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( minibatch_size, model.trackable_names(), [result_train[k] for k in model.trackable_names()], prefix='train_') if do_log_train(buddy.epoch, buddy.train_iter, ii): print(('[%5d] [%2d/%2d] train: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch, args.epochs, buddy.epoch_mean_pretty_re( '^train_', style=train_style), toc2()))) if args.output and do_log_train(buddy.epoch, buddy.train_iter, ii): if train_histogram_summaries is not None: hist_summary_str = result_train[ 'train_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_scalar_summaries is not None: scalar_summary_str = result_train['train_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) if train_image_summaries is not None: image_summary_str = result_train['train_image_summaries'] writer.add_summary(image_summary_str, buddy.train_iter) log_scalars( writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re('^train_') }, prefix='train') if ii > 0 and ii % 100 == 0: print(( ' %d: Average iteration time over last 100 train iters: %.3gs' % (ii, toc3() / 100))) tic3() buddy.inc_train_iter() # after finished training a mini-batch buddy.inc_epoch() # after finished training whole pass through set if args.output and do_log_train(buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^train_') }, prefix='train') print('\nFinal') print(('%02d:%d val: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style)))) print(('%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style)))) print( '\nEnd of training. Saving evaluation results on whole train and val set.' ) if args.output: writer.close() # Flush and close
def main(): parser = make_standard_parser('Random Projection Experiments.', arch_choices=arch_choices) parser.add_argument('--vsize', type=int, default=100, help='Dimension of intrinsic parmaeter space.') parser.add_argument('--d_rate', '--dr', type=float, default=0.0, help='Dropout rate.') parser.add_argument('--depth', type=int, default=2, help='Number of layers in FNN.') parser.add_argument('--width', type=int, default=100, help='Width of layers in FNN.') parser.add_argument('--minibatch', '--mb', type=int, default=128, help='Size of minibatch.') parser.add_argument('--lr_ratio', '--lrr', type=float, default=.1, help='Ratio to decay LR by every LR_EPSTEP epochs.') parser.add_argument( '--lr_epochs', '--lrep', type=float, default=0, help='Decay LR every LR_EPSTEP epochs. 0 to turn off decay.') parser.add_argument('--lr_steps', '--lrst', type=float, default=3, help='Max LR steps.') parser.add_argument('--c1', type=int, default=6, help='Channels in first conv layer, for LeNet.') parser.add_argument('--c2', type=int, default=16, help='Channels in second conv layer, for LeNet.') parser.add_argument('--d1', type=int, default=120, help='Channels in first dense layer, for LeNet.') parser.add_argument('--d2', type=int, default=84, help='Channels in second dense layer, for LeNet.') parser.add_argument('--denseproj', action='store_true', help='Use a dense projection.') parser.add_argument('--sparseproj', action='store_true', help='Use a sparse projection.') parser.add_argument('--fastfoodproj', action='store_true', help='Use a fastfood projection.') parser.add_argument('--partial_data', '--pd', type=float, default=1.0, help='Percentage of dataset.') parser.add_argument( '--skiptfevents', action='store_true', help='Skip writing tf events files even if output is used.') args = parser.parse_args() n_proj_specified = sum( [args.denseproj, args.sparseproj, args.fastfoodproj]) if args.arch in arch_choices_projected: assert n_proj_specified == 1, 'Arch "%s" requires projection. Specify exactly one of {denseproj, sparseproj, fastfoodproj} options.' % args.arch else: assert n_proj_specified == 0, 'Arch "%s" does not require projection, so do not specify any of {denseproj, sparseproj, fastfoodproj} options.' % args.arch if args.denseproj: proj_type = 'dense' elif args.sparseproj: proj_type = 'sparse' else: proj_type = 'fastfood' train_style, val_style = ('', '') if args.nocolor else (colorama.Fore.BLUE, colorama.Fore.MAGENTA) # Get a TF session registered with Keras and set numpy and TF seeds sess = setup_session_and_seeds(args.seed) # 0. LOAD DATA train_h5 = h5py.File(args.train_h5, 'r') train_x = train_h5['images'] train_y = train_h5['labels'] val_h5 = h5py.File(args.val_h5, 'r') val_x = val_h5['images'] val_y = val_h5['labels'] # loadpath = "./dataset/ag_news.p" # x = pickle.load(open(loadpath, "rb")) # train_x, val_x, test_x = x[0], x[1], x[2] # train_y, val_y, test_y = x[3], x[4], x[5] # wordtoix, ixtoword = x[6], x[7] # train_x = prepare_data_for_cnn(train_x, 100, 5) # val_x = prepare_data_for_cnn(val_x, 100, 5) # #weightInit = tf.random_uniform_initializer(-0.001, 0.001) # #W = tf.get_variable('W', [13010, 300], initializer=weightInit) # W = np.random.rand(13010, 300) # #pdb.set_trace() # train_x = [np.take(W, i, axis=0) for i in train_x] # train_x = np.array(train_x, dtype='float32') # val_x = [np.take(W, i, axis=0) for i in val_x] # val_x = np.array(val_x, dtype='float32') #pdb.set_trace() train_x = np.array(train_x, dtype='float32') val_x = np.array(val_x, dtype='float32') if args.partial_data < 1.0: n_train_ = int(train_y.size * args.partial_data) n_test_ = int(val_y.size * args.partial_data) train_x = train_x[:n_train_] train_y = train_y[:n_train_] val_x = val_x[:n_test_] val_y = val_y[:n_test_] # load into memory if less than 1 GB if train_x.size * 4 + val_x.size * 4 < 1e9: train_x, train_y = np.array(train_x), np.array(train_y) val_x, val_y = np.array(val_x), np.array(val_y) # 1. CREATE MODEL randmirrors = False randcrops = False cropsize = None with WithTimer('Make model'): if args.arch == 'mnistfc_dir': model = build_model_mnist_fc_dir(weight_decay=args.l2, depth=args.depth, width=args.width) elif args.arch == 'mnistfc': if proj_type == 'fastfood': model = build_model_mnist_fc_fastfood(weight_decay=args.l2, vsize=args.vsize, depth=args.depth, width=args.width) else: model = build_model_mnist_fc(weight_decay=args.l2, vsize=args.vsize, depth=args.depth, width=args.width, proj_type=proj_type) elif args.arch == 'mnistconv': model = build_cnn_model_mnist(weight_decay=args.l2, vsize=args.vsize) elif args.arch == 'mnistconv_dir': model = build_cnn_model_direct_mnist(weight_decay=args.l2) elif args.arch == 'cifarfc_dir': model = build_model_cifar_fc_dir(weight_decay=args.l2, depth=args.depth, width=args.width) elif args.arch == 'cifarfc': if proj_type == 'fastfood': model = build_model_cifar_fc_fastfood(weight_decay=args.l2, vsize=args.vsize, depth=args.depth, width=args.width) else: model = build_model_cifar_fc(weight_decay=args.l2, vsize=args.vsize, depth=args.depth, width=args.width, proj_type=proj_type) elif args.arch == 'mnistlenet_dir': model = build_LeNet_direct_mnist(weight_decay=args.l2, c1=args.c1, c2=args.c2, d1=args.d1, d2=args.d2) elif args.arch == 'mnistMLPlenet_dir': model = build_MLPLeNet_direct_mnist(weight_decay=args.l2) elif args.arch == 'mnistMLPlenet': if proj_type == 'fastfood': model = build_model_mnist_MLPLeNet_fastfood( weight_decay=args.l2, vsize=args.vsize) elif args.arch == 'mnistUntiedlenet_dir': model = build_UntiedLeNet_direct_mnist(weight_decay=args.l2) elif args.arch == 'mnistUntiedlenet': if proj_type == 'fastfood': model = build_model_mnist_UntiedLeNet_fastfood( weight_decay=args.l2, vsize=args.vsize) elif args.arch == 'cifarMLPlenet_dir': model = build_MLPLeNet_direct_cifar(weight_decay=args.l2) elif args.arch == 'cifarMLPlenet': if proj_type == 'fastfood': model = build_model_cifar_MLPLeNet_fastfood( weight_decay=args.l2, vsize=args.vsize) elif args.arch == 'cifarUntiedlenet_dir': model = build_UntiedLeNet_direct_cifar(weight_decay=args.l2) elif args.arch == 'cifarUntiedlenet': if proj_type == 'fastfood': model = build_model_cifar_UntiedLeNet_fastfood( weight_decay=args.l2, vsize=args.vsize) elif args.arch == 'mnistlenet': if proj_type == 'fastfood': model = build_model_mnist_LeNet_fastfood(weight_decay=args.l2, vsize=args.vsize) else: model = build_LeNet_mnist(weight_decay=args.l2, vsize=args.vsize, proj_type=proj_type) elif args.arch == 'cifarlenet_dir': model = build_LeNet_direct_cifar(weight_decay=args.l2, d_rate=args.d_rate, c1=args.c1, c2=args.c2, d1=args.d1, d2=args.d2) elif args.arch == 'cifarlenet': if proj_type == 'fastfood': model = build_model_cifar_LeNet_fastfood(weight_decay=args.l2, vsize=args.vsize, d_rate=args.d_rate, c1=args.c1, c2=args.c2, d1=args.d1, d2=args.d2) else: model = build_LeNet_cifar(weight_decay=args.l2, vsize=args.vsize, proj_type=proj_type, d_rate=args.d_rate) elif args.arch == 'cifarDenseNet_dir': model = build_DenseNet_direct_cifar(weight_decay=args.l2, depth=25, nb_dense_block=1, growth_rate=12) elif args.arch == 'cifarDenseNet': if proj_type == 'fastfood': model = build_DenseNet_cifar_fastfood(weight_decay=args.l2, vsize=args.vsize, depth=25, nb_dense_block=1, growth_rate=12) elif args.arch == 'alexnet_dir': model = build_alexnet_direct(weight_decay=args.l2, shift_in=np.array([104, 117, 123])) args.shuffletrain = False randmirrors = True randcrops = True cropsize = (227, 227) elif args.arch == 'squeeze_dir': model = build_squeezenet_direct(weight_decay=args.l2, shift_in=np.array([104, 117, 123])) args.shuffletrain = False randmirrors = True randcrops = True cropsize = (224, 224) elif args.arch == 'alexnet': if proj_type == 'fastfood': model = build_alexnet_fastfood(weight_decay=args.l2, shift_in=np.array( [104, 117, 123]), vsize=args.vsize) else: raise Exception('not implemented') args.shuffletrain = False randmirrors = True randcrops = True cropsize = (227, 227) else: raise Exception('Unknown network architecture: %s' % args.arch) print 'All model weights:' total_params = summarize_weights(model.trainable_weights) print 'Model summary:' model.summary() model.print_trainable_warnings() input_lr = tf.placeholder(tf.float32, shape=[]) lr_stepper = LRStepper(args.lr, args.lr_ratio, args.lr_epochs, args.lr_steps) # 2. COMPUTE GRADS AND CREATE OPTIMIZER if args.opt == 'sgd': opt = tf.train.MomentumOptimizer(input_lr, args.mom) elif args.opt == 'rmsprop': opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom) elif args.opt == 'adam': opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2) # Optimize w.r.t all trainable params in the model grads_and_vars = opt.compute_gradients( model.v.loss, model.trainable_weights, gate_gradients=tf.train.Optimizer.GATE_GRAPH) train_step = opt.apply_gradients(grads_and_vars) add_grad_summaries(grads_and_vars) summarize_opt(opt) # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization saver = tf.train.Saver( max_to_keep=None) if (args.output or args.load) else None if args.load: ckptfile, miscfile = args.load.split(':') # Restore values directly to graph saver.restore(sess, ckptfile) with gzip.open(miscfile) as ff: saved = pickle.load(ff) buddy = saved['buddy'] else: buddy = StatsBuddy() buddy.tic() # call if new run OR resumed run # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint) uninitialized_vars = tf_get_uninitialized_variables(sess) init_missed_vars = tf.variables_initializer(uninitialized_vars, 'init_missed_vars') sess.run(init_missed_vars) # Print warnings about any TF vs. Keras shape mismatches warn_misaligned_shapes(model) # Make sure all variables, which are model variables, have been initialized (e.g. model params and model running BN means) tf_assert_all_init(sess) # 3.5 Normalize the overall basis matrix across the (multiple) unnormalized basis matrices for each layer basis_matrices = [] normalizers = [] for layer in model.layers: try: basis_matrices.extend(layer.offset_creator.basis_matrices) except AttributeError: continue try: normalizers.extend(layer.offset_creator.basis_matrix_normalizers) except AttributeError: continue if len(basis_matrices) > 0 and not args.load: if proj_type == 'sparse': # Norm of overall basis matrix rows (num elements in each sum == total parameters in model) bm_row_norms = tf.sqrt( tf.add_n([ tf.sparse_reduce_sum(tf.square(bm), 1) for bm in basis_matrices ])) # Assign `normalizer` Variable to these row norms to achieve normalization of the basis matrix # in the TF computational graph rescale_basis_matrices = [ tf.assign(var, tf.reshape(bm_row_norms, var.shape)) for var in normalizers ] _ = sess.run(rescale_basis_matrices) elif proj_type == 'dense': bm_sums = [ tf.reduce_sum(tf.square(bm), 1) for bm in basis_matrices ] divisor = tf.expand_dims(tf.sqrt(tf.add_n(bm_sums)), 1) rescale_basis_matrices = [ tf.assign(var, var / divisor) for var in basis_matrices ] _ = sess.run(rescale_basis_matrices) else: print '\nhere\n' embed() assert False, 'what to do with fastfood?' # 4. SETUP TENSORBOARD LOGGING train_histogram_summaries = get_collection_intersection_summary( 'train_collection', 'orig_histogram') train_scalar_summaries = get_collection_intersection_summary( 'train_collection', 'orig_scalar') val_histogram_summaries = get_collection_intersection_summary( 'val_collection', 'orig_histogram') val_scalar_summaries = get_collection_intersection_summary( 'val_collection', 'orig_scalar') param_histogram_summaries = get_collection_intersection_summary( 'param_collection', 'orig_histogram') writer = None if args.output: mkdir_p(args.output) if not args.skiptfevents: writer = tf.summary.FileWriter(args.output, sess.graph) # 5. TRAIN train_iters = (train_y.shape[0] - 1) / args.minibatch + 1 val_iters = (val_y.shape[0] - 1) / args.minibatch + 1 impreproc = ImagePreproc() if args.ipy: print 'Embed: before train / val loop (Ctrl-D to continue)' embed() fastest_avg_iter_time = 1e9 while buddy.epoch < args.epochs + 1: # How often to log data do_log_params = lambda ep, it, ii: False do_log_val = lambda ep, it, ii: True do_log_train = lambda ep, it, ii: ( it < train_iters and it & it - 1 == 0 or it >= train_iters and it % train_iters == 0) # Log on powers of two then every epoch # 0. Log params if args.output and do_log_params( buddy.epoch, buddy.train_iter, 0 ) and param_histogram_summaries is not None and not args.skiptfevents: params_summary_str, = sess.run([param_histogram_summaries]) writer.add_summary(params_summary_str, buddy.train_iter) # 1. Evaluate val set performance if not args.skipval: tic2() for ii in xrange(val_iters): start_idx = ii * args.minibatch batch_x = val_x[start_idx:start_idx + args.minibatch] batch_y = val_y[start_idx:start_idx + args.minibatch] if randcrops: batch_x = impreproc.center_crops(batch_x, cropsize) feed_dict = { model.v.input_images: batch_x, model.v.input_labels: batch_y, K.learning_phase(): 0, } fetch_dict = model.trackable_dict with WithTimer('sess.run val iter', quiet=not args.verbose): result_val = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names, [result_val[k] for k in model.trackable_names], prefix='val_') if args.output and not args.skiptfevents and do_log_val( buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_') }, prefix='buddy') print( '\ntime: %f. after training for %d epochs:\n%3d val: %s (%.3gs/i)' % (buddy.toc(), buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re( '^val_', style=val_style), toc2() / val_iters)) # 2. Possiby Snapshot, possibly quit if args.output and args.snapshot_to and args.snapshot_every: snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0 snap_end = buddy.epoch == args.epochs if snap_intermed or snap_end: # Snapshot save_path = saver.save( sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch)) print 'snappshotted model to', save_path with gzip.open( '%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff: saved = {'buddy': buddy} pickle.dump(saved, ff) if buddy.epoch == args.epochs: if args.ipy: print 'Embed: at end of training (Ctrl-D to exit)' embed() break # Extra pass at end: just report val stats and skip training # 3. Train on training set #train_order = range(train_x.shape[0]) if args.shuffletrain: train_order = np.random.permutation(train_x.shape[0]) tic3() for ii in xrange(train_iters): tic2() start_idx = ii * args.minibatch if args.shuffletrain: batch_x = train_x[train_order[start_idx:start_idx + args.minibatch]] batch_y = train_y[train_order[start_idx:start_idx + args.minibatch]] else: batch_x = train_x[start_idx:start_idx + args.minibatch] batch_y = train_y[start_idx:start_idx + args.minibatch] if randcrops: batch_x = impreproc.random_crops(batch_x, cropsize, randmirrors) feed_dict = { model.v.input_images: batch_x, model.v.input_labels: batch_y, input_lr: lr_stepper.lr(buddy), K.learning_phase(): 1, } fetch_dict = {'train_step': train_step} fetch_dict.update(model.trackable_and_update_dict) if args.output and not args.skiptfevents and do_log_train( buddy.epoch, buddy.train_iter, ii): if param_histogram_summaries is not None: fetch_dict.update({ 'param_histogram_summaries': param_histogram_summaries }) if train_histogram_summaries is not None: fetch_dict.update({ 'train_histogram_summaries': train_histogram_summaries }) if train_scalar_summaries is not None: fetch_dict.update( {'train_scalar_summaries': train_scalar_summaries}) with WithTimer('sess.run train iter', quiet=not args.verbose): result_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list( batch_x.shape[0], model.trackable_names, [result_train[k] for k in model.trackable_names], prefix='train_') if do_log_train(buddy.epoch, buddy.train_iter, ii): print('%3d train: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style), toc2())) if args.output and not args.skiptfevents: if param_histogram_summaries is not None: hist_summary_str = result_train[ 'param_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_histogram_summaries is not None: hist_summary_str = result_train[ 'train_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_scalar_summaries is not None: scalar_summary_str = result_train[ 'train_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) log_scalars( writer, buddy.train_iter, { 'batch_%s' % name: value for name, value in buddy.last_list_re('^train_') }, prefix='buddy') if ii > 0 and ii % 100 == 0: avg_iter_time = toc3() / 100 tic3() fastest_avg_iter_time = min(fastest_avg_iter_time, avg_iter_time) print ' %d: Average iteration time over last 100 train iters: %.3gs' % ( ii, avg_iter_time) buddy.inc_train_iter() # after finished training a mini-batch buddy.inc_epoch() # after finished training whole pass through set if args.output and not args.skiptfevents and do_log_train( buddy.epoch, buddy.train_iter, 0): log_scalars( writer, buddy.train_iter, { 'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^train_') }, prefix='buddy') print '\nFinal' print '%02d:%d val: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style)) print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style)) print '\nfinal_stats epochs %g' % buddy.epoch print 'final_stats iters %g' % buddy.train_iter print 'final_stats time %g' % buddy.toc() print 'final_stats total_params %g' % total_params print 'final_stats fastest_avg_iter_time %g' % fastest_avg_iter_time for name, value in buddy.epoch_mean_list_all(): print 'final_stats %s %g' % (name, value) if args.output and not args.skiptfevents: writer.close() # Flush and close
def main(): parser = make_standard_parser('Low Rank Basis experiments.', skip_train=True, skip_val=True, arch_choices=['one']) parser.add_argument('--DD', type=int, default=1000, help='Dimension of full parameter space.') parser.add_argument('--vsize', type=int, default=100, help='Dimension of intrinsic parameter space.') parser.add_argument('--lr_ratio', '--lrr', type=float, default=.5, help='Ratio to decay LR by every LR_EPSTEP epochs.') parser.add_argument('--lr_epochs', '--lrep', type=float, default=0, help='Decay LR every LR_EPSTEP epochs. 0 to turn off decay.') parser.add_argument('--lr_steps', '--lrst', type=float, default=3, help='Max LR steps.') parser.add_argument('--denseproj', action='store_true', help='Use a dense projection.') parser.add_argument('--skiptfevents', action='store_true', help='Skip writing tf events files even if output is used.') args = parser.parse_args() if args.denseproj: proj_type = 'dense' else: proj_type = None train_style, val_style = ('', '') if args.nocolor else (colorama.Fore.BLUE, colorama.Fore.MAGENTA) # Get a TF session registered with Keras and set numpy and TF seeds sess = setup_session_and_seeds(args.seed) # 1. CREATE MODEL with WithTimer('Make model'): if args.denseproj: model = build_toy(weight_decay=args.l2, DD=args.DD, groups=10, vsize=args.vsize, proj=True) else: model = build_toy(weight_decay=args.l2, DD=args.DD, proj=False) print 'All model weights:' total_params = summarize_weights(model.trainable_weights) print 'Model summary:' model.summary() model.print_trainable_warnings() input_lr = tf.placeholder(tf.float32, shape=[]) lr_stepper = LRStepper(args.lr, args.lr_ratio, args.lr_epochs, args.lr_steps) # 2. COMPUTE GRADS AND CREATE OPTIMIZER if args.opt == 'sgd': opt = tf.train.MomentumOptimizer(input_lr, args.mom) elif args.opt == 'rmsprop': opt = tf.train.RMSPropOptimizer(input_lr, momentum=args.mom) elif args.opt == 'adam': opt = tf.train.AdamOptimizer(input_lr, args.beta1, args.beta2) # Optimize w.r.t all trainable params in the model grads_and_vars = opt.compute_gradients(model.v.loss, model.trainable_weights, gate_gradients=tf.train.Optimizer.GATE_GRAPH) train_step = opt.apply_gradients(grads_and_vars) add_grad_summaries(grads_and_vars) summarize_opt(opt) # 3. OPTIONALLY SAVE OR LOAD VARIABLES (e.g. model params, model running BN means, optimization momentum, ...) and then finalize initialization saver = tf.train.Saver(max_to_keep=None) if (args.output or args.load) else None if args.load: ckptfile, miscfile = args.load.split(':') # Restore values directly to graph saver.restore(sess, ckptfile) with gzip.open(miscfile) as ff: saved = pickle.load(ff) buddy = saved['buddy'] else: buddy = StatsBuddy() buddy.tic() # call if new run OR resumed run # Initialize any missed vars (e.g. optimization momentum, ... if not loaded from checkpoint) uninitialized_vars = tf_get_uninitialized_variables(sess) init_missed_vars = tf.variables_initializer(uninitialized_vars, 'init_missed_vars') sess.run(init_missed_vars) # Print warnings about any TF vs. Keras shape mismatches warn_misaligned_shapes(model) # Make sure all variables, which are model variables, have been initialized (e.g. model params and model running BN means) tf_assert_all_init(sess) # Choose between sparsified and dense projection matrix if using them #SparseRM = True # 3.5 Normalize the overall basis matrix across the (multiple) unnormalized basis matrices for each layer basis_matrices = [] normalizers = [] for layer in model.layers: try: basis_matrices.extend(layer.offset_creator.basis_matrices) except AttributeError: continue try: normalizers.extend(layer.offset_creator.basis_matrix_normalizers) except AttributeError: continue if len(basis_matrices) > 0 and not args.load: if proj_type == 'sparse': # Norm of overall basis matrix rows (num elements in each sum == total parameters in model) bm_row_norms = tf.sqrt(tf.add_n([tf.sparse_reduce_sum(tf.square(bm), 1) for bm in basis_matrices])) # Assign `normalizer` Variable to these row norms to achieve normalization of the basis matrix # in the TF computational graph rescale_basis_matrices = [tf.assign(var, tf.reshape(bm_row_norms,var.shape)) for var in normalizers] _ = sess.run(rescale_basis_matrices) elif proj_type == 'dense': bm_sums = [tf.reduce_sum(tf.square(bm), 1) for bm in basis_matrices] divisor = tf.expand_dims(tf.sqrt(tf.add_n(bm_sums)), 1) rescale_basis_matrices = [tf.assign(var, var / divisor) for var in basis_matrices] sess.run(rescale_basis_matrices) else: print '\nhere\n' embed() assert False, 'what to do with fastfood?' # 4. SETUP TENSORBOARD LOGGING train_histogram_summaries = get_collection_intersection_summary('train_collection', 'orig_histogram') train_scalar_summaries = get_collection_intersection_summary('train_collection', 'orig_scalar') val_histogram_summaries = get_collection_intersection_summary('val_collection', 'orig_histogram') val_scalar_summaries = get_collection_intersection_summary('val_collection', 'orig_scalar') param_histogram_summaries = get_collection_intersection_summary('param_collection', 'orig_histogram') writer = None if args.output: mkdir_p(args.output) if not args.skiptfevents: writer = tf.summary.FileWriter(args.output, sess.graph) # 5. TRAIN train_iters = 1 val_iters = 1 if args.ipy: print 'Embed: before train / val loop (Ctrl-D to continue)' embed() fastest_avg_iter_time = 1e9 while buddy.epoch < args.epochs + 1: # How often to log data do_log_params = lambda ep, it, ii: False do_log_val = lambda ep, it, ii: True do_log_train = lambda ep, it, ii: (it < train_iters and it & it-1 == 0 or it>=train_iters and it%train_iters == 0) # Log on powers of two then every epoch # 0. Log params if args.output and do_log_params(buddy.epoch, buddy.train_iter, 0) and param_histogram_summaries is not None and not args.skiptfevents: params_summary_str, = sess.run([param_histogram_summaries]) writer.add_summary(params_summary_str, buddy.train_iter) # 1. Evaluate val set performance if not args.skipval: tic2() for ii in xrange(val_iters): with WithTimer('val iter %d/%d'%(ii, val_iters), quiet=not args.verbose): feed_dict = { K.learning_phase(): 0, } fetch_dict = model.trackable_dict with WithTimer('sess.run val iter', quiet=not args.verbose): result_val = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list(1, model.trackable_names, [result_val[k] for k in model.trackable_names], prefix='val_') if args.output and not args.skiptfevents and do_log_val(buddy.epoch, buddy.train_iter, 0): log_scalars(writer, buddy.train_iter, {'mean_%s' % name: value for name, value in buddy.epoch_mean_list_re('^val_')}, prefix='buddy') print ('\ntime: %f. after training for %d epochs:\n%3d val: %s (%.3gs/i)' % (buddy.toc(), buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style), toc2() / val_iters)) # 2. Possiby Snapshot, possibly quit if args.output and args.snapshot_to and args.snapshot_every: snap_intermed = args.snapshot_every > 0 and buddy.train_iter % args.snapshot_every == 0 snap_end = buddy.epoch == args.epochs if snap_intermed or snap_end: # Snapshot save_path = saver.save(sess, '%s/%s_%04d.ckpt' % (args.output, args.snapshot_to, buddy.epoch)) print 'snappshotted model to', save_path with gzip.open('%s/%s_misc_%04d.pkl.gz' % (args.output, args.snapshot_to, buddy.epoch), 'w') as ff: saved = {'buddy': buddy} pickle.dump(saved, ff) if buddy.epoch == args.epochs: if args.ipy: print 'Embed: at end of training (Ctrl-D to exit)' embed() break # Extra pass at end: just report val stats and skip training # 3. Train on training set #train_order = range(train_x.shape[0]) tic3() for ii in xrange(train_iters): with WithTimer('train iter %d/%d'%(ii, train_iters), quiet=not args.verbose): tic2() feed_dict = { input_lr: lr_stepper.lr(buddy), K.learning_phase(): 1, } fetch_dict = {'train_step': train_step} fetch_dict.update(model.trackable_and_update_dict) if args.output and not args.skiptfevents and do_log_train(buddy.epoch, buddy.train_iter, ii): if param_histogram_summaries is not None: fetch_dict.update({'param_histogram_summaries': param_histogram_summaries}) if train_histogram_summaries is not None: fetch_dict.update({'train_histogram_summaries': train_histogram_summaries}) if train_scalar_summaries is not None: fetch_dict.update({'train_scalar_summaries': train_scalar_summaries}) with WithTimer('sess.run train iter', quiet=not args.verbose): result_train = sess_run_dict(sess, fetch_dict, feed_dict=feed_dict) buddy.note_weighted_list(1, model.trackable_names, [result_train[k] for k in model.trackable_names], prefix='train_') if do_log_train(buddy.epoch, buddy.train_iter, ii): print ('%3d train: %s (%.3gs/i)' % (buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style), toc2())) if args.output and not args.skiptfevents: if param_histogram_summaries is not None: hist_summary_str = result_train['param_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_histogram_summaries is not None: hist_summary_str = result_train['train_histogram_summaries'] writer.add_summary(hist_summary_str, buddy.train_iter) if train_scalar_summaries is not None: scalar_summary_str = result_train['train_scalar_summaries'] writer.add_summary(scalar_summary_str, buddy.train_iter) log_scalars(writer, buddy.train_iter, {'batch_%s' % name: value for name, value in buddy.last_list_re('^train_')}, prefix='buddy') log_scalars(writer, buddy.train_iter, {'batch_lr': lr_stepper.lr(buddy)}, prefix='buddy') if ii > 0 and ii % 100 == 0: avg_iter_time = toc3() / 100; tic3() fastest_avg_iter_time = min(fastest_avg_iter_time, avg_iter_time) print ' %d: Average iteration time over last 100 train iters: %.3gs' % (ii, avg_iter_time) buddy.inc_train_iter() # after finished training a mini-batch buddy.inc_epoch() # after finished training whole pass through set if args.output and not args.skiptfevents and do_log_train(buddy.epoch, buddy.train_iter, 0): log_scalars(writer, buddy.train_iter, {'mean_%s' % name: value for name,value in buddy.epoch_mean_list_re('^train_')}, prefix='buddy') print '\nFinal' print '%02d:%d val: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^val_', style=val_style)) print '%02d:%d train: %s' % (buddy.epoch, buddy.train_iter, buddy.epoch_mean_pretty_re('^train_', style=train_style)) print '\nfinal_stats epochs %g' % buddy.epoch print 'final_stats iters %g' % buddy.train_iter print 'final_stats time %g' % buddy.toc() print 'final_stats total_params %g' % total_params print 'final_stats fastest_avg_iter_time %g' % fastest_avg_iter_time for name, value in buddy.epoch_mean_list_all(): print 'final_stats %s %g' % (name, value) if args.output and not args.skiptfevents: writer.close() # Flush and close