Exemple #1
0
def evaluate():
    """Eval happens on GPU or CPU, and evals each checkpoint as it appears."""
    tf.compat.v1.enable_eager_execution()

    candidate_checkpoint = None
    uflow = uflow_main.create_uflow()
    evaluate_fn, _ = uflow_data.make_eval_function(FLAGS.eval_on,
                                                   FLAGS.height,
                                                   FLAGS.width,
                                                   progress_bar=True,
                                                   plot_dir=FLAGS.plot_dir,
                                                   num_plots=50)

    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    while 1:
        # Wait for a new checkpoint
        while candidate_checkpoint == latest_checkpoint:
            logging.log_every_n(
                logging.INFO,
                'Waiting for a new checkpoint, at %s, latest is %s', 20,
                FLAGS.checkpoint_dir, latest_checkpoint)
            time.sleep(0.5)
            candidate_checkpoint = tf.train.latest_checkpoint(
                FLAGS.checkpoint_dir)
        candidate_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        latest_checkpoint = candidate_checkpoint
        logging.info('New checkpoint found: %s', candidate_checkpoint)
        # This forces the checkpoint manager to reexamine the checkpoint directory
        # and become aware of the new checkpoint.
        uflow.update_checkpoint_dir(FLAGS.checkpoint_dir)
        uflow.restore()
        eval_results = evaluate_fn(uflow)
        uflow_plotting.print_eval(eval_results)
        step = tf.compat.v1.train.get_global_step().numpy()
        if step >= FLAGS.num_train_steps:
            logging.info(
                'Evaluator terminating - completed evaluation of checkpoint '
                'from step %d', step)
            return
def main(unused_argv):

    if FLAGS.no_tf_function:
        tf.config.experimental_run_functions_eagerly(True)
        print('TFFUNCTION DISABLED')

    gin.parse_config_files_and_bindings(FLAGS.config_file, FLAGS.gin_bindings)
    # Make directories if they do not exist yet.
    if FLAGS.checkpoint_dir and not tf.io.gfile.exists(FLAGS.checkpoint_dir):
        print('Making new checkpoint directory', FLAGS.checkpoint_dir)
        tf.io.gfile.makedirs(FLAGS.checkpoint_dir)
    if FLAGS.plot_dir and not tf.io.gfile.exists(FLAGS.plot_dir):
        print('Making new plot directory', FLAGS.plot_dir)
        tf.io.gfile.makedirs(FLAGS.plot_dir)

    uflow = create_uflow()

    if not FLAGS.from_scratch:
        # First restore from init_checkpoint_dir, which is only restored from but
        # not saved to, and then restore from checkpoint_dir if there is already
        # a model there (e.g. if the run was stopped and restarted).
        if FLAGS.init_checkpoint_dir:
            print('Initializing model from checkpoint {}.'.format(
                FLAGS.init_checkpoint_dir))
            uflow.update_checkpoint_dir(FLAGS.init_checkpoint_dir)
            uflow.restore(reset_optimizer=FLAGS.reset_optimizer,
                          reset_global_step=FLAGS.reset_global_step)
            uflow.update_checkpoint_dir(FLAGS.checkpoint_dir)

        if FLAGS.checkpoint_dir:
            print('Restoring model from checkpoint {}.'.format(
                FLAGS.checkpoint_dir))
            uflow.restore()
    else:
        print('Starting from scratch.')

    print('Making eval datasets and eval functions.')
    if FLAGS.eval_on:
        evaluate, _ = uflow_data.make_eval_function(FLAGS.eval_on,
                                                    FLAGS.height,
                                                    FLAGS.width,
                                                    progress_bar=True,
                                                    plot_dir=FLAGS.plot_dir,
                                                    num_plots=50)

    if FLAGS.train_on:
        # Build training iterator.
        print('Making training iterator.')
        train_it = uflow_data.make_train_iterator(
            FLAGS.train_on,
            FLAGS.height,
            FLAGS.width,
            FLAGS.shuffle_buffer_size,
            FLAGS.batch_size,
            FLAGS.seq_len,
            crop_instead_of_resize=FLAGS.crop_instead_of_resize,
            apply_augmentation=True,
            include_ground_truth=FLAGS.use_supervision,
            resize_gt_flow=FLAGS.resize_gt_flow_supervision,
            include_occlusions=FLAGS.use_gt_occlusions,
        )

        if FLAGS.use_supervision:
            # Since this is the only loss in this setting, and the Adam optimizer
            # is scale invariant, the actual weight here does not matter for now.
            weights = {'supervision': 1.}
        else:
            # Note that self-supervision loss is added during training.
            weights = {
                'photo': FLAGS.weight_photo,
                'ssim': FLAGS.weight_ssim,
                'census': FLAGS.weight_census,
                'smooth1': FLAGS.weight_smooth1,
                'smooth2': FLAGS.weight_smooth2,
                'edge_constant': FLAGS.smoothness_edge_constant,
            }

            # Switch off loss-terms that have weights < 1e-7.
            weights = {
                k: v
                for (k, v) in weights.items()
                if v > 1e-7 or k == 'edge_constant'
            }

        def weight_selfsup_fn():
            step = tf.compat.v1.train.get_or_create_global_step(
            ) % FLAGS.selfsup_step_cycle
            # Start self-supervision only after a certain number of steps.
            # Linearly increase self-supervision weight for a number of steps.
            ramp_up_factor = tf.clip_by_value(
                float(step - (FLAGS.selfsup_after_num_steps - 1)) /
                float(max(FLAGS.selfsup_ramp_up_steps, 1)), 0., 1.)
            return FLAGS.weight_selfsup * ramp_up_factor

        distance_metrics = {
            'photo': FLAGS.distance_photo,
            'census': FLAGS.distance_census,
        }

        print('Starting training loop.')
        log = dict()
        epoch = 0

        teacher_feature_model = None
        teacher_flow_model = None
        test_frozen_flow = None

        while True:
            current_step = tf.compat.v1.train.get_or_create_global_step(
            ).numpy()

            # Set which occlusion estimation methods could be active at this point.
            # (They will only be used if occlusion_estimation is set accordingly.)
            occ_active = {
                'uflow':
                FLAGS.occlusion_estimation == 'uflow',
                'brox':
                current_step > FLAGS.occ_after_num_steps_brox,
                'wang':
                current_step > FLAGS.occ_after_num_steps_wang,
                'wang4':
                current_step > FLAGS.occ_after_num_steps_wang,
                'wangthres':
                current_step > FLAGS.occ_after_num_steps_wang,
                'wang4thres':
                current_step > FLAGS.occ_after_num_steps_wang,
                'fb_abs':
                current_step > FLAGS.occ_after_num_steps_fb_abs,
                'forward_collision':
                current_step > FLAGS.occ_after_num_steps_forward_collision,
                'backward_zero':
                current_step > FLAGS.occ_after_num_steps_backward_zero,
            }

            current_weights = {k: v for k, v in weights.items()}

            # Prepare self-supervision if it will be used in the next epoch.
            if FLAGS.weight_selfsup > 1e-7 and (
                    current_step % FLAGS.selfsup_step_cycle
            ) + FLAGS.epoch_length > FLAGS.selfsup_after_num_steps:

                # Add selfsup weight with a ramp-up schedule. This will cause a
                # recompilation of the training graph defined in uflow.train(...).
                current_weights['selfsup'] = weight_selfsup_fn

                # Freeze model for teacher distillation.
                if teacher_feature_model is None and FLAGS.frozen_teacher:
                    # Create a copy of the existing models and freeze them as a teacher.
                    # Tell uflow about the new, frozen teacher model.
                    teacher_feature_model, teacher_flow_model = create_frozen_teacher_models(
                        uflow)
                    uflow.set_teacher_models(
                        teacher_feature_model=teacher_feature_model,
                        teacher_flow_model=teacher_flow_model)
                    test_frozen_flow = check_model_frozen(
                        teacher_feature_model,
                        teacher_flow_model,
                        prev_flow_output=None)

                    # Check that the model actually is frozen.
                    if FLAGS.frozen_teacher and test_frozen_flow is not None:
                        check_model_frozen(teacher_feature_model,
                                           teacher_flow_model,
                                           prev_flow_output=test_frozen_flow)

            # Train for an epoch and save the results.
            log_update = uflow.train(
                train_it,
                weights=current_weights,
                num_steps=FLAGS.epoch_length,
                progress_bar=True,
                plot_dir=FLAGS.plot_dir if FLAGS.plot_debug_info else None,
                distance_metrics=distance_metrics,
                occ_active=occ_active)

            for key in log_update:
                if key in log:
                    log[key].append(log_update[key])
                else:
                    log[key] = [log_update[key]]

            if FLAGS.checkpoint_dir and not FLAGS.no_checkpointing:
                uflow.save()

            # Print losses from last epoch.
            uflow_plotting.print_log(log, epoch)

            if FLAGS.eval_on and FLAGS.evaluate_during_train:
                # Evaluate
                eval_results = evaluate(uflow)
                uflow_plotting.print_eval(eval_results)

            if current_step >= FLAGS.num_train_steps:
                break

            epoch += 1

    else:
        print(
            'Specify flag train_on to enable training to <format>:<path>;... .'
        )
        print('Just doing evaluation now.')
        eval_results = evaluate(uflow)
        if eval_results:
            uflow_plotting.print_eval(eval_results)
        print('Evaluation complete.')