def test_symmetry_breaking(self): mu = np.array([ # LF 0 [0.75, 0.25], [0.25, 0.75], # LF 1 [0.25, 0.75], [0.15, 0.25], # LF 2 [0.75, 0.25], [0.25, 0.75], ]) mu = mu[:, [1, 0]] # First test: Two "good" LFs label_model = LabelModel(verbose=False) label_model._set_class_balance(None, None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.75) # Test with non-uniform class balance # It should not consider the "correct" permutation as does not commute now label_model = LabelModel(verbose=False) label_model._set_class_balance([0.9, 0.1], None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.25)
def test_count_accurate_lfs(self): mu = np.array([ # LF 0 [0.75, 0.25], [0.25, 0.75], # LF 1 [0.25, 0.75], [0.15, 0.25], # LF 2 [0.75, 0.25], [0.25, 0.75], ]) # First test: Two "good" LFs label_model = LabelModel(verbose=False) label_model._set_class_balance(None, None) label_model.m = 3 self.assertEqual(label_model._count_accurate_lfs(mu), 2) # Second test: Now they should all be "good" due to class balance, since we're # counting accuracy (not conditional probabilities) label_model = LabelModel(verbose=False) label_model._set_class_balance([0.9, 0.1], None) label_model.m = 3 self.assertEqual(label_model._count_accurate_lfs(mu), 3)
def test_class_balance(self): label_model = LabelModel(cardinality=2, verbose=False) # Test class balance Y_dev = np.array([0, 0, 1, 1, 0, 0, 0, 0, 1, 1]) label_model._set_class_balance(class_balance=None, Y_dev=Y_dev) np.testing.assert_array_almost_equal(label_model.p, np.array([0.6, 0.4]))
def _set_up_model(self, L: np.ndarray, class_balance: List[float] = [0.5, 0.5]): label_model = LabelModel(cardinality=2, verbose=False) label_model.train_config = TrainConfig() # type: ignore L_aug = L + 1 label_model._set_constants(L_aug) label_model._create_tree() label_model._generate_O(L_aug) label_model._build_mask() label_model._get_augmented_label_matrix(L_aug) label_model._set_class_balance(class_balance=class_balance, Y_dev=None) label_model._init_params() return label_model
def test_symmetry_breaking_multiclass(self): mu = np.array( [ # LF 0 [0.75, 0.15, 0.1], [0.20, 0.75, 0.3], [0.05, 0.10, 0.6], # LF 1 [0.25, 0.55, 0.3], [0.15, 0.45, 0.4], [0.20, 0.00, 0.3], # LF 2 [0.5, 0.15, 0.2], [0.3, 0.65, 0.2], [0.2, 0.20, 0.6], ] ) mu = mu[:, [1, 2, 0]] # First test: Two "good" LFs label_model = LabelModel(cardinality=3, verbose=False) label_model._set_class_balance(None, None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.75) self.assertEqual(label_model.mu.data[1, 1], 0.75) # Test with non-uniform class balance # It should not consider the "correct" permutation as it does not commute label_model = LabelModel(cardinality=3, verbose=False) label_model._set_class_balance([0.7, 0.2, 0.1], None) label_model.m = 3 label_model.mu = nn.Parameter(torch.from_numpy(mu)) label_model._break_col_permutation_symmetry() self.assertEqual(label_model.mu.data[0, 0], 0.15) self.assertEqual(label_model.mu.data[1, 1], 0.3)
def test_class_balance(self): label_model = LabelModel(cardinality=2, verbose=False) # Test class balance Y_dev = np.array([0, 0, 1, 1, 0, 0, 0, 0, 1, 1]) label_model._set_class_balance(class_balance=None, Y_dev=Y_dev) np.testing.assert_array_almost_equal(label_model.p, np.array([0.6, 0.4])) class_balance = np.array([0.0, 1.0]) with self.assertRaisesRegex(ValueError, "Class balance prior is 0"): label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev) class_balance = np.array([0.0]) with self.assertRaisesRegex(ValueError, "class_balance has 1 entries."): label_model._set_class_balance(class_balance=class_balance, Y_dev=Y_dev) Y_dev_one_class = np.array([0, 0, 0]) with self.assertRaisesRegex( ValueError, "Does not match LabelModel cardinality" ): label_model._set_class_balance(class_balance=None, Y_dev=Y_dev_one_class)