Пример #1
0
    def test_softranks(self, axis, direction):
        """Test ops.softranks for a given shape, axis and direction."""
        shape = tf.TensorShape((3, 8, 6))
        n = shape[axis]
        p = int(np.prod(shape) / shape[axis])

        # Build a target tensor of ranks, of rank 2.
        # Those targets are zero based.
        target = tf.constant([np.random.permutation(n) for _ in range(p)],
                             dtype=tf.float32)

        # Turn it into a tensor of desired shape.
        dims = np.arange(shape.rank)
        dims[axis], dims[-1] = dims[-1], dims[axis]
        fake = tf.zeros(shape)
        transposition = tf.transpose(fake, dims).shape
        target = ops.postprocess(target, dims, transposition)

        # Apply a monotonic transformation to turn ranks into values
        sign = 2 * float(direction == 'ASCENDING') - 1
        x = sign * (1.2 * target - 0.4)

        # The softranks of x along the axis should be close to the target.
        eps = 1e-3
        sinkhorn_threshold = 1e-3
        tolerance = 0.5
        for zero_based in [False, True]:
            ranks = ops.softranks(x,
                                  direction=direction,
                                  axis=axis,
                                  zero_based=zero_based,
                                  epsilon=eps,
                                  threshold=sinkhorn_threshold)
            targets = target + 1 if not zero_based else target
            self.assertAllClose(ranks, targets, tolerance, tolerance)
Пример #2
0
 def test_postprocess(self):
     """Tests that postprocess is the inverse of preprocess."""
     shape = (4, 21, 7, 10)
     for i in range(1, len(shape)):
         x = tf.random.uniform(shape[:i])
         for axis in range(x.shape.rank):
             y, transp, s = ops.preprocess(x, axis)
             z = ops.postprocess(y, transp, s)
             self.assertAllEqual(x, z)