def test_loop(config, sess, summary_op, summary_writer, saver, restore_saver, directories, test_dict, exp_label, num_params, log, map_out='test_maps', num_batches=None, placeholders=False, checkpoint=None, save_weights=False, save_checkpoints=False, save_activities=False, save_gradients=False): """Run the model test loop.""" if checkpoint is not None: restore_saver.restore(sess, checkpoint) print 'Restored checkpoint %s' % checkpoint if placeholders: test_images = placeholders['test']['images'] test_labels = placeholders['test']['labels'] test_batches = len(test_images) / config.test_batch_size test_batch_idx = np.arange(test_batches).reshape(-1, 1).repeat( config.test_batch_size) test_images = test_images[:len(test_batch_idx)] test_labels = test_labels[:len(test_batch_idx)] assert len(test_labels), 'Test labels not found.' assert len(test_images), 'Test images not found.' # Check that labels are appropriate shape tf_label_shape = test_dict['test_labels'].get_shape().as_list() np_label_shape = test_labels.shape if len(tf_label_shape) == 2 and len(np_label_shape) == 1: test_labels = test_labels[..., None] elif len(tf_label_shape) == len(np_label_shape): pass elif len(tf_label_shape) == 4 and np.all( np.array(tf_label_shape)[1:3] == np_label_shape[1:3]): pass else: raise RuntimeError('Mismatch label shape np: %s vs. tf: %s' % (np_label_shape, tf_label_shape)) # Loop through all the images if num_batches is not None: config.validation_steps = num_batches else: config.validation_steps = test_batches test_score, test_lo, it_test_dict, duration = validation_step( sequential=True, sess=sess, val_dict=test_dict, config=config, log=log, dict_image_key='test_images', dict_label_key='test_labels', eval_score_key='test_score', eval_loss_key='test_loss', map_im_key='test_proc_images', map_log_key='test_logits', map_lab_key='test_proc_labels', val_images=test_images, val_labels=test_labels, val_batch_idx=test_batch_idx, val_batches=test_batches) if hasattr(config, 'get_map') and config.get_map and map_out is not None: maps, arands = tf_fun.calculate_map(it_test_dict, exp_label, config, map_dir=map_out) return { 'scores': test_score, 'losses': test_lo, 'maps': maps, 'arands': arands, 'exp_label': exp_label, 'test_dict': it_test_dict, 'duration': duration } else: return { 'scores': test_score, 'losses': test_lo, 'exp_label': exp_label, 'test_dict': it_test_dict, 'duration': duration } else: test_score, test_lo, it_test_dict, duration = validation_step( sess=sess, val_dict=test_dict, config=config, log=log, dict_image_key='test_images', dict_label_key='test_labels', eval_score_key='test_score', eval_loss_key='test_loss', map_im_key='test_proc_images', map_log_key='test_logits', map_lab_key='test_proc_labels') return { 'scores': test_score, 'losses': test_lo, 'exp_label': exp_label, 'test_dict': it_test_dict, 'duration': duration }
def training_loop(config, coord, sess, summary_op, summary_writer, saver, restore_saver, threads, directories, train_dict, val_dict, exp_label, num_params, use_db, log, placeholders=False, checkpoint=None, save_weights=False, save_checkpoints=False, save_activities=False, save_gradients=False): """Run the model training loop.""" if checkpoint is not None: restore_saver.restore(sess, checkpoint) print 'Restored checkpoint %s' % checkpoint if not hasattr(config, 'early_stop'): config.early_stop = np.inf val_perf = np.asarray([np.inf]) step = 0 best_val_dict = None if save_weights: try: weight_dict = {v.name: v for v in tf.trainable_variables()} val_dict = dict(val_dict, **weight_dict) except Exception: raise RuntimeError('Failed to find weights to save.') else: weight_dict = None if hasattr(config, 'early_stop'): it_early_stop = config.early_stop else: it_early_stop = np.inf if placeholders: train_images = placeholders['train']['images'] val_images = placeholders['val']['images'] train_labels = placeholders['train']['labels'] val_labels = placeholders['val']['labels'] train_batches = len(train_images) / config.train_batch_size train_batch_idx = np.arange(train_batches).reshape(-1, 1).repeat( config.train_batch_size) train_images = train_images[:len(train_batch_idx)] train_labels = train_labels[:len(train_batch_idx)] val_batches = len(val_images) / config.val_batch_size val_batch_idx = np.arange(val_batches).reshape(-1, 1).repeat( config.val_batch_size) val_images = val_images[:len(val_batch_idx)] val_labels = val_labels[:len(val_batch_idx)] # Check that labels are appropriate shape tf_label_shape = train_dict['train_labels'].get_shape().as_list() np_label_shape = train_labels.shape if len(tf_label_shape) == 2 and len(np_label_shape) == 1: train_labels = train_labels[..., None] val_labels = val_labels[..., None] elif len(tf_label_shape) == len(np_label_shape): pass else: raise RuntimeError('Mismatch label shape np: %s vs. tf: %s' % (np_label_shape, tf_label_shape)) # Start training for epoch in tqdm(range(config.epochs), desc='Epoch', total=config.epochs): for train_batch in range(train_batches): data_idx = train_batch_idx == train_batch it_train_images = train_images[data_idx] it_train_labels = train_labels[data_idx] if isinstance(it_train_images[0], basestring): it_train_images = np.asarray([ data_to_tfrecords.load_image(im) for im in it_train_images ]) feed_dict = { train_dict['train_images']: it_train_images, train_dict['train_labels']: it_train_labels } (train_score, train_loss, it_train_dict, timer) = training_step(sess=sess, train_dict=train_dict, feed_dict=feed_dict) if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log, val_images=val_images, val_labels=val_labels, val_batch_idx=val_batch_idx, val_batches=val_batches) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print 'Deducted from early stop count.' else: it_early_stop = config.early_stop best_val_dict = it_val_dict print 'Reset early stop count.' if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) return save_progress(config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Training status and validation accuracy val_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status train_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration val_perf = np.concatenate([val_perf, [val_lo]]) step += 1 else: try: while not coord.should_stop(): (train_score, train_loss, it_train_dict, duration) = training_step(sess=sess, config=config, train_dict=train_dict) io_start_time = time.time() if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print('Deducted from early stop count (%s/ %s).' % (it_early_stop, config.early_stop)) else: it_early_stop = config.early_stop best_val_dict = it_val_dict print('Reset early stop count (%s/ %s).' % (it_early_stop, config.early_stop)) if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) break val_perf = save_progress( config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, val_perf=val_perf, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, use_db=use_db, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Training status and validation accuracy val_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status io_duration = time.time() - io_start_time train_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), io_timer=float(io_duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration step += 1 except tf.errors.OutOfRangeError: log.info('Done training for %d epochs, %d steps.' % (config.epochs, step)) log.info('Saved to: %s' % directories['checkpoints']) finally: coord.request_stop() coord.join(threads) sess.close() try: print 'Best %s loss: %s' % (config.val_loss_function, val_perf[0]) if hasattr(config, 'get_map') and config.get_map: tf_fun.calculate_map(best_val_dict, exp_label, config) except: print 'Best %s loss: %s' % val_perf[0] return
def training_loop( config, coord, sess, summary_op, summary_writer, saver, restore_saver, threads, directories, train_dict, val_dict, exp_label, num_params, use_db, log, placeholders=False, checkpoint=None, save_weights=False, save_checkpoints=False, save_activities=False, save_gradients=False): """Run the model training loop.""" if checkpoint is not None: restore_saver.restore(sess, checkpoint) print 'Restored checkpoint %s' % checkpoint if not hasattr(config, 'early_stop'): config.early_stop = np.inf val_perf = np.asarray([np.inf]) step = 0 best_val_dict = None if save_weights: try: weight_dict = {v.name: v for v in tf.trainable_variables()} val_dict = dict( val_dict, **weight_dict) except Exception: raise RuntimeError('Failed to find weights to save.') else: weight_dict = None if hasattr(config, 'early_stop'): it_early_stop = config.early_stop else: it_early_stop = np.inf if hasattr(config, "adaptive_train"): adaptive_train = config.adaptive_train else: adaptive_train = False if placeholders: train_images = placeholders['train']['images'] val_images = placeholders['val']['images'] train_labels = placeholders['train']['labels'] val_labels = placeholders['val']['labels'] train_batches = len(train_images) / config.train_batch_size train_batch_idx = np.arange( train_batches).reshape(-1, 1).repeat( config.train_batch_size) train_images = train_images[:len(train_batch_idx)] train_labels = train_labels[:len(train_batch_idx)] val_batches = len(val_images) / config.val_batch_size val_batch_idx = np.arange( val_batches).reshape(-1, 1).repeat( config.val_batch_size) val_images = val_images[:len(val_batch_idx)] val_labels = val_labels[:len(val_batch_idx)] # Check that labels are appropriate shape tf_label_shape = train_dict['train_labels'].get_shape().as_list() np_label_shape = train_labels.shape if len(tf_label_shape) == 2 and len(np_label_shape) == 1: train_labels = train_labels[..., None] val_labels = val_labels[..., None] elif len(tf_label_shape) == len(np_label_shape): pass else: raise RuntimeError( 'Mismatch label shape np: %s vs. tf: %s' % ( np_label_shape, tf_label_shape)) # Start training train_losses = [] train_logits = [] for epoch in tqdm( range(config.epochs), desc='Epoch', total=config.epochs): for train_batch in range(train_batches): io_start_time = time.time() data_idx = train_batch_idx == train_batch it_train_images = train_images[data_idx] it_train_labels = train_labels[data_idx] if isinstance(it_train_images[0], basestring): it_train_images = np.asarray( [ data_to_tfrecords.load_image(im) for im in it_train_images]) feed_dict = { train_dict['train_images']: it_train_images, train_dict['train_labels']: it_train_labels } ( train_score, train_loss, it_train_dict, timer) = training_step( sess=sess, train_dict=train_dict, config=config, feed_dict=feed_dict) train_losses.append(train_loss) if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log, val_images=val_images, val_labels=val_labels, val_batch_idx=val_batch_idx, val_batches=val_batches) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print 'Deducted from early stop count.' else: it_early_stop = config.early_stop best_val_dict = it_val_dict print 'Reset early stop count.' if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) return save_progress( config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Hack to get the visulations... clean this up later if "BSDS500_test_orientation_viz" in config.experiment: # .model == "BSDS_inh_perturb" or config.model == "BSDS_exc_perturb": # from matplotlib import pyplot as plt;plt.plot(it_train_dict['train_logits'].squeeze(), "r", label="Perturb");plt.plot(it_train_dict['train_labels'].squeeze()[-6:], 'b', label="GT");plt.legend();plt.show() # from matplotlib import pyplot as plt;plt.imshow((it_train_dict['impatch'].squeeze() + np.asarray([123.68, 116.78, 103.94])[None, None]).astype(np.uint8));plt.show() # from matplotlib import pyplot as plt;dd = it_train_dict["grad0"];plt.imshow(np.abs(dd.squeeze()).mean(-1) / (np.abs(dd.squeeze()).std(-1) + 1e-4));plt.show() # from matplotlib import pyplot as plt;dd = it_train_dict['mask'];plt.imshow(dd.squeeze().mean(-1));plt.show() train_logits.append([it_train_dict["train_logits"].ravel()]) out_dir = "circuits_{}".format(config.out_dir) py_utils.make_dir(out_dir) out_target = os.path.join(out_dir, "{}_{}".format(config.model, config.train_dataset)) np.save("{}_optim".format(out_target), [sess.run(tf.trainable_variables())]) # , it_train_dict["conv"]]) np.save("{}_perf".format(out_target), train_losses) np.save("{}_curves".format(out_target), train_logits) np.save("{}_label".format(out_target), it_train_dict["train_labels"]) """ if config.model == "BSDS_inh_perturb": np.save("inh_perturbs/optim", sess.run(tf.trainable_variables()[0])) np.save("inh_perturbs/perf", train_losses) np.save("inh_perturbs/curves", train_logits) np.save("inh_perturbs/label", it_train_dict["train_labels"]) if config.model == "BSDS_exc_perturb": np.save("exc_perturbs/optim", sess.run(tf.trainable_variables()[0])) np.save("exc_perturbs/perf", train_losses) np.save("exc_perturbs/curves", train_logits) np.save("exc_perturbs/label", it_train_dict["train_labels"]) """ # Training status and validation accuracy val_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status io_duration = time.time() - io_start_time train_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), io_timer=float(io_duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration val_perf = np.concatenate([val_perf, [val_lo]]) step += 1 # Adaptive ending if adaptive_train and train_loss <= adaptive_train: break if adaptive_train and train_loss <= adaptive_train: break else: try: while not coord.should_stop(): ( train_score, train_loss, it_train_dict, duration) = training_step( sess=sess, config=config, train_dict=train_dict) io_start_time = time.time() if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print 'Deducted from early stop count.' else: it_early_stop = config.early_stop best_val_dict = it_val_dict print 'Reset early stop count.' if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) break val_perf = save_progress( config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, val_perf=val_perf, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, use_db=use_db, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Training status and validation accuracy val_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status io_duration = time.time() - io_start_time train_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), io_timer=float(io_duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration step += 1 except tf.errors.OutOfRangeError: log.info( 'Done training for %d epochs, %d steps.' % ( config.epochs, step)) log.info('Saved to: %s' % directories['checkpoints']) finally: coord.request_stop() coord.join(threads) sess.close() print 'Best %s loss: %s' % (config.val_loss_function, val_perf[0]) if hasattr(config, 'get_map') and config.get_map: tf_fun.calculate_map(best_val_dict, exp_label, config) return