def test_decay(self):
    sinkhorn_no_decay = sinkhorn.Sinkhorn1D(
        epsilon=1e-3, epsilon_0=1e-3, epsilon_decay=1.0, power=2.0,
        threshold=1e-3)
    sinkhorn_decay = sinkhorn.Sinkhorn1D(
        epsilon=1e-3, epsilon_0=1e-1, epsilon_decay=0.95, power=2.0,
        threshold=1e-3)

    sinkhorn_no_decay(self.x, self.y, self.a, self.b)
    sinkhorn_decay(self.x, self.y, self.a, self.b)
    self.assertLess(sinkhorn_decay.iterations, sinkhorn_no_decay.iterations)
Esempio n. 2
0
  def __init__(
      self, x=None, weights=None, num_targets=None, target_weights=None, y=None,
      descending=False, scale_input_fn=squash.group_rescale, **kwargs):
    """Initializes the internal state of the SoftSorter.

    Args:
     x: the Tensor<float>[batch, n] to be soft-sorted.
     weights: Tensor<float>[n] or None. May be given with a batch dimension.
      When these weights are uniform one recovers usual sorting behaviour.
      If left to None, we use uniform weights.
     num_targets: used when y is not assigned to set a uniform target vector on
      [0,1]. num_targets is equal to n by default. If smaller than n, it will
      lead to some quantization of the input vector.
     target_weights: vector of weights for each element in y. Uniform by
      default.
     y: Tensor<float>[m] or None. May be given with a batch dimension.
      We encourage the default use of None for most of the cases. If provided,
      the list, array or tensor must be sorted in increasing order in order to
      perform a soft sort. If left to None, it will be set to num_targets values
      [0,1/(num_targets-1),...,1] copied N times.
     descending: (bool), if True, targets will be reversed so as to produce a
      decending sorting.
     scale_input_fn: function used to scale input entries so that they fit into
      the [0, 1] segment [0,1]. This is not only useful to stabilize
      computations but also to ensure that the regularization parameter epsilon
      is valid throughout gradient iterations, regardless of the variations of
      the input values'range.
     **kwargs: extra parameters to the Sinkhorn algorithm.
    """
    self._scale_input_fn = scale_input_fn
    self.iterations = 0
    self._descending = descending
    self._sinkhorn = sinkhorn.Sinkhorn1D(**kwargs)
    self.reset(x, y, weights, target_weights, num_targets)
 def test_routine(self):
   sinkhorn1d = sinkhorn.Sinkhorn1D()
   p = sinkhorn1d(self.x, self.y, self.a, self.b)
   self.assertEqual(p.shape.as_list(), [1, 8, 8])