def evaluate(): """Eval happens on GPU or CPU, and evals each checkpoint as it appears.""" tf.compat.v1.enable_eager_execution() candidate_checkpoint = None uflow = uflow_main.create_uflow() evaluate_fn, _ = uflow_data.make_eval_function(FLAGS.eval_on, FLAGS.height, FLAGS.width, progress_bar=True, plot_dir=FLAGS.plot_dir, num_plots=50) latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) while 1: # Wait for a new checkpoint while candidate_checkpoint == latest_checkpoint: logging.log_every_n( logging.INFO, 'Waiting for a new checkpoint, at %s, latest is %s', 20, FLAGS.checkpoint_dir, latest_checkpoint) time.sleep(0.5) candidate_checkpoint = tf.train.latest_checkpoint( FLAGS.checkpoint_dir) candidate_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) latest_checkpoint = candidate_checkpoint logging.info('New checkpoint found: %s', candidate_checkpoint) # This forces the checkpoint manager to reexamine the checkpoint directory # and become aware of the new checkpoint. uflow.update_checkpoint_dir(FLAGS.checkpoint_dir) uflow.restore() eval_results = evaluate_fn(uflow) uflow_plotting.print_eval(eval_results) step = tf.compat.v1.train.get_global_step().numpy() if step >= FLAGS.num_train_steps: logging.info( 'Evaluator terminating - completed evaluation of checkpoint ' 'from step %d', step) return
def main(unused_argv): if FLAGS.no_tf_function: tf.config.experimental_run_functions_eagerly(True) print('TFFUNCTION DISABLED') gin.parse_config_files_and_bindings(FLAGS.config_file, FLAGS.gin_bindings) # Make directories if they do not exist yet. if FLAGS.checkpoint_dir and not tf.io.gfile.exists(FLAGS.checkpoint_dir): print('Making new checkpoint directory', FLAGS.checkpoint_dir) tf.io.gfile.makedirs(FLAGS.checkpoint_dir) if FLAGS.plot_dir and not tf.io.gfile.exists(FLAGS.plot_dir): print('Making new plot directory', FLAGS.plot_dir) tf.io.gfile.makedirs(FLAGS.plot_dir) uflow = create_uflow() if not FLAGS.from_scratch: # First restore from init_checkpoint_dir, which is only restored from but # not saved to, and then restore from checkpoint_dir if there is already # a model there (e.g. if the run was stopped and restarted). if FLAGS.init_checkpoint_dir: print('Initializing model from checkpoint {}.'.format( FLAGS.init_checkpoint_dir)) uflow.update_checkpoint_dir(FLAGS.init_checkpoint_dir) uflow.restore(reset_optimizer=FLAGS.reset_optimizer, reset_global_step=FLAGS.reset_global_step) uflow.update_checkpoint_dir(FLAGS.checkpoint_dir) if FLAGS.checkpoint_dir: print('Restoring model from checkpoint {}.'.format( FLAGS.checkpoint_dir)) uflow.restore() else: print('Starting from scratch.') print('Making eval datasets and eval functions.') if FLAGS.eval_on: evaluate, _ = uflow_data.make_eval_function(FLAGS.eval_on, FLAGS.height, FLAGS.width, progress_bar=True, plot_dir=FLAGS.plot_dir, num_plots=50) if FLAGS.train_on: # Build training iterator. print('Making training iterator.') train_it = uflow_data.make_train_iterator( FLAGS.train_on, FLAGS.height, FLAGS.width, FLAGS.shuffle_buffer_size, FLAGS.batch_size, FLAGS.seq_len, crop_instead_of_resize=FLAGS.crop_instead_of_resize, apply_augmentation=True, include_ground_truth=FLAGS.use_supervision, resize_gt_flow=FLAGS.resize_gt_flow_supervision, include_occlusions=FLAGS.use_gt_occlusions, ) if FLAGS.use_supervision: # Since this is the only loss in this setting, and the Adam optimizer # is scale invariant, the actual weight here does not matter for now. weights = {'supervision': 1.} else: # Note that self-supervision loss is added during training. weights = { 'photo': FLAGS.weight_photo, 'ssim': FLAGS.weight_ssim, 'census': FLAGS.weight_census, 'smooth1': FLAGS.weight_smooth1, 'smooth2': FLAGS.weight_smooth2, 'edge_constant': FLAGS.smoothness_edge_constant, } # Switch off loss-terms that have weights < 1e-7. weights = { k: v for (k, v) in weights.items() if v > 1e-7 or k == 'edge_constant' } def weight_selfsup_fn(): step = tf.compat.v1.train.get_or_create_global_step( ) % FLAGS.selfsup_step_cycle # Start self-supervision only after a certain number of steps. # Linearly increase self-supervision weight for a number of steps. ramp_up_factor = tf.clip_by_value( float(step - (FLAGS.selfsup_after_num_steps - 1)) / float(max(FLAGS.selfsup_ramp_up_steps, 1)), 0., 1.) return FLAGS.weight_selfsup * ramp_up_factor distance_metrics = { 'photo': FLAGS.distance_photo, 'census': FLAGS.distance_census, } print('Starting training loop.') log = dict() epoch = 0 teacher_feature_model = None teacher_flow_model = None test_frozen_flow = None while True: current_step = tf.compat.v1.train.get_or_create_global_step( ).numpy() # Set which occlusion estimation methods could be active at this point. # (They will only be used if occlusion_estimation is set accordingly.) occ_active = { 'uflow': FLAGS.occlusion_estimation == 'uflow', 'brox': current_step > FLAGS.occ_after_num_steps_brox, 'wang': current_step > FLAGS.occ_after_num_steps_wang, 'wang4': current_step > FLAGS.occ_after_num_steps_wang, 'wangthres': current_step > FLAGS.occ_after_num_steps_wang, 'wang4thres': current_step > FLAGS.occ_after_num_steps_wang, 'fb_abs': current_step > FLAGS.occ_after_num_steps_fb_abs, 'forward_collision': current_step > FLAGS.occ_after_num_steps_forward_collision, 'backward_zero': current_step > FLAGS.occ_after_num_steps_backward_zero, } current_weights = {k: v for k, v in weights.items()} # Prepare self-supervision if it will be used in the next epoch. if FLAGS.weight_selfsup > 1e-7 and ( current_step % FLAGS.selfsup_step_cycle ) + FLAGS.epoch_length > FLAGS.selfsup_after_num_steps: # Add selfsup weight with a ramp-up schedule. This will cause a # recompilation of the training graph defined in uflow.train(...). current_weights['selfsup'] = weight_selfsup_fn # Freeze model for teacher distillation. if teacher_feature_model is None and FLAGS.frozen_teacher: # Create a copy of the existing models and freeze them as a teacher. # Tell uflow about the new, frozen teacher model. teacher_feature_model, teacher_flow_model = create_frozen_teacher_models( uflow) uflow.set_teacher_models( teacher_feature_model=teacher_feature_model, teacher_flow_model=teacher_flow_model) test_frozen_flow = check_model_frozen( teacher_feature_model, teacher_flow_model, prev_flow_output=None) # Check that the model actually is frozen. if FLAGS.frozen_teacher and test_frozen_flow is not None: check_model_frozen(teacher_feature_model, teacher_flow_model, prev_flow_output=test_frozen_flow) # Train for an epoch and save the results. log_update = uflow.train( train_it, weights=current_weights, num_steps=FLAGS.epoch_length, progress_bar=True, plot_dir=FLAGS.plot_dir if FLAGS.plot_debug_info else None, distance_metrics=distance_metrics, occ_active=occ_active) for key in log_update: if key in log: log[key].append(log_update[key]) else: log[key] = [log_update[key]] if FLAGS.checkpoint_dir and not FLAGS.no_checkpointing: uflow.save() # Print losses from last epoch. uflow_plotting.print_log(log, epoch) if FLAGS.eval_on and FLAGS.evaluate_during_train: # Evaluate eval_results = evaluate(uflow) uflow_plotting.print_eval(eval_results) if current_step >= FLAGS.num_train_steps: break epoch += 1 else: print( 'Specify flag train_on to enable training to <format>:<path>;... .' ) print('Just doing evaluation now.') eval_results = evaluate(uflow) if eval_results: uflow_plotting.print_eval(eval_results) print('Evaluation complete.')