def validation_step(sess, val_dict, config, log, val_images=False, val_labels=False, val_batch_idx=False, val_batches=False): it_val_acc = np.asarray([]) it_val_loss = np.asarray([]) start_time = time.time() if val_batch_idx: shuff_val_batch_idx = val_batch_idx[np.random.permutation( len(val_batch_idx))] for num_vals in range(config.num_validation_evals): # Validation accuracy as the average of n batches if val_images: it_idx = shuff_val_batch_idx == num_vals it_ims = val_images[it_idx] it_labs = val_labels[it_idx] if isinstance(it_labs[0], basestring): it_labs = np.asarray( [data_to_tfrecords.load_image(im) for im in it_labs]) feed_dict = { val_dict['val_images']: it_ims, val_dict['val_labels']: it_labs } it_val_dict = sess.run(val_dict, feed_dict=feed_dict) else: it_val_dict = sess.run(val_dict) it_val_acc = np.append(it_val_acc, it_val_dict['val_accuracy']) it_val_loss = np.append(it_val_loss, it_val_dict['val_loss']) val_acc = it_val_acc.mean() val_lo = it_val_loss.mean() duration = time.time() - start_time return val_acc, val_lo, it_val_dict, duration
def validation_step(sess, val_dict, config, log, sequential=False, dict_image_key='val_images', dict_label_key='val_labels', eval_score_key='val_score', eval_loss_key='val_loss', map_im_key='val_images', map_log_key='val_logits', map_lab_key='val_labels', debug_tilts=True, val_images=False, val_labels=False, val_batch_idx=None, val_batches=False): it_val_score = np.asarray([]) it_val_loss = np.asarray([]) start_time = time.time() if debug_tilts: jacs = [] tilt_meta = np.load( '/media/data_cifs/tilt_illusion/test/metadata/filtered_te.npy') save_val_dicts = (hasattr(config, 'get_map') and config.get_map) or \ (hasattr(config, 'all_results') and config.all_results) if save_val_dicts: it_val_dicts = [] for idx in range(config.validation_steps): # Validation accuracy as the average of n batches file_paths = '' # Init empty full path image info for tfrecords if val_batch_idx is not None: if not sequential: it_val_batch_idx = val_batch_idx[np.random.permutation( len(val_batch_idx))] val_step = np.random.randint(low=0, high=val_batch_idx.max()) else: it_val_batch_idx = val_batch_idx val_step = idx it_idx = it_val_batch_idx == val_step it_ims = val_images[it_idx] # if debug_tilts and it_ims[0].split(os.path.sep)[-1] in tilt_meta[:, 1]: # jidx = np.where(it_ims[0].split(os.path.sep)[-1] == tilt_meta[:, 1])[0][0] # pim = data_to_tfrecords.load_image(it_ims[0]) # rot = rotate(pim, float(tilt_meta[jidx][4])) # fig = plt.figure() # plt.subplot(121) # plt.imshow(pim) # plt.subplot(122) # plt.imshow(rot) # plt.show() # plt.close(fig) it_labs = val_labels[it_idx] if isinstance(it_ims[0], basestring): file_paths = np.copy(it_ims) it_ims = np.asarray( [data_to_tfrecords.load_image(im) for im in it_ims]) if isinstance(it_labs[0], basestring): it_labs = np.asarray( [data_to_tfrecords.load_image(im) for im in it_labs]) # if debug_tilts and val_images[it_idx][0].split(os.path.sep)[-1] in tilt_meta[:, 1]: # jidx = np.where(val_images[it_idx][0].split(os.path.sep)[-1] == tilt_meta[:, 1])[0][0] # it_ims[0] = rotate(it_ims[0], float(tilt_meta[jidx][4])) feed_dict = { val_dict[dict_image_key]: it_ims, val_dict[dict_label_key]: it_labs } it_val_dict = sess.run(val_dict, feed_dict=feed_dict) else: it_val_dict = sess.run(val_dict) try: if debug_tilts and val_images[it_idx][0].split( os.path.sep)[-1] in tilt_meta[:, 1]: jidx = np.where(val_images[it_idx][0].split(os.path.sep)[-1] == tilt_meta[:, 1])[0][0] # jacs.append(rotate(it_val_dict['test_jacobian'].squeeze(), float(tilt_meta[jidx][4]))) jacs.append(it_val_dict['test_jacobian'].squeeze()) except: import ipdb ipdb.set_trace() if hasattr(config, 'plot_recurrence') and config.plot_recurrence: norms = tf_fun.visualize_recurrence( idx=idx, image=it_ims, label=it_labs, logits=it_val_dict['test_logits'], h2s=it_val_dict['h2_list'], ff=it_val_dict['ff'], config=config, debug=False) else: norms = -1 # Patch for accuracy... TODO: Fix the TF op # if config.score_function == 'accuracy': # preds = np.argsort(it_val_dict['val_logits'], -1)[:, -1] # it_val_score = np.mean( # preds == it_val_dict['val_labels'].astype(float)) it_val_score = np.append(it_val_score, it_val_dict[eval_score_key]) it_val_loss = np.append(it_val_loss, it_val_dict[eval_loss_key]) if save_val_dicts: if map_im_key not in it_val_dict: map_im_key = map_im_key.replace('proc_', '') map_log_key = map_log_key.replace('proc_', '') map_lab_key = map_lab_key.replace('proc_', '') trim_dict = { 'images': it_val_dict[map_im_key], # im_key='val_images'], 'logits': it_val_dict[map_log_key], # log_key='val_logits'], 'labels': it_val_dict[map_lab_key], # lab_key='val_labels'] 'image_paths': file_paths, 'norms': norms } it_val_dicts += [trim_dict] val_score = it_val_score.mean() val_lo = it_val_loss.mean() duration = time.time() - start_time import ipdb ipdb.set_trace() jacs = np.array(jacs) if save_val_dicts: it_val_dict = it_val_dicts return val_score, val_lo, it_val_dict, duration
def training_loop(config, coord, sess, summary_op, summary_writer, saver, restore_saver, threads, directories, train_dict, val_dict, exp_label, num_params, use_db, log, placeholders=False, checkpoint=None, save_weights=False, save_checkpoints=False, save_activities=False, save_gradients=False): """Run the model training loop.""" if checkpoint is not None: restore_saver.restore(sess, checkpoint) print 'Restored checkpoint %s' % checkpoint if not hasattr(config, 'early_stop'): config.early_stop = np.inf val_perf = np.asarray([np.inf]) step = 0 best_val_dict = None if save_weights: try: weight_dict = {v.name: v for v in tf.trainable_variables()} val_dict = dict(val_dict, **weight_dict) except Exception: raise RuntimeError('Failed to find weights to save.') else: weight_dict = None if hasattr(config, 'early_stop'): it_early_stop = config.early_stop else: it_early_stop = np.inf if placeholders: train_images = placeholders['train']['images'] val_images = placeholders['val']['images'] train_labels = placeholders['train']['labels'] val_labels = placeholders['val']['labels'] train_batches = len(train_images) / config.train_batch_size train_batch_idx = np.arange(train_batches).reshape(-1, 1).repeat( config.train_batch_size) train_images = train_images[:len(train_batch_idx)] train_labels = train_labels[:len(train_batch_idx)] val_batches = len(val_images) / config.val_batch_size val_batch_idx = np.arange(val_batches).reshape(-1, 1).repeat( config.val_batch_size) val_images = val_images[:len(val_batch_idx)] val_labels = val_labels[:len(val_batch_idx)] # Check that labels are appropriate shape tf_label_shape = train_dict['train_labels'].get_shape().as_list() np_label_shape = train_labels.shape if len(tf_label_shape) == 2 and len(np_label_shape) == 1: train_labels = train_labels[..., None] val_labels = val_labels[..., None] elif len(tf_label_shape) == len(np_label_shape): pass else: raise RuntimeError('Mismatch label shape np: %s vs. tf: %s' % (np_label_shape, tf_label_shape)) # Start training for epoch in tqdm(range(config.epochs), desc='Epoch', total=config.epochs): for train_batch in range(train_batches): data_idx = train_batch_idx == train_batch it_train_images = train_images[data_idx] it_train_labels = train_labels[data_idx] if isinstance(it_train_images[0], basestring): it_train_images = np.asarray([ data_to_tfrecords.load_image(im) for im in it_train_images ]) feed_dict = { train_dict['train_images']: it_train_images, train_dict['train_labels']: it_train_labels } (train_score, train_loss, it_train_dict, timer) = training_step(sess=sess, train_dict=train_dict, feed_dict=feed_dict) if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log, val_images=val_images, val_labels=val_labels, val_batch_idx=val_batch_idx, val_batches=val_batches) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print 'Deducted from early stop count.' else: it_early_stop = config.early_stop best_val_dict = it_val_dict print 'Reset early stop count.' if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) return save_progress(config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Training status and validation accuracy val_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status train_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration val_perf = np.concatenate([val_perf, [val_lo]]) step += 1 else: try: while not coord.should_stop(): (train_score, train_loss, it_train_dict, duration) = training_step(sess=sess, config=config, train_dict=train_dict) io_start_time = time.time() if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print('Deducted from early stop count (%s/ %s).' % (it_early_stop, config.early_stop)) else: it_early_stop = config.early_stop best_val_dict = it_val_dict print('Reset early stop count (%s/ %s).' % (it_early_stop, config.early_stop)) if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) break val_perf = save_progress( config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, val_perf=val_perf, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, use_db=use_db, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Training status and validation accuracy val_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status io_duration = time.time() - io_start_time train_status(log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), io_timer=float(io_duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration step += 1 except tf.errors.OutOfRangeError: log.info('Done training for %d epochs, %d steps.' % (config.epochs, step)) log.info('Saved to: %s' % directories['checkpoints']) finally: coord.request_stop() coord.join(threads) sess.close() try: print 'Best %s loss: %s' % (config.val_loss_function, val_perf[0]) if hasattr(config, 'get_map') and config.get_map: tf_fun.calculate_map(best_val_dict, exp_label, config) except: print 'Best %s loss: %s' % val_perf[0] return
def training_loop( config, coord, sess, summary_op, summary_writer, saver, restore_saver, threads, directories, train_dict, val_dict, exp_label, num_params, use_db, log, placeholders=False, checkpoint=None, save_weights=False, save_checkpoints=False, save_activities=False, save_gradients=False): """Run the model training loop.""" if checkpoint is not None: restore_saver.restore(sess, checkpoint) print 'Restored checkpoint %s' % checkpoint if not hasattr(config, 'early_stop'): config.early_stop = np.inf val_perf = np.asarray([np.inf]) step = 0 best_val_dict = None if save_weights: try: weight_dict = {v.name: v for v in tf.trainable_variables()} val_dict = dict( val_dict, **weight_dict) except Exception: raise RuntimeError('Failed to find weights to save.') else: weight_dict = None if hasattr(config, 'early_stop'): it_early_stop = config.early_stop else: it_early_stop = np.inf if hasattr(config, "adaptive_train"): adaptive_train = config.adaptive_train else: adaptive_train = False if placeholders: train_images = placeholders['train']['images'] val_images = placeholders['val']['images'] train_labels = placeholders['train']['labels'] val_labels = placeholders['val']['labels'] train_batches = len(train_images) / config.train_batch_size train_batch_idx = np.arange( train_batches).reshape(-1, 1).repeat( config.train_batch_size) train_images = train_images[:len(train_batch_idx)] train_labels = train_labels[:len(train_batch_idx)] val_batches = len(val_images) / config.val_batch_size val_batch_idx = np.arange( val_batches).reshape(-1, 1).repeat( config.val_batch_size) val_images = val_images[:len(val_batch_idx)] val_labels = val_labels[:len(val_batch_idx)] # Check that labels are appropriate shape tf_label_shape = train_dict['train_labels'].get_shape().as_list() np_label_shape = train_labels.shape if len(tf_label_shape) == 2 and len(np_label_shape) == 1: train_labels = train_labels[..., None] val_labels = val_labels[..., None] elif len(tf_label_shape) == len(np_label_shape): pass else: raise RuntimeError( 'Mismatch label shape np: %s vs. tf: %s' % ( np_label_shape, tf_label_shape)) # Start training train_losses = [] train_logits = [] for epoch in tqdm( range(config.epochs), desc='Epoch', total=config.epochs): for train_batch in range(train_batches): io_start_time = time.time() data_idx = train_batch_idx == train_batch it_train_images = train_images[data_idx] it_train_labels = train_labels[data_idx] if isinstance(it_train_images[0], basestring): it_train_images = np.asarray( [ data_to_tfrecords.load_image(im) for im in it_train_images]) feed_dict = { train_dict['train_images']: it_train_images, train_dict['train_labels']: it_train_labels } ( train_score, train_loss, it_train_dict, timer) = training_step( sess=sess, train_dict=train_dict, config=config, feed_dict=feed_dict) train_losses.append(train_loss) if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log, val_images=val_images, val_labels=val_labels, val_batch_idx=val_batch_idx, val_batches=val_batches) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print 'Deducted from early stop count.' else: it_early_stop = config.early_stop best_val_dict = it_val_dict print 'Reset early stop count.' if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) return save_progress( config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Hack to get the visulations... clean this up later if "BSDS500_test_orientation_viz" in config.experiment: # .model == "BSDS_inh_perturb" or config.model == "BSDS_exc_perturb": # from matplotlib import pyplot as plt;plt.plot(it_train_dict['train_logits'].squeeze(), "r", label="Perturb");plt.plot(it_train_dict['train_labels'].squeeze()[-6:], 'b', label="GT");plt.legend();plt.show() # from matplotlib import pyplot as plt;plt.imshow((it_train_dict['impatch'].squeeze() + np.asarray([123.68, 116.78, 103.94])[None, None]).astype(np.uint8));plt.show() # from matplotlib import pyplot as plt;dd = it_train_dict["grad0"];plt.imshow(np.abs(dd.squeeze()).mean(-1) / (np.abs(dd.squeeze()).std(-1) + 1e-4));plt.show() # from matplotlib import pyplot as plt;dd = it_train_dict['mask'];plt.imshow(dd.squeeze().mean(-1));plt.show() train_logits.append([it_train_dict["train_logits"].ravel()]) out_dir = "circuits_{}".format(config.out_dir) py_utils.make_dir(out_dir) out_target = os.path.join(out_dir, "{}_{}".format(config.model, config.train_dataset)) np.save("{}_optim".format(out_target), [sess.run(tf.trainable_variables())]) # , it_train_dict["conv"]]) np.save("{}_perf".format(out_target), train_losses) np.save("{}_curves".format(out_target), train_logits) np.save("{}_label".format(out_target), it_train_dict["train_labels"]) """ if config.model == "BSDS_inh_perturb": np.save("inh_perturbs/optim", sess.run(tf.trainable_variables()[0])) np.save("inh_perturbs/perf", train_losses) np.save("inh_perturbs/curves", train_logits) np.save("inh_perturbs/label", it_train_dict["train_labels"]) if config.model == "BSDS_exc_perturb": np.save("exc_perturbs/optim", sess.run(tf.trainable_variables()[0])) np.save("exc_perturbs/perf", train_losses) np.save("exc_perturbs/curves", train_logits) np.save("exc_perturbs/label", it_train_dict["train_labels"]) """ # Training status and validation accuracy val_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status io_duration = time.time() - io_start_time train_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), io_timer=float(io_duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration val_perf = np.concatenate([val_perf, [val_lo]]) step += 1 # Adaptive ending if adaptive_train and train_loss <= adaptive_train: break if adaptive_train and train_loss <= adaptive_train: break else: try: while not coord.should_stop(): ( train_score, train_loss, it_train_dict, duration) = training_step( sess=sess, config=config, train_dict=train_dict) io_start_time = time.time() if step % config.validation_period == 0: val_score, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log) # Save progress and important data try: val_check = np.where(val_lo < val_perf)[0] if not len(val_check): it_early_stop -= 1 print 'Deducted from early stop count.' else: it_early_stop = config.early_stop best_val_dict = it_val_dict print 'Reset early stop count.' if it_early_stop <= 0: print 'Early stop triggered. Ending early.' print 'Best validation loss: %s' % np.min(val_perf) break val_perf = save_progress( config=config, val_check=val_check, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, val_score=val_score, val_loss=val_lo, val_perf=val_perf, train_score=train_score, train_loss=train_loss, timer=duration, num_params=num_params, log=log, use_db=use_db, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_gradients=save_gradients, save_checkpoints=save_checkpoints) except Exception as e: log.info('Failed to save checkpoint: %s' % e) # Training status and validation accuracy val_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), score_function=config.score_function, train_score=train_score, val_score=val_score, val_loss=val_lo, best_val_loss=np.min(val_perf), summary_dir=directories['summaries']) else: # Training status io_duration = time.time() - io_start_time train_status( log=log, dt=datetime.now(), step=step, train_loss=train_loss, rate=config.val_batch_size / duration, timer=float(duration), io_timer=float(io_duration), lr=it_train_dict['lr'], score_function=config.score_function, train_score=train_score) # End iteration step += 1 except tf.errors.OutOfRangeError: log.info( 'Done training for %d epochs, %d steps.' % ( config.epochs, step)) log.info('Saved to: %s' % directories['checkpoints']) finally: coord.request_stop() coord.join(threads) sess.close() print 'Best %s loss: %s' % (config.val_loss_function, val_perf[0]) if hasattr(config, 'get_map') and config.get_map: tf_fun.calculate_map(best_val_dict, exp_label, config) return
def validation_step( sess, val_dict, config, log, sequential=False, dict_image_key='val_images', dict_label_key='val_labels', eval_score_key='val_score', eval_loss_key='val_loss', map_im_key='val_images', map_log_key='val_logits', map_lab_key='val_labels', val_images=False, val_labels=False, val_batch_idx=None, val_batches=False): it_val_score = np.asarray([]) it_val_loss = np.asarray([]) start_time = time.time() save_val_dicts = (hasattr(config, 'get_map') and config.get_map) or \ (hasattr(config, 'all_results') and config.all_results) # map_log_key="fgru" # Hack to get activities for viz. if save_val_dicts: it_val_dicts = [] for idx in range(config.validation_steps): # Validation accuracy as the average of n batches file_paths = '' # Init empty full path image info for tfrecords if val_batch_idx is not None: if not sequential and len(val_batch_idx) > 1: it_val_batch_idx = val_batch_idx[ np.random.permutation(len(val_batch_idx))] val_step = np.random.randint(low=0, high=val_batch_idx.max()) else: it_val_batch_idx = val_batch_idx val_step = idx if val_images.shape > 1: it_idx = it_val_batch_idx == val_step it_ims = val_images[it_idx] it_labs = val_labels[it_idx] # Correct for batch-size=1 cases if isinstance(it_ims[0], basestring): file_paths = np.copy(it_ims) it_ims = np.asarray( [ data_to_tfrecords.load_image(im) for im in it_ims]) if isinstance(it_labs[0], basestring): it_labs = np.asarray( [ data_to_tfrecords.load_image(im) for im in it_labs]) feed_dict = { val_dict[dict_image_key]: it_ims, val_dict[dict_label_key]: it_labs } it_val_dict = sess.run(val_dict, feed_dict=feed_dict) else: it_val_dict = sess.run(val_dict) if hasattr(config, 'plot_recurrence') and config.plot_recurrence: tf_fun.visualize_recurrence( idx=idx, image=it_ims, label=it_labs, logits=it_val_dict['test_logits'], h2s=it_val_dict['h2_list'], ff=it_val_dict['ff'], config=config, debug=False) # Patch for accuracy... TODO: Fix the TF op # if config.score_function == 'accuracy': # preds = np.argsort(it_val_dict['val_logits'], -1)[:, -1] # it_val_score = np.mean( # preds == it_val_dict['val_labels'].astype(float)) if config.score_function == 'fixed_accuracy': it_val_labels = it_val_dict[map_lab_key].astype(float) it_val_logits = np.round(sigmoid(it_val_dict[map_log_key])).astype(float) it_val_score = np.mean(it_val_labels == it_val_logits) it_val_score = np.append( it_val_score, it_val_dict[eval_score_key]) it_val_loss = np.append( it_val_loss, it_val_dict[eval_loss_key]) if save_val_dicts: trim_dict = { 'images': it_val_dict[map_im_key], # im_key='val_images'], 'logits': it_val_dict[map_log_key], # log_key='val_logits'], 'labels': it_val_dict[map_lab_key], # lab_key='val_labels'] 'image_paths': file_paths, } it_val_dicts += [trim_dict] val_score = it_val_score.mean() val_lo = it_val_loss.mean() duration = time.time() - start_time if save_val_dicts: it_val_dict = it_val_dicts return val_score, val_lo, it_val_dict, duration
def training_loop(config, coord, sess, summary_op, summary_writer, saver, threads, directories, train_dict, val_dict, exp_label, data_structure, placeholders=False, save_checkpoints=False, save_activities=True, save_gradients=False): """Run the model training loop.""" log = logger.get(os.path.join(config.log_dir, exp_label)) step = 0 if config.save_weights: try: weight_dict = {v.name: v for v in tf.trainable_variables()} val_dict = dict(val_dict, **weight_dict) except Exception: raise RuntimeError('Failed to find weights to save.') if placeholders: placeholder_images = placeholders[0] placeholder_labels = placeholders[1] train_images = placeholder_images['train'] val_images = placeholder_images['val'] train_labels = placeholder_labels['train'] val_labels = placeholder_labels['val'] train_batches = len(train_images) / config.batch_size train_batch_idx = np.arange(train_batches / config.batch_size).reshape( -1, 1).repeat(config.batch_size) train_images = train_images[:len(train_batch_idx)] train_labels = train_labels[:len(train_batch_idx)] val_batches = len(val_images) / config.batch_size val_batch_idx = np.arange(val_batches / config.batch_size).reshape( -1, 1).repeat(config.batch_size) val_images = val_images[:len(val_batch_idx)] val_labels = val_labels[:len(val_batch_idx)] for epoch in tqdm(range(config.epochs), desc='Epoch', total=config.epochs): for train_batch in range(train_batches): data_idx = train_batch_idx == train_batch it_train_images = train_images[data_idx] it_train_labels = train_labels[data_idx] if isinstance(it_train_images[0], basestring): it_train_images = np.asarray([ data_to_tfrecords.load_image(im) for im in it_train_images ]) feed_dict = { train_dict['train_images']: it_train_images, train_dict['train_labels']: it_train_labels } (train_acc, train_loss, it_train_dict, timesteps) = training_step(sess=sess, train_dict=train_dict, feed_dict=feed_dict) if step % config.validation_iters == 0: val_acc, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, data_structure=data_structure, config=config, log=log, val_images=val_images, val_labels=val_labels, val_batch_idx=val_batch_idx, val_batches=val_batches) # Save progress and important data save_progress(config=config, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, data_structure=data_structure, val_acc=val_acc, val_lo=val_lo, train_acc=train_acc, train_loss=train_loss, timesteps=timesteps, log=log, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_checkpoints=save_checkpoints) # Training status and validation accuracy format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; ' '%.3f sec/batch) | Training accuracy = %s | ' 'Validation accuracy = %s | logdir = %s') log.info(format_str % (datetime.now(), step, train_acc, config.batch_size / duration, float(duration), train_loss, val_acc, directories['summaries'])) else: # Training status format_str = ( '%s: step %d, loss = %.5f (%.1f examples/sec; ' '%.3f sec/batch) | Training accuracy = %s') log.info( format_str % (datetime.now(), step, train_loss, config.batch_size / duration, float(duration), train_acc)) step += 1 else: try: while not coord.should_stop(): (train_acc, train_loss, it_train_dict, duration) = training_step(sess=sess, train_dict=train_dict) if step % config.validation_iters == 0: val_acc, val_lo, it_val_dict, duration = validation_step( sess=sess, val_dict=val_dict, config=config, log=log) # Save progress and important data save_progress(config=config, weight_dict=weight_dict, it_val_dict=it_val_dict, exp_label=exp_label, step=step, directories=directories, sess=sess, saver=saver, data_structure=data_structure, val_acc=val_acc, val_lo=val_lo, train_acc=train_acc, train_loss=train_loss, timesteps=duration, log=log, summary_op=summary_op, summary_writer=summary_writer, save_activities=save_activities, save_checkpoints=save_checkpoints) # Training status and validation accuracy format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; ' '%.3f sec/batch) | Training accuracy = %s | ' 'Validation accuracy = %s | logdir = %s') log.info(format_str % (datetime.now(), step, train_acc, config.batch_size / duration, float(duration), train_loss, val_acc, directories['summaries'])) else: # Training status format_str = ( '%s: step %d, loss = %.5f (%.1f examples/sec; ' '%.3f sec/batch) | Training accuracy = %s') log.info( format_str % (datetime.now(), step, train_loss, config.batch_size / duration, float(duration), train_acc)) # End iteration step += 1 except tf.errors.OutOfRangeError: log.info('Done training for %d epochs, %d steps.' % (config.epochs, step)) log.info('Saved to: %s' % directories['checkpoints']) finally: coord.request_stop() coord.join(threads) sess.close() # Package output variables into a dictionary if save_gradients: np.savez( os.path.join(config.results, '%s_train_gradients' % exp_label), **it_train_dict) np.savez(os.path.join(config.results, '%s_val_gradients' % exp_label), **it_val_dict) return