def test_symmetry_breaking_multiclass(self):
        mu = np.array([
            # LF 0
            [0.75, 0.15, 0.1],
            [0.20, 0.75, 0.3],
            [0.05, 0.10, 0.6],
            # LF 1
            [0.25, 0.55, 0.3],
            [0.15, 0.45, 0.4],
            [0.20, 0.00, 0.3],
            # LF 2
            [0.5, 0.15, 0.2],
            [0.3, 0.65, 0.2],
            [0.2, 0.20, 0.6],
        ])
        mu = mu[:, [1, 2, 0]]

        # First test: Two "good" LFs
        label_model = LabelModel(cardinality=3, verbose=False)
        label_model._set_class_balance(None, None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.75)
        self.assertEqual(label_model.mu.data[1, 1], 0.75)

        # Test with non-uniform class balance
        # It should not consider the "correct" permutation as it does not commute
        label_model = LabelModel(cardinality=3, verbose=False)
        label_model._set_class_balance([0.7, 0.2, 0.1], None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.15)
        self.assertEqual(label_model.mu.data[1, 1], 0.3)
    def test_symmetry_breaking(self):
        mu = np.array([
            # LF 0
            [0.75, 0.25],
            [0.25, 0.75],
            # LF 1
            [0.25, 0.75],
            [0.15, 0.25],
            # LF 2
            [0.75, 0.25],
            [0.25, 0.75],
        ])
        mu = mu[:, [1, 0]]

        # First test: Two "good" LFs
        label_model = LabelModel(verbose=False)
        label_model._set_class_balance(None, None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.75)

        # Test with non-uniform class balance
        # It should not consider the "correct" permutation as does not commute now
        label_model = LabelModel(verbose=False)
        label_model._set_class_balance([0.9, 0.1], None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.25)
 def _set_up_model(self, L: np.ndarray, class_balance: List[float] = [0.5, 0.5]):
     label_model = LabelModel(cardinality=2, verbose=False)
     label_model.train_config = TrainConfig()  # type: ignore
     L_aug = L + 1
     label_model._set_constants(L_aug)
     label_model._create_tree()
     label_model._generate_O(L_aug)
     label_model._build_mask()
     label_model._get_augmented_label_matrix(L_aug)
     label_model._set_class_balance(class_balance=class_balance, Y_dev=None)
     label_model._init_params()
     return label_model
    def test_class_balance(self):
        label_model = LabelModel(cardinality=2, verbose=False)
        # Test class balance
        Y_dev = np.array([0, 0, 1, 1, 0, 0, 0, 0, 1, 1])
        label_model._set_class_balance(class_balance=None, Y_dev=Y_dev)
        np.testing.assert_array_almost_equal(label_model.p, np.array([0.6, 0.4]))

        class_balance = np.array([0.0, 1.0])
        with self.assertRaisesRegex(ValueError, "Class balance prior is 0"):
            label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev)

        class_balance = np.array([0.0])
        with self.assertRaisesRegex(ValueError, "class_balance has 1 entries."):
            label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev)

        Y_dev_one_class = np.array([0, 0, 0])
        with self.assertRaisesRegex(
            ValueError, "Does not match LabelModel cardinality"
        ):
            label_model._set_class_balance(class_balance=None, Y_dev=Y_dev_one_class)