def evaluate(
    inference_fn,
    dataset,
    height,
    width,
    progress_bar=False,
    plot_dir='',
    num_plots=0,
    max_num_evals=10000,
    prefix='',
    has_occlusion=True,
):
    """Evaluate an inference function for flow.

  Args:
    inference_fn: An inference function that produces a flow_field from two
      images, e.g. the infer method of UFlow.
    dataset: A dataset produced by the method above with for_eval=True.
    height: int, the height to which the images should be resized for inference.
    width: int, the width to which the images should be resized for inference.
    progress_bar: boolean, flag to indicate whether the function should print a
      progress_bar during evaluaton.
    plot_dir: string, optional path to a directory in which plots are saved (if
      num_plots > 0).
    num_plots: int, maximum number of qualitative results to plot for the
      evaluation.
    max_num_evals: int, maxmim number of evaluations.
    prefix: str, prefix to prepend to filenames for saved plots and for keys in
      results dictionary.
    has_occlusion: bool indicating whether or not the dataset includes ground
      truth occlusion.

  Returns:
    A dictionary of floats that represent different evaluation metrics. The keys
    of this dictionary are returned by the method list_eval_keys (see below).
  """

    eval_start_in_s = time.time()

    it = tf.compat.v1.data.make_one_shot_iterator(dataset)
    epe_occ = []  # End point errors.
    errors_occ = []
    inference_times = []
    all_occlusion_results = defaultdict(lambda: defaultdict(int))

    plot_count = 0
    eval_count = -1
    for test_batch in it:

        if eval_count >= max_num_evals:
            break

        eval_count += 1
        if eval_count >= max_num_evals:
            break

        if progress_bar:
            sys.stdout.write(':')
            sys.stdout.flush()

        if has_occlusion:
            (image_batch, flow_gt, _, occ_mask_gt) = test_batch
        else:
            (image_batch, flow_gt, _) = test_batch
            occ_mask_gt = tf.ones_like(flow_gt[Ellipsis, -1:])
        # pylint:disable=cell-var-from-loop
        # pylint:disable=g-long-lambda
        f = lambda: inference_fn(image_batch[0],
                                 image_batch[1],
                                 input_height=height,
                                 input_width=width,
                                 infer_occlusion=True)
        inference_time_in_ms, (flow,
                               soft_occlusion_mask) = uflow_utils.time_it(
                                   f, execute_once_before=eval_count == 1)
        inference_times.append(inference_time_in_ms)

        if not has_occlusion:
            best_thresh = .5
        else:
            f_dict = compute_f_metrics(soft_occlusion_mask, occ_mask_gt)
            best_thresh = -1.
            best_f_score = -1.
            for thresh, metrics in f_dict.items():
                precision = metrics['tp'] / (metrics['tp'] + metrics['fp'] +
                                             1e-6)
                recall = metrics['tp'] / (metrics['tp'] + metrics['fn'] + 1e-6)
                f1 = 2 * precision * recall / (precision + recall + 1e-6)
                if f1 > best_f_score:
                    best_thresh = thresh
                    best_f_score = f1
                all_occlusion_results[thresh]['tp'] += metrics['tp']
                all_occlusion_results[thresh]['fp'] += metrics['fp']
                all_occlusion_results[thresh]['tn'] += metrics['tn']
                all_occlusion_results[thresh]['fn'] += metrics['fn']

        final_flow = flow
        endpoint_error_occ = tf.reduce_sum(input_tensor=(final_flow -
                                                         flow_gt)**2,
                                           axis=-1,
                                           keepdims=True)**0.5
        gt_flow_abs = tf.reduce_sum(input_tensor=flow_gt**2,
                                    axis=-1,
                                    keepdims=True)**0.5
        outliers_occ = tf.cast(
            tf.logical_and(endpoint_error_occ > 3.,
                           endpoint_error_occ > 0.05 * gt_flow_abs), 'float32')
        epe_occ.append(tf.reduce_mean(input_tensor=endpoint_error_occ))
        errors_occ.append(tf.reduce_mean(input_tensor=outliers_occ))

        if plot_dir and plot_count < num_plots:
            plot_count += 1
            mask_thresh = tf.cast(
                tf.math.greater(soft_occlusion_mask, best_thresh), tf.float32)
            uflow_plotting.complete_paper_plot(
                plot_dir,
                plot_count,
                image_batch[0].numpy(),
                image_batch[1].numpy(),
                final_flow.numpy(),
                flow_gt.numpy(),
                np.ones_like(mask_thresh.numpy()),
                1. - mask_thresh.numpy(),
                1. - occ_mask_gt.numpy().astype('float32'),
                frame_skip=None)
    if progress_bar:
        sys.stdout.write('\n')
        sys.stdout.flush()

    fmax, best_thresh = get_fmax_and_best_thresh(all_occlusion_results)
    eval_stop_in_s = time.time()

    results = {
        'occl-f-max': fmax,
        'best-occl-thresh': best_thresh,
        'EPE': np.mean(np.array(epe_occ)),
        'ER': np.mean(np.array(errors_occ)),
        'inf-time(ms)': np.mean(inference_times),
        'eval-time(s)': eval_stop_in_s - eval_start_in_s
    }
    if prefix:
        return {prefix + '-' + k: v for k, v in results.items()}
    return results
def evaluate(inference_fn,
             dataset,
             height,
             width,
             progress_bar=False,
             plot_dir='',
             num_plots=0,
             prefix='kitti'):
  """Evaluate an iference function for flow with a kitti eval dataset.

  Args:
    inference_fn: An inference function that produces a flow_field from two
      images, e.g. the infer method of UFlow.
    dataset: A dataset produced by the method above with for_eval=True.
    height: int, the height to which the images should be resized for inference.
    width: int, the width to which the images should be resized for inference.
    progress_bar: boolean, flag to indicate whether the function should print a
      progress_bar during evaluaton.
    plot_dir: string, optional path to a directory in which plots are saved (if
      num_plots > 0).
    num_plots: int, maximum number of qualitative results to plot for the
      evaluation.
    prefix: str to prefix evaluation keys with in the returned dictionary.

  Returns:
    A dictionary of floats that represent different evaluation metrics. The keys
    of this dictionary are returned by the method list_eval_keys (see below).
  """

  eval_start_in_s = time.time()

  it = tf.compat.v1.data.make_one_shot_iterator(dataset)
  epe_occ = []  # End point errors.
  errors_occ = []
  valid_occ = []
  epe_noc = []  # End point errors.
  errors_noc = []
  valid_noc = []
  inference_times = []
  all_occlusion_results = defaultdict(lambda: defaultdict(int))

  for i, test_batch in enumerate(it):

    if progress_bar:
      sys.stdout.write(':')
      sys.stdout.flush()
    (image_batch, flow_uv_occ, flow_uv_noc, flow_valid_occ,
     flow_valid_noc) = test_batch

    flow_valid_occ = tf.cast(flow_valid_occ, 'float32')
    flow_valid_noc = tf.cast(flow_valid_noc, 'float32')

    # pylint:disable=cell-var-from-loop
    f = lambda: inference_fn(
        image_batch[0],
        image_batch[1],
        input_height=height,
        input_width=width,
        infer_occlusion=True)
    inference_time_in_ms, (flow, soft_occlusion_mask) = uflow_utils.time_it(
        f, execute_once_before=i == 0)
    inference_times.append(inference_time_in_ms)

    occ_mask_gt = flow_valid_occ - flow_valid_noc
    f_dict = data_utils.compute_f_metrics(soft_occlusion_mask * flow_valid_occ,
                                          occ_mask_gt * flow_valid_occ)
    best_thresh = -1.
    best_f_score = -1.
    for thresh, metrics in f_dict.items():
      precision = metrics['tp'] / (metrics['tp'] + metrics['fp'] + 1e-6)
      recall = metrics['tp'] / (metrics['tp'] + metrics['fn'] + 1e-6)
      f1 = 2 * precision * recall / (precision + recall + 1e-6)
      if f1 > best_f_score:
        best_thresh = thresh
        best_f_score = f1
      all_occlusion_results[thresh]['tp'] += metrics['tp']
      all_occlusion_results[thresh]['fp'] += metrics['fp']
      all_occlusion_results[thresh]['tn'] += metrics['tn']
      all_occlusion_results[thresh]['fn'] += metrics['fn']

    mask_thresh = tf.cast(
        tf.math.greater(soft_occlusion_mask, best_thresh), tf.float32)
    # Image coordinates are swapped in labels
    final_flow = flow[Ellipsis, ::-1]

    endpoint_error_occ = tf.reduce_sum(
        input_tensor=(final_flow - flow_uv_occ)**2, axis=-1, keepdims=True)**0.5
    gt_flow_abs = tf.reduce_sum(
        input_tensor=flow_uv_occ**2, axis=-1, keepdims=True)**0.5
    outliers_occ = tf.cast(
        tf.logical_and(endpoint_error_occ > 3.,
                       endpoint_error_occ > 0.05 * gt_flow_abs), 'float32')

    endpoint_error_noc = tf.reduce_sum(
        input_tensor=(final_flow - flow_uv_noc)**2, axis=-1, keepdims=True)**0.5
    gt_flow_abs = tf.reduce_sum(
        input_tensor=flow_uv_noc**2, axis=-1, keepdims=True)**0.5
    outliers_noc = tf.cast(
        tf.logical_and(endpoint_error_noc > 3.,
                       endpoint_error_noc > 0.05 * gt_flow_abs), 'float32')

    epe_occ.append(
        tf.reduce_sum(input_tensor=flow_valid_occ * endpoint_error_occ))
    errors_occ.append(tf.reduce_sum(input_tensor=flow_valid_occ * outliers_occ))
    valid_occ.append(tf.reduce_sum(input_tensor=flow_valid_occ))

    epe_noc.append(
        tf.reduce_sum(input_tensor=flow_valid_noc * endpoint_error_noc))
    errors_noc.append(tf.reduce_sum(input_tensor=flow_valid_noc * outliers_noc))
    valid_noc.append(tf.reduce_sum(input_tensor=flow_valid_noc))

    if plot_dir and i < num_plots:
      uflow_plotting.complete_paper_plot(
          plot_dir,
          i,
          image_batch[0].numpy(),
          image_batch[1].numpy(),
          final_flow.numpy(),
          flow_uv_occ.numpy(),
          flow_valid_occ.numpy(), (1. - mask_thresh).numpy(),
          (1. - occ_mask_gt).numpy(),
          frame_skip=None)
  if progress_bar:
    sys.stdout.write('\n')
    sys.stdout.flush()

  fmax, best_thresh = data_utils.get_fmax_and_best_thresh(all_occlusion_results)
  eval_stop_in_s = time.time()

  results = {
      prefix + '-occl-f-max':
          fmax,
      prefix + '-best-occl-thresh':
          best_thresh,
      prefix + '-EPE(occ)':
          np.clip(np.mean(np.array(epe_occ) / np.array(valid_occ)), 0.0, 50.0),
      prefix + '-ER(occ)':
          np.mean(np.array(errors_occ) / np.array(valid_occ)),
      prefix + '-EPE(noc)':
          np.clip(np.mean(np.array(epe_noc) / np.array(valid_noc)), 0.0, 50.0),
      prefix + '-ER(noc)':
          np.mean(np.array(errors_noc) / np.array(valid_noc)),
      prefix + '-inf-time(ms)':
          np.mean(inference_times),
      prefix + '-eval-time(s)':
          eval_stop_in_s - eval_start_in_s,
  }
  return results