def test_decay(self):
     """Tests that applying the epsilon decay scheme speeds up convergence."""
     decay_iterations = sinkhorn.sinkhorn_iterations(
         self.x, self.y, self.a, self.b, **self.decay)[-1]
     no_decay_iterations = sinkhorn.sinkhorn_iterations(
         self.x, self.y, self.a, self.b, **self.no_decay)[-1]
     self.assertLess(decay_iterations, no_decay_iterations)
 def test_sinkhorn(self):
     """Tests that the __call__ methods returns transport maps."""
     f, g, eps, cost, _, iterations = sinkhorn.sinkhorn_iterations(
         self.x, self.y, self.a, self.b, **self.decay)
     p = sinkhorn.transport(cost, f, g, eps)
     self.assertTupleEqual(p.shape, self.x.shape + (self.y.shape[1], ))
     self.assertLessEqual(iterations, self.decay['max_iterations'])