def test_softranks(self, axis, direction): """Test ops.softranks for a given shape, axis and direction.""" shape = tf.TensorShape((3, 8, 6)) n = shape[axis] p = int(np.prod(shape) / shape[axis]) # Build a target tensor of ranks, of rank 2. # Those targets are zero based. target = tf.constant([np.random.permutation(n) for _ in range(p)], dtype=tf.float32) # Turn it into a tensor of desired shape. dims = np.arange(shape.rank) dims[axis], dims[-1] = dims[-1], dims[axis] fake = tf.zeros(shape) transposition = tf.transpose(fake, dims).shape target = ops.postprocess(target, dims, transposition) # Apply a monotonic transformation to turn ranks into values sign = 2 * float(direction == 'ASCENDING') - 1 x = sign * (1.2 * target - 0.4) # The softranks of x along the axis should be close to the target. eps = 1e-3 sinkhorn_threshold = 1e-3 tolerance = 0.5 for zero_based in [False, True]: ranks = ops.softranks(x, direction=direction, axis=axis, zero_based=zero_based, epsilon=eps, threshold=sinkhorn_threshold) targets = target + 1 if not zero_based else target self.assertAllClose(ranks, targets, tolerance, tolerance)
def test_postprocess(self): """Tests that postprocess is the inverse of preprocess.""" shape = (4, 21, 7, 10) for i in range(1, len(shape)): x = tf.random.uniform(shape[:i]) for axis in range(x.shape.rank): y, transp, s = ops.preprocess(x, axis) z = ops.postprocess(y, transp, s) self.assertAllEqual(x, z)