def compute_distance_matrix(starts, ends, distance_fn): """Computes all-pair distance matrix. Computes distance matrix as: [d(s_1, e_1), d(s_1, e_2), ..., d(s_1, e_N)] [d(s_2, e_1), d(s_2, e_2), ..., d(s_2, e_N)] [..., ..., ..., ... ] [d(s_M, e_1), d(s_2, e_2), ..., d(s_2, e_N)] Args: starts: A tensor for starts. Shape = [num_starts, ...]. ends: A tensor for ends. Shape = [num_ends, ...]. distance_fn: A function handle for computing distance matrix, which takes two matrix tensors and returns an element-wise distance matrix tensor. Returns: A tensor for distance matrix. Shape = [num_starts, num_ends, ...]. """ starts = tf.expand_dims(starts, axis=1) ends = tf.expand_dims(ends, axis=0) starts = data_utils.tile_first_dims( starts, first_dim_multiples=[1, tf.shape(ends)[1]]) ends = data_utils.tile_first_dims( ends, first_dim_multiples=[tf.shape(starts)[0], 1]) return distance_fn(starts, ends)
def expand_and_tile_axis_01(x, target_axis, target_dim): """Expands and tiles tensor along target axis 0 or 1.""" if target_axis not in [0, 1]: raise ValueError('Only supports 0 or 1 as target axis: %s.' % str(target_axis)) x = tf.expand_dims(x, axis=target_axis) first_dim_multiples = [1, 1] first_dim_multiples[target_axis] = target_dim return data_utils.tile_first_dims( x, first_dim_multiples=first_dim_multiples)
def test_tile_first_dims(self): # Shape = [1, 2, 1]. x = tf.constant([[[1], [2]]]) tiled_x = data_utils.tile_first_dims(x, first_dim_multiples=[2, 2]) self.assertAllEqual(tiled_x, [[[1], [2], [1], [2]], [[1], [2], [1], [2]]])