def test_sinkhorn_knopp(self):
        # Epsilon = 1e-3
        sk = SinkhornKnopp()
        P = np.asarray([[1, 2], [3, 4]])
        Pt = sk.fit(P)
        self.assertFalse(np.all(P == Pt))

        # Make sure we stopped because of epsilon
        self.assertEqual(sk._stopping_condition, 'epsilon')
        col_sum = np.sum(Pt, axis=1)
        row_sum = np.sum(Pt, axis=0)

        self.assertTrue(np.all(1 + sk._epsilon > row_sum))
        self.assertTrue(np.all(1 - sk._epsilon < row_sum))
        self.assertTrue(np.all(1 + sk._epsilon > col_sum))
        self.assertTrue(np.all(1 - sk._epsilon < col_sum))

        # Epsilon = 1e-8
        sk = SinkhornKnopp(epsilon=1e-8)
        P = np.asarray([[1.4, .2, 4], [3, 4, .7], [.4, 6, 1]])
        Pt = sk.fit(P)
        self.assertFalse(np.all(P == Pt))

        # Make sure we stopped because of epsilon
        self.assertEqual(sk._stopping_condition, 'epsilon')

        col_sum = np.sum(Pt, axis=1)
        row_sum = np.sum(Pt, axis=0)

        self.assertTrue(np.all(1 + sk._epsilon > row_sum))
        self.assertTrue(np.all(1 - sk._epsilon < row_sum))
        self.assertTrue(np.all(1 + sk._epsilon > col_sum))
        self.assertTrue(np.all(1 - sk._epsilon < col_sum))
 def test_has_support_false_t_shape_sideways(self):
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         sk = SinkhornKnopp()
         P = np.asarray([[1, 1, 1, 1], [0, 0, 1, 0], [0, 0, 1, 0],
                         [0, 0, 1, 0]])
         self.assertFalse(sk.has_support(P))
         self.assertEqual(w[0].category, UserWarning)
 def test_diagonal_matrices(self):
     sk = SinkhornKnopp()
     P = np.eye(2)
     Pt = sk.fit(P)
     self.assertTrue(np.all(P == Pt))
     self.assertEqual(sk._D1.ndim, 2)
     self.assertEqual(sk._D2.ndim, 2)
     self.assertTrue(np.all(sk._D1 == P))
     self.assertTrue(np.all(sk._D2 == P))
 def test_diagonal_matrices(self):
     sk = SinkhornKnopp()
     P = np.eye(2)
     Pt = sk.fit(P)
     self.assertTrue(np.all(P == Pt))
     self.assertEqual(sk._D1.ndim, 2)
     self.assertEqual(sk._D2.ndim, 2)
     self.assertTrue(np.all(sk._D1 == P))
     self.assertTrue(np.all(sk._D2 == P))
    def test_stopping_condition(self):
        sk = SinkhornKnopp(max_iter=5)
        self.assertIsNone(sk._stopping_condition)

        P = np.eye(2)
        sk.fit(P)
        self.assertEqual(sk._stopping_condition, 'epsilon')

        P = np.array([[.011, .15], [1.71, .1]])
        sk.fit(P)
        self.assertEqual(sk._stopping_condition, 'max_iter')
        self.assertEqual(sk._iterations, 5)
 def test_support_warning(self):
     """
         A non-negative square is said to have total
         support if A =/= 0 and if every positive element of A
         lies on a positive diagonal. A diagonal of a matrix
         is defined, for any permutation of sigma = {1,...,N},
         as a[1,sigma(1)], ..., a[N,sigma(N)], where a[i, j]
         is an element of A at the ith row and jth column.
     """
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         sk = SinkhornKnopp()
         P = np.array([[0, 0], [1, 0]])
         sk.fit(P)
         self.assertEqual(w[0].category, UserWarning)
 def test_support_warning(self):
     """
         A non-negative square is said to have total
         support if A =/= 0 and if every positive element of A
         lies on a positive diagonal. A diagonal of a matrix
         is defined, for any permutation of sigma = {1,...,N},
         as a[1,sigma(1)], ..., a[N,sigma(N)], where a[i, j]
         is an element of A at the ith row and jth column.
     """
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         sk = SinkhornKnopp()
         P = np.array([[0, 0], [1, 0]])
         sk.fit(P)
         self.assertEqual(w[0].category, UserWarning)
    def test_init(self):
        sk = SinkhornKnopp()
        self.assertIsNotNone(sk)

        self.assertEqual(sk._max_iter, 1000)
        self.assertEqual(sk._epsilon, 1e-3)

        self.assertRaises(AssertionError, SinkhornKnopp, max_iter='1')
        self.assertRaises(AssertionError, SinkhornKnopp, epsilon='1')
        self.assertRaises(AssertionError, SinkhornKnopp, max_iter=-1)
        self.assertRaises(AssertionError, SinkhornKnopp, epsilon=1)
        self.assertRaises(AssertionError, SinkhornKnopp, epsilon=0)
    def test_sinkhorn_knopp(self):
        # Epsilon = 1e-3
        sk = SinkhornKnopp()
        P = np.asarray([[1, 2], [3, 4]])
        Pt = sk.fit(P)
        self.assertFalse(np.all(P == Pt))

        # Make sure we stopped because of epsilon
        self.assertEqual(sk._stopping_condition, 'epsilon')
        col_sum = np.sum(Pt, axis=1)
        row_sum = np.sum(Pt, axis=0)

        self.assertTrue(np.all(1 + sk._epsilon > row_sum))
        self.assertTrue(np.all(1 - sk._epsilon < row_sum))
        self.assertTrue(np.all(1 + sk._epsilon > col_sum))
        self.assertTrue(np.all(1 - sk._epsilon < col_sum))

        # Epsilon = 1e-8
        sk = SinkhornKnopp(epsilon=1e-8)
        P = np.asarray([[1.4, .2, 4], [3, 4, .7], [.4, 6, 1]])
        Pt = sk.fit(P)
        self.assertFalse(np.all(P == Pt))

        # Make sure we stopped because of epsilon
        self.assertEqual(sk._stopping_condition, 'epsilon')

        col_sum = np.sum(Pt, axis=1)
        row_sum = np.sum(Pt, axis=0)

        self.assertTrue(np.all(1 + sk._epsilon > row_sum))
        self.assertTrue(np.all(1 - sk._epsilon < row_sum))
        self.assertTrue(np.all(1 + sk._epsilon > col_sum))
        self.assertTrue(np.all(1 - sk._epsilon < col_sum))
    def test_stopping_condition(self):
        sk = SinkhornKnopp(max_iter=5)
        self.assertIsNone(sk._stopping_condition)

        P = np.eye(2)
        sk.fit(P)
        self.assertEqual(sk._stopping_condition, 'epsilon')

        P = np.array([[.011, .15], [1.71, .1]])
        sk.fit(P)
        self.assertEqual(sk._stopping_condition, 'max_iter')
        self.assertEqual(sk._iterations, 5)
 def test_fit_inputs(self):
     sk = SinkhornKnopp()
     self.assertRaises(AssertionError, sk.fit, 1)
     self.assertRaises(AssertionError, sk.fit, [[1, 2], [1, 2], [1, 2]])
     self.assertRaises(AssertionError, sk.fit, [[1, 2], [1, -1]])