def test_decay(self): """Tests that applying the epsilon decay scheme speeds up convergence.""" decay_iterations = sinkhorn.sinkhorn_iterations( self.x, self.y, self.a, self.b, **self.decay)[-1] no_decay_iterations = sinkhorn.sinkhorn_iterations( self.x, self.y, self.a, self.b, **self.no_decay)[-1] self.assertLess(decay_iterations, no_decay_iterations)
def test_sinkhorn(self): """Tests that the __call__ methods returns transport maps.""" f, g, eps, cost, _, iterations = sinkhorn.sinkhorn_iterations( self.x, self.y, self.a, self.b, **self.decay) p = sinkhorn.transport(cost, f, g, eps) self.assertTupleEqual(p.shape, self.x.shape + (self.y.shape[1], )) self.assertLessEqual(iterations, self.decay['max_iterations'])