Exemplo n.º 1
0
def bipartite_match(
    distance_mat,
    num_valid_rows,
    top_k=-1):
  """Find bipartite matching based on a given distance matrix.

  A greedy bi-partite matching algorithm is used to obtain the matching with
  the (greedy) minimum distance.

  Args:
    distance_mat: A 2-D float tensor of shape `[num_rows, num_columns]`. It is a
      pair-wise distance matrix between the entities represented by each row and
      each column. It is an asymmetric matrix. The smaller the distance is, the
      more similar the pairs are. The bipartite matching is to minimize the
      distances.
    num_valid_rows: A scalar or a 1-D tensor with one element describing the
      number of valid rows of distance_mat to consider for the bipartite
      matching. If set to be negative, then all rows from `distance_mat` are
      used.
    top_k: A scalar that specifies the number of top-k matches to retrieve.
      If set to be negative, then is set according to the maximum number of
      matches from `distance_mat`.

  Returns:
    row_to_col_match_indices: A vector of length num_rows, which is the number
      of rows of the input `distance_matrix`. If `row_to_col_match_indices[i]`
      is not -1, row i is matched to column `row_to_col_match_indices[i]`.
    col_to_row_match_indices: A vector of length num_columns, which is the
      number of columns of the input ditance matrix.
      If `col_to_row_match_indices[j]` is not -1, column j is matched to row
      `col_to_row_match_indices[j]`.
  """
  result = gen_image_ops.bipartite_match(distance_mat, num_valid_rows, top_k)
  return result
Exemplo n.º 2
0
def bipartite_match(
    distance_mat,
    num_valid_rows,
    top_k=-1):
  """Find bipartite matching based on a given distance matrix.

  A greedy bi-partite matching algorithm is used to obtain the matching with
  the (greedy) minimum distance.

  Args:
    distance_mat: A 2-D float tensor of shape `[num_rows, num_columns]`. It is a
      pair-wise distance matrix between the entities represented by each row and
      each column. It is an asymmetric matrix. The smaller the distance is, the
      more similar the pairs are. The bipartite matching is to minimize the
      distances.
    num_valid_rows: A scalar or a 1-D tensor with one element describing the
      number of valid rows of distance_mat to consider for the bipartite
      matching. If set to be negative, then all rows from `distance_mat` are
      used.
    top_k: A scalar that specifies the number of top-k matches to retrieve.
      If set to be negative, then is set according to the maximum number of
      matches from `distance_mat`.

  Returns:
    row_to_col_match_indices: A vector of length num_rows, which is the number
      of rows of the input `distance_matrix`. If `row_to_col_match_indices[i]`
      is not -1, row i is matched to column `row_to_col_match_indices[i]`.
    col_to_row_match_indices: A vector of length num_columns, which is the
      number of columns of the input ditance matrix.
      If `col_to_row_match_indices[j]` is not -1, column j is matched to row
      `col_to_row_match_indices[j]`.
  """
  result = gen_image_ops.bipartite_match(distance_mat, num_valid_rows, top_k)
  return result