def test_batch_gather_by_one_hot_with_indices_out_of_range(self): params = tf.constant([ [ [0.0, 0.0], # [0.1, -0.1], # [0.2, -0.2], # [0.3, -0.3], # ], [ [1.0, -1.0], # [1.1, -1.1], # [1.2, -1.2], # [1.3, -1.3], # ] ]) self.assertAllClose( [ [ [0.2, -0.2], # [0.0, 0.0], # [0.1, -0.1], # ], [ [1.0, -1.0], # [1.3, -1.3], # [0.0, 0.0], # ] ], tensor_utils.batch_gather_by_one_hot(params, [[2, -1, 1], [0, 3, 4]]))
def test_batch_gather_by_one_hot_2_batch_dims_plus_2d_params( self, gather_fn): params = tf.constant([ [ [ [0.0, 0.0], # [0.1, -0.1], # [0.2, -0.2], # [0.3, -0.3], # ], [ [1.0, -1.0], # [1.1, -1.1], # [1.2, -1.2], # [1.3, -1.3], # ] ], [ [ [2.0, -2.0], # [2.1, -2.1], # [2.2, -2.2], # [2.3, -2.3], # ], [ [3.0, -3.0], # [3.1, -3.1], # [3.2, -3.2], # [3.3, -3.3], # ] ] ]) # This case doesn't work for `tf.gather` so we don't use `gather_fn`. self.assertAllClose( [ [[0.2, -0.2], [1.1, -1.1]], # [[2.2, -2.2], [3.3, -3.3]] ], tensor_utils.batch_gather_by_one_hot(params, [[2, 1], [2, 3]], batch_dims=2)) self.assertAllClose( [ [ [[0.2, -0.2], [0.0, 0.0]], # [[1.2, -1.2], [1.3, -1.3]] ], [ [[2.1, -2.1], [2.0, -2.0]], # [[3.2, -3.2], [3.1, -3.1]] ] ], gather_fn( params, [ [[2, 0], [2, 3]], # [[1, 0], [2, 1]] ], batch_dims=2))
def test_batch_gather_by_one_hot_1_batch_dim_plus_1d_params( self, gather_fn): params = tf.constant([ [0.0, 0.1, 0.2, 0.3], # [1.0, 1.1, 1.2, 1.3], # ]) # This case doesn't work for `tf.gather` so we don't use `gather_fn`. self.assertAllClose([0.2, 1.1], tensor_utils.batch_gather_by_one_hot(params, [2, 1], batch_dims=1)) self.assertAllClose([[0.2, 0.0], [1.2, 1.3]], gather_fn(params, [[2, 0], [2, 3]], batch_dims=1)) self.assertAllClose( [ [[0.2, 0], [0.2, 0.3]], # [[1.1, 1.0], [1.2, 1.1]] ], gather_fn( params, [ [[2, 0], [2, 3]], # [[1, 0], [2, 1]] ], batch_dims=1))
def test_batch_gather_by_one_hot_with_default_batch_dim(self): params = tf.constant([ [ [0.0, 0.0], # [0.1, -0.1], # [0.2, -0.2], # [0.3, -0.3], # ], [ [1.0, -1.0], # [1.1, -1.1], # [1.2, -1.2], # [1.3, -1.3], # ] ]) # `batch_dim` should be inferred as 1. self.assertAllClose( [ [[0.2, -0.2], [0.0, 0.0]], # [[1.2, -1.2], [1.3, -1.3]] ], tensor_utils.batch_gather_by_one_hot(params, [[2, 0], [2, 3]]))