def test_symmetry_breaking_multiclass(self):
        mu = np.array([
            # LF 0
            [0.75, 0.15, 0.1],
            [0.20, 0.75, 0.3],
            [0.05, 0.10, 0.6],
            # LF 1
            [0.25, 0.55, 0.3],
            [0.15, 0.45, 0.4],
            [0.20, 0.00, 0.3],
            # LF 2
            [0.5, 0.15, 0.2],
            [0.3, 0.65, 0.2],
            [0.2, 0.20, 0.6],
        ])
        mu = mu[:, [1, 2, 0]]

        # First test: Two "good" LFs
        label_model = LabelModel(cardinality=3, verbose=False)
        label_model._set_class_balance(None, None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.75)
        self.assertEqual(label_model.mu.data[1, 1], 0.75)

        # Test with non-uniform class balance
        # It should not consider the "correct" permutation as it does not commute
        label_model = LabelModel(cardinality=3, verbose=False)
        label_model._set_class_balance([0.7, 0.2, 0.1], None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.15)
        self.assertEqual(label_model.mu.data[1, 1], 0.3)
    def test_symmetry_breaking(self):
        mu = np.array([
            # LF 0
            [0.75, 0.25],
            [0.25, 0.75],
            # LF 1
            [0.25, 0.75],
            [0.15, 0.25],
            # LF 2
            [0.75, 0.25],
            [0.25, 0.75],
        ])
        mu = mu[:, [1, 0]]

        # First test: Two "good" LFs
        label_model = LabelModel(verbose=False)
        label_model._set_class_balance(None, None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.75)

        # Test with non-uniform class balance
        # It should not consider the "correct" permutation as does not commute now
        label_model = LabelModel(verbose=False)
        label_model._set_class_balance([0.9, 0.1], None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.25)
    def test_score(self):
        L = np.array([[1, 1, 0], [-1, -1, -1], [1, 0, 1]])
        Y = np.array([1, 0, 1])
        label_model = LabelModel(cardinality=2, verbose=False)
        label_model.fit(L, n_epochs=100)
        results = label_model.score(L, Y, metrics=["accuracy", "coverage"])
        np.testing.assert_array_almost_equal(label_model.predict(L),
                                             np.array([1, -1, 1]))

        results_expected = dict(accuracy=1.0, coverage=2 / 3)
        self.assertEqual(results, results_expected)

        L = np.array([[1, 0, 1], [1, 0, 1]])
        label_model = self._set_up_model(L)
        label_model.mu = nn.Parameter(label_model.mu_init.clone().clamp(
            0.01, 0.99))

        results = label_model.score(L, Y=np.array([0, 1]))
        results_expected = dict(accuracy=0.5)
        self.assertEqual(results, results_expected)

        results = label_model.score(L=L,
                                    Y=np.array([1, 0]),
                                    metrics=["accuracy", "f1"])
        results_expected = dict(accuracy=0.5, f1=2 / 3)
        self.assertEqual(results, results_expected)
    def test_loss(self):
        L = np.array([[0, -1, 0], [0, 1, -1]])
        label_model = LabelModel(cardinality=2, verbose=False)
        label_model.fit(L, n_epochs=1)
        label_model.mu = nn.Parameter(label_model.mu_init.clone() + 0.05)

        # l2_loss = l2*M*K*||mu - mu_init||_2 = 3*2*(0.05^2) = 0.03
        self.assertAlmostEqual(label_model._loss_l2(l2=1.0).item(), 0.03)
        self.assertAlmostEqual(label_model._loss_l2(l2=np.ones(6)).item(), 0.03)

        # mu_loss = ||O - \mu^T P \mu||_2 + ||\mu^T P - diag(O)||_2
        self.assertAlmostEqual(label_model._loss_mu().item(), 0.675, 3)
    def test_predict(self):
        # 3 LFs that always disagree/abstain leads to all abstains
        L = np.array([[-1, 1, 0], [0, -1, 1], [1, 0, -1]])
        label_model = LabelModel(cardinality=2, verbose=False)
        label_model.fit(L, n_epochs=100)
        np.testing.assert_array_almost_equal(label_model.predict(L),
                                             np.array([-1, -1, -1]))

        L = np.array([[0, 1, 0], [0, 1, 0]])
        label_model = self._set_up_model(L)

        label_model.mu = nn.Parameter(label_model.mu_init.clone().clamp(
            0.01, 0.99))
        preds = label_model.predict(L)

        true_preds = np.array([0, 0])
        np.testing.assert_array_equal(preds, true_preds)

        preds, probs = label_model.predict(L, return_probs=True)
        true_probs = np.array([[0.99, 0.01], [0.99, 0.01]])
        np.testing.assert_array_almost_equal(probs, true_probs)