Пример #1
0
    def test_symmetry_breaking_multiclass(self):
        mu = np.array([
            # LF 0
            [0.75, 0.15, 0.1],
            [0.20, 0.75, 0.3],
            [0.05, 0.10, 0.6],
            # LF 1
            [0.25, 0.55, 0.3],
            [0.15, 0.45, 0.4],
            [0.20, 0.00, 0.3],
            # LF 2
            [0.5, 0.15, 0.2],
            [0.3, 0.65, 0.2],
            [0.2, 0.20, 0.6],
        ])
        mu = mu[:, [1, 2, 0]]

        # First test: Two "good" LFs
        label_model = LabelModel(cardinality=3, verbose=False)
        label_model._set_class_balance(None, None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.75)
        self.assertEqual(label_model.mu.data[1, 1], 0.75)

        # Test with non-uniform class balance
        # It should not consider the "correct" permutation as it does not commute
        label_model = LabelModel(cardinality=3, verbose=False)
        label_model._set_class_balance([0.7, 0.2, 0.1], None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.15)
        self.assertEqual(label_model.mu.data[1, 1], 0.3)
Пример #2
0
    def test_symmetry_breaking(self):
        mu = np.array([
            # LF 0
            [0.75, 0.25],
            [0.25, 0.75],
            # LF 1
            [0.25, 0.75],
            [0.15, 0.25],
            # LF 2
            [0.75, 0.25],
            [0.25, 0.75],
        ])
        mu = mu[:, [1, 0]]

        # First test: Two "good" LFs
        label_model = LabelModel(verbose=False)
        label_model._set_class_balance(None, None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.75)

        # Test with non-uniform class balance
        # It should not consider the "correct" permutation as does not commute now
        label_model = LabelModel(verbose=False)
        label_model._set_class_balance([0.9, 0.1], None)
        label_model.m = 3
        label_model.mu = nn.Parameter(torch.from_numpy(mu))
        label_model._break_col_permutation_symmetry()
        self.assertEqual(label_model.mu.data[0, 0], 0.25)