def compute_distance_matrix(starts, ends, distance_fn):
    """Computes all-pair distance matrix.

  Computes distance matrix as:
    [d(s_1, e_1),  d(s_1, e_2),  ...,  d(s_1, e_N)]
    [d(s_2, e_1),  d(s_2, e_2),  ...,  d(s_2, e_N)]
    [...,          ...,          ...,  ...        ]
    [d(s_M, e_1),  d(s_2, e_2),  ...,  d(s_2, e_N)]

  Args:
    starts: A tensor for starts. Shape = [num_starts, ...].
    ends: A tensor for ends. Shape = [num_ends, ...].
    distance_fn: A function handle for computing distance matrix, which takes
      two matrix tensors and returns an element-wise distance matrix tensor.

  Returns:
    A tensor for distance matrix. Shape = [num_starts, num_ends, ...].
  """
    starts = tf.expand_dims(starts, axis=1)
    ends = tf.expand_dims(ends, axis=0)
    starts = data_utils.tile_first_dims(
        starts, first_dim_multiples=[1, tf.shape(ends)[1]])
    ends = data_utils.tile_first_dims(
        ends, first_dim_multiples=[tf.shape(starts)[0], 1])
    return distance_fn(starts, ends)
Ejemplo n.º 2
0
 def expand_and_tile_axis_01(x, target_axis, target_dim):
     """Expands and tiles tensor along target axis 0 or 1."""
     if target_axis not in [0, 1]:
         raise ValueError('Only supports 0 or 1 as target axis: %s.' %
                          str(target_axis))
     x = tf.expand_dims(x, axis=target_axis)
     first_dim_multiples = [1, 1]
     first_dim_multiples[target_axis] = target_dim
     return data_utils.tile_first_dims(
         x, first_dim_multiples=first_dim_multiples)
Ejemplo n.º 3
0
 def test_tile_first_dims(self):
     # Shape = [1, 2, 1].
     x = tf.constant([[[1], [2]]])
     tiled_x = data_utils.tile_first_dims(x, first_dim_multiples=[2, 2])
     self.assertAllEqual(tiled_x,
                         [[[1], [2], [1], [2]], [[1], [2], [1], [2]]])