예제 #1
0
def make_metrics(proto):
    key, s = proto
    p = results_pb2.Results.FromString(s)
    mesh = mesh_util.deserialize(p.mesh)
    gt_mesh = mesh_util.deserialize(p.gt_mesh)
    nc, fst, fs2t, chamfer = metrics.all_mesh_metrics(mesh, gt_mesh)
    return {
        'key': key,
        'Normal Consistency': nc,
        'F-Score (tau)': fst,
        'F-Score (2*tau)': fs2t,
        'Chamfer': chamfer,
        'IoU': p.iou
    }
예제 #2
0
def make_metrics(proto):
  """Returns a single-element list containing a dictionary of metrics."""
  key, s = proto
  p = results_pb2.Results.FromString(s)
  mesh_path = f"{FLAGS.occnet_dir}{key.replace('test/', '')}.ply"
  log.warning('Mesh path: %s' % mesh_path)
  try:
    mesh = file_util.read_mesh(mesh_path)
    _, synset, mesh_hash = key.split('/')
    if FLAGS.transform:
      ex = example.InferenceExample('test', synset, mesh_hash)
      tx = ex.gaps_to_occnet
      mesh.apply_transform(tx)
    log.info('Succeeded on %s' % mesh_path)
  # pylint:disable=broad-except
  except Exception as e:
    # pylint:enable=broad-except
    log.error(f"Couldn't load {mesh_path}, skipping due to {repr(e)}.")
    return []

  gt_mesh = mesh_util.deserialize(p.gt_mesh)
  dir_out = FLAGS.occnet_dir + '/metrics-out-gt/%s' % key
  if not file_util.exists(dir_out):
    file_util.makedirs(dir_out)
  file_util.write_mesh(f'{dir_out}gt_mesh.ply', gt_mesh)
  file_util.write_mesh(f'{dir_out}occnet_pred.ply', mesh)

  nc, fst, fs2t, chamfer = metrics.all_mesh_metrics(mesh, gt_mesh)
  return [{
      'key': key,
      'Normal Consistency': nc,
      'F-Score (tau)': fst,
      'F-Score (2*tau)': fs2t,
      'Chamfer': chamfer,
  }]
예제 #3
0
def make_metrics(proto):
  """Builds a dictionary containing proto elements."""
  key, s = proto
  p = results_pb2.Results.FromString(s)
  mesh_path = FLAGS.occnet_dir + key.replace('test/', '') + '.ply'
  log.warning('Mesh path: %s' % mesh_path)
  try:
    mesh = file_util.read_mesh(mesh_path)
    if FLAGS.transform:
      # TODO(ldif-user) Set up the path to the transformation:
      tx_path = 'ROOT_DIR/%s/occnet_to_gaps.txt' % key
      occnet_to_gaps = file_util.read_txt_to_np(tx_path).reshape([4, 4])
      gaps_to_occnet = np.linalg.inv(occnet_to_gaps)
      mesh.apply_transform(gaps_to_occnet)
  # pylint: disable=broad-except
  except Exception as e:
    # pylint: enable=broad-except
    log.error("Couldn't load %s, skipping due to %s." % (mesh_path, repr(e)))
    return []

  gt_mesh = mesh_util.deserialize(p.gt_mesh)
  dir_out = FLAGS.occnet_dir + '/metrics-out-gt/%s' % key
  if not file_util.exists(dir_out):
    file_util.makedirs(dir_out)
  file_util.write_mesh(f'{dir_out}gt_mesh.ply', gt_mesh)
  file_util.write_mesh(f'{dir_out}occnet_pred.ply', mesh)

  nc, fst, fs2t, chamfer = metrics.all_mesh_metrics(mesh, gt_mesh)
  return [{
      'key': key,
      'Normal Consistency': nc,
      'F-Score (tau)': fst,
      'F-Score (2*tau)': fs2t,
      'Chamfer': chamfer,
  }]
예제 #4
0
def mesh_metrics(element):
    """Computes the chamfer distance and normal consistency metrics."""
    log.info('Metric step input: %s' % repr(element))
    example_np = element_to_example(element)
    if not element['mesh_str']:
        raise ValueError(
            'Empty mesh string encountered for %s but mesh metrics required.' %
            repr(element))
    mesh = mesh_util.deserialize(element['mesh_str'])
    if mesh.is_empty:
        raise ValueError(
            'Empty mesh encountered for %s but mesh metrics required.' %
            repr(element))

    sample_count = 100000
    points_pred, normals_pred = sample_points_and_face_normals(
        mesh, sample_count)
    points_gt, normals_gt = sample_points_and_face_normals(
        example_np.gt_mesh, sample_count)

    pred_to_gt_dist, pred_to_gt_indices = pointcloud_neighbor_distances_indices(
        points_pred, points_gt)
    gt_to_pred_dist, gt_to_pred_indices = pointcloud_neighbor_distances_indices(
        points_gt, points_pred)

    pred_to_gt_normals = normals_gt[pred_to_gt_indices]
    gt_to_pred_normals = normals_pred[gt_to_pred_indices]

    # We take abs because the OccNet code takes abs
    pred_to_gt_normal_consistency = np.abs(
        dot_product(normals_pred, pred_to_gt_normals))
    gt_to_pred_normal_consistency = np.abs(
        dot_product(normals_gt, gt_to_pred_normals))

    # The 100 factor is because papers multiply by 100 for display purposes.
    chamfer = 100.0 * (np.mean(pred_to_gt_dist**2) +
                       np.mean(gt_to_pred_dist**2))

    nc = 0.5 * np.mean(pred_to_gt_normal_consistency) + 0.5 * np.mean(
        gt_to_pred_normal_consistency)

    tau = 1e-04
    f_score_tau = f_score(pred_to_gt_dist, gt_to_pred_dist, tau)
    f_score_2tau = f_score(pred_to_gt_dist, gt_to_pred_dist, 2.0 * tau)

    element['chamfer'] = chamfer
    element['normal_consistency'] = nc
    element['f_score_tau'] = f_score_tau
    element['f_score_2tau'] = f_score_2tau
    element['split'] = example_np.split
    element['synset'] = example_np.synset
    element['name'] = example_np.mesh_hash
    element['class'] = example_np.cat
    return element
예제 #5
0
def _write_results(proto, xid=None):
    """Writes the prediction, ground truth, and representation to disk."""
    key, s = proto
    p = results_pb2.Results.FromString(s)
    if xid is None:
        dir_out = FLAGS.input_dir + '/extracted/' + key + '/'
    else:
        dir_out = FLAGS.input_dir + '/extracted/XID%i/%s/' % (xid, key)
    file_util.makedirs(dir_out)
    file_util.write_mesh(f'{dir_out}/gt_mesh.ply', p.gt_mesh)
    file_util.write_mesh(f'{dir_out}/pred_mesh.ply', p.mesh)
    file_util.writetxt(f'{dir_out}/sif.txt', p.representation)
    # TODO(ldif-user) Set up the unnormalized2normalized path.
    path_to_tx = '/ROOT_DIR/%s/occnet_to_gaps.txt' % key
    occnet_to_gaps = file_util.read_txt_to_np(path_to_tx).reshape([4, 4])
    pm = mesh_util.deserialize(p.mesh)
    pm.apply_transform(occnet_to_gaps)
    file_util.write_mesh(f'{dir_out}/nrm_pred_mesh.ply', pm)
    gtm = mesh_util.deserialize(p.gt_mesh)
    gtm.apply_transform(occnet_to_gaps)
    file_util.write_mesh(f'{dir_out}/nrm_gt_mesh.ply', gtm)
예제 #6
0
def read_mesh(path):
    with base_util.FS.open(path, 'rb') as f:
        mesh_str = f.read()
    return mesh_util.deserialize(mesh_str)