def test_augmented_L_construction(self): # 5 LFs n = 3 m = 5 k = 2 L = np.array([[0, 0, 0, 1, 0], [0, 1, 1, 0, -1], [0, 0, 0, 0, -1]]) L_shift = L + 1 lm = LabelModel(cardinality=k, verbose=False) lm._set_constants(L_shift) lm._create_tree() L_aug = lm._get_augmented_label_matrix(L_shift, higher_order=True) # Should have 10 columns: # - 5 * 2 = 10 for the sources self.assertEqual(L_aug.shape, (3, 10)) # 13 total nonzero entries self.assertEqual(L_aug.sum(), 13) # Next, check the singleton entries for i in range(n): for j in range(m): if L_shift[i, j] > 0: self.assertEqual(L_aug[i, j * k + L_shift[i, j] - 1], 1) # Finally, check the clique entries # Singleton clique 1 self.assertEqual(len(lm.c_tree.node[1]["members"]), 1) j = lm.c_tree.node[1]["start_index"] self.assertEqual(L_aug[0, j], 1) # Singleton clique 2 self.assertEqual(len(lm.c_tree.node[2]["members"]), 1) j = lm.c_tree.node[2]["start_index"] self.assertEqual(L_aug[0, j + 1], 0)
def _set_up_model(self, L: np.ndarray, class_balance: List[float] = [0.5, 0.5]): label_model = LabelModel(cardinality=2, verbose=False) label_model.train_config = TrainConfig() # type: ignore L_aug = L + 1 label_model._set_constants(L_aug) label_model._create_tree() label_model._generate_O(L_aug) label_model._build_mask() label_model._get_augmented_label_matrix(L_aug) label_model._set_class_balance(class_balance=class_balance, Y_dev=None) label_model._init_params() return label_model