def softranks(self, axis, direction): """Test ops.softranks for a given shape, axis and direction.""" shape = tf.TensorShape((3, 8, 6)) n = shape[axis] p = int(np.prod(shape) / shape[axis]) # Build a target tensor of ranks, of rank 2. # Those targets are zero based. target = tf.constant([np.random.permutation(n) for _ in range(p)], dtype=tf.float32) # Turn it into a tensor of desired shape. target = ops._postprocess(target, shape, axis) # Apply a monotonic transformation to turn ranks into values sign = 2 * float(direction == 'ASCENDING') - 1 x = sign * (1.2 * target - 0.4) # The softranks of x along the axis should be close to the target. eps = 1e-3 sinkhorn_threshold = 1e-3 tolerance = 0.5 for zero_based in [False, True]: ranks = ops.softranks(x, direction=direction, axis=axis, zero_based=zero_based, epsilon=eps, sinkhorn_threshold=sinkhorn_threshold) targets = target + 1 if not zero_based else target self.assertAllClose(ranks, targets, tolerance, tolerance)
def _soft_topk_accuracy(self, y_true, y_pred): """Computes the soft topk accuracy of the prediction w.r.t the true values. Args: y_true: Tensor<float>[batch]: the true labels in [0, n-1]. y_pred: Tensor<float>[batch, n]: n activation values for each input. Returns: A Tensor<float>[batch] of accuracy per batch. """ num_activations = tf.shape(y_pred)[-1] topk = tf.cast(self._topk, dtype=y_pred.dtype) ranks = ops.softranks(y_pred, direction='ASCENDING', axis=-1, zero_based=True, **self._kwargs) # If the ranks are above topk then the accuracy is 1. Below that threshold # the accuracy decreases to 0. accuracies = tf.math.minimum( 1.0, ranks / (tf.cast(num_activations, dtype=y_pred.dtype) - topk)) # Multiply with the one hot encoding of the label to select only the soft # topk accuracy of the true labels. true_labels = tf.one_hot(tf.cast(y_true, dtype=tf.int32), depth=num_activations, dtype=y_pred.dtype) return tf.reduce_sum(accuracies * true_labels, axis=-1)
def call(self, inputs): outputs = ops.softranks(inputs, axis=self._axis, **self._kwargs) return tf.reshape(outputs, tf.shape(inputs))
def get_ranks(self, y): return ops.softranks(y, direction='ASCENDING', axis=-1, zero_based=True, **self._kwargs)