def test_sample_with_hard_constraints(self): f1 = np.concatenate([np.zeros(75, dtype=np.float32), np.ones(25, dtype=np.float32)], axis=0) f2 = np.concatenate([np.zeros(50, dtype=np.float32), np.ones(50, dtype=np.float32)], axis=0) f3 = np.concatenate([np.zeros(20, dtype=np.float32), np.ones(80, dtype=np.float32)], axis=0) data = np.stack([f1, f2, f3], axis=1) model = trepan.DiscreteModel() model.fit(data) constraints = [ ("left", trepan.Rule(0, 0.5, trepan.Rule.SplitType.BELOW)), ("right", trepan.Rule(2, 0.5, trepan.Rule.SplitType.BELOW)) ] oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) num_samples = 1000 samples = [] for _ in range(num_samples): samples.append(oracle.sample_with_constraints(model, constraints)) samples = np.stack(samples) self.assertTrue(np.all(samples[:, 0] == 0)) self.assertTrue(np.all(samples[:, 2] == 1)) p1_0 = np.sum(samples[:, 1] == 0) / samples.shape[0] p1_1 = np.sum(samples[:, 1] == 1) / samples.shape[0] self.assertTrue(0.4 <= p1_0 <= 0.6) self.assertTrue(0.4 <= p1_1 <= 0.6)
def test_train_impossible(self): data = np.random.uniform(0, 1, size=[100, 40]) labels = np.random.randint(0, 30, size=100) oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) tp = trepan.Trepan(data, labels, oracle, 5, 50) tp.train() def count_nodes(node): internal = 0 leafs = 0 if node is not None: if node.leaf: leafs += 1 else: internal += 1 x, y = count_nodes(node.left_child) internal += x leafs += y x, y = count_nodes(node.right_child) internal += x leafs += y return internal, leafs internal_count, _ = count_nodes(tp.root) self.assertEqual(internal_count, 5)
def test_sample_failure_mode(self): # test with a specific failure mode model = trepan.DiscreteModel() model.distributions = [ np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([0.5, 0.5], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([0.5, 0.5], dtype=np.float32), np.array([0.5, 0.5], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32) ] model.values = [ [0.0], [1.0], [1.0], [1.0], [1.0], [0.0, 1.0], [1.0], [1.0], [0.0, 1.0], [0.0, 1.0], [1.0], [1.0], [1.0], [0.0], [1.0] ] model.num_features = len(model.values) rule1 = trepan.Rule(1, 0.5, trepan.Rule.SplitType.ABOVE) rule1.add_split(5, 0.5, trepan.Rule.SplitType.BELOW) rule1.num_required = 2 rule2 = trepan.Rule(14, 0.5, trepan.Rule.SplitType.ABOVE) rule2.add_split(4, 0.5, trepan.Rule.SplitType.BELOW) rule2.num_required = 2 rule3 = trepan.Rule(13, 0.5, trepan.Rule.SplitType.ABOVE) rule3.add_split(7, 0.5, trepan.Rule.SplitType.ABOVE) rule3.add_split(8, 0.5, trepan.Rule.SplitType.BELOW) rule4 = trepan.Rule(3, 0.5, trepan.Rule.SplitType.ABOVE) rule4.add_split(10, 0.5, trepan.Rule.SplitType.ABOVE) rule4.num_required = 2 rule5 = trepan.Rule(2, 0.5, trepan.Rule.SplitType.ABOVE) rule5.add_split(9, 0.5, trepan.Rule.SplitType.ABOVE) constraints = [ ("left", trepan.Rule(6, 0.5, trepan.Rule.SplitType.ABOVE)), ("right", trepan.Rule(0, 0.5, trepan.Rule.SplitType.ABOVE)), ("left", rule1), ("left", rule2), ("left", rule3), ("left", trepan.Rule(12, 0.5, trepan.Rule.SplitType.ABOVE)), ("left", rule4), ("left", rule5) ] oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) oracle.sample_with_constraints(model, constraints)
def test_step_end(self): data = np.concatenate([np.zeros(100, dtype=np.float32), np.ones(100, dtype=np.float32)], axis=0) labels = cp.deepcopy(data) data = np.expand_dims(data, axis=1) oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) tp = trepan.Trepan(data, labels, oracle, 15, 50) tp.step() self.assertIsNotNone(tp.root.left_child) self.assertIsNotNone(tp.root.left_child.parent) self.assertIsNotNone(tp.root.right_child) self.assertIsNotNone(tp.root.right_child.parent) self.assertIsNone(tp.root.left_child.rule) self.assertIsNone(tp.root.right_child.rule) self.assertTrue(tp.queue.is_empty())
def test_sample_with_disj_constraints(self): f1 = np.concatenate([np.zeros(75, dtype=np.float32), np.ones(25, dtype=np.float32)], axis=0) f2 = np.concatenate([np.zeros(50, dtype=np.float32), np.ones(50, dtype=np.float32)], axis=0) f3 = np.concatenate([np.zeros(20, dtype=np.float32), np.ones(80, dtype=np.float32)], axis=0) data = np.stack([f1, f2, f3], axis=1) model = trepan.DiscreteModel() model.fit(data) rule = trepan.Rule(0, 0.5, trepan.Rule.SplitType.BELOW) rule.add_split(1, 0.5, trepan.Rule.SplitType.ABOVE) rule.add_split(2, 0.5, trepan.Rule.SplitType.BELOW) rule.num_required = 2 constraints = [ ("left", rule) ] oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) oracle.sample_with_constraints(model, constraints)
def test_prune_tree_none(self): p = trepan.Trepan.Node() p_lc = trepan.Trepan.Node() p_rc = trepan.Trepan.Node() p_lc.majority_class = "a" p_lc.fidelity = 1.0 p_lc.leaf = True p_rc.majority_class = "b" p_rc.fidelity = 1.0 p_rc.leaf = True p.majority_class = "a" p.fidelity = 1.0 p.leaf = False p.left_child = p_lc p_lc.parent = p p.right_child = p_rc p_rc.parent = p data = np.random.uniform(-1, 1, size=[100, 10]) labels = np.random.randint(0, 10, size=100) oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) tp = trepan.Trepan(data, labels, oracle, 15, 50) tp.root = p tp.prune() self.assertEqual(tp.root.majority_class, "a") self.assertEqual(tp.root.fidelity, 1) self.assertEqual(tp.root.left_child, p_lc) self.assertEqual(tp.root.right_child, p_rc)
def test_step_continue(self): data = np.concatenate([np.zeros(20, dtype=np.float32), np.ones(20, dtype=np.float32)], axis=0) data = np.expand_dims(data, axis=1) labels = np.concatenate([ np.zeros(10, dtype=np.float32), np.ones(10, dtype=np.float32), np.ones(10, dtype=np.float32) + 1, np.ones(10, dtype=np.float32) + 2 ], axis=0) oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05) tp = trepan.Trepan(data, labels, oracle, 15, 50) tp.step() self.assertIsNotNone(tp.root.left_child) self.assertIsNotNone(tp.root.left_child.parent) self.assertIsNotNone(tp.root.right_child) self.assertIsNotNone(tp.root.right_child.parent) self.assertIsNone(tp.root.left_child.rule) self.assertIsNone(tp.root.right_child.rule) self.assertEqual(tp.queue.size, 2)