def test_count_accurate_lfs(self): mu = np.array([ # LF 0 [0.75, 0.25], [0.25, 0.75], # LF 1 [0.25, 0.75], [0.15, 0.25], # LF 2 [0.75, 0.25], [0.25, 0.75], ]) # First test: Two "good" LFs label_model = LabelModel(verbose=False) label_model._set_class_balance(None, None) label_model.m = 3 self.assertEqual(label_model._count_accurate_lfs(mu), 2) # Second test: Now they should all be "good" due to class balance, since we're # counting accuracy (not conditional probabilities) label_model = LabelModel(verbose=False) label_model._set_class_balance([0.9, 0.1], None) label_model.m = 3 self.assertEqual(label_model._count_accurate_lfs(mu), 3)
def test_symmetry_breaking(self): mu = np.array([ # LF 0 [0.75, 0.25], [0.25, 0.75], # LF 1 [0.25, 0.75], [0.15, 0.25], # LF 2 [0.75, 0.25], [0.25, 0.75], ]) mu = mu[:, [1, 0]] # First test: Two "good" LFs label_model = LabelModel(verbose=False) label_model._set_class_balance(None, None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.75) # Test with non-uniform class balance # It should not consider the "correct" permutation as does not commute now label_model = LabelModel(verbose=False) label_model._set_class_balance([0.9, 0.1], None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.25)
def test_symmetry_breaking_multiclass(self): mu = np.array( [ # LF 0 [0.75, 0.15, 0.1], [0.20, 0.75, 0.3], [0.05, 0.10, 0.6], # LF 1 [0.25, 0.55, 0.3], [0.15, 0.45, 0.4], [0.20, 0.00, 0.3], # LF 2 [0.5, 0.15, 0.2], [0.3, 0.65, 0.2], [0.2, 0.20, 0.6], ] ) mu = mu[:, [1, 2, 0]] # First test: Two "good" LFs label_model = LabelModel(cardinality=3, verbose=False) label_model._set_class_balance(None, None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.75) self.assertEqual(label_model.mu.data[1, 1], 0.75) # Test with non-uniform class balance # It should not consider the "correct" permutation as it does not commute label_model = LabelModel(cardinality=3, verbose=False) label_model._set_class_balance([0.7, 0.2, 0.1], None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.15) self.assertEqual(label_model.mu.data[1, 1], 0.3)