コード例 #1
0
    def test_sample_with_hard_constraints(self):

        f1 = np.concatenate([np.zeros(75, dtype=np.float32), np.ones(25, dtype=np.float32)], axis=0)
        f2 = np.concatenate([np.zeros(50, dtype=np.float32), np.ones(50, dtype=np.float32)], axis=0)
        f3 = np.concatenate([np.zeros(20, dtype=np.float32), np.ones(80, dtype=np.float32)], axis=0)

        data = np.stack([f1, f2, f3], axis=1)

        model = trepan.DiscreteModel()
        model.fit(data)

        constraints = [
            ("left", trepan.Rule(0, 0.5, trepan.Rule.SplitType.BELOW)),
            ("right", trepan.Rule(2, 0.5, trepan.Rule.SplitType.BELOW))
        ]

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)

        num_samples = 1000
        samples = []

        for _ in range(num_samples):

            samples.append(oracle.sample_with_constraints(model, constraints))

        samples = np.stack(samples)

        self.assertTrue(np.all(samples[:, 0] == 0))
        self.assertTrue(np.all(samples[:, 2] == 1))

        p1_0 = np.sum(samples[:, 1] == 0) / samples.shape[0]
        p1_1 = np.sum(samples[:, 1] == 1) / samples.shape[0]

        self.assertTrue(0.4 <= p1_0 <= 0.6)
        self.assertTrue(0.4 <= p1_1 <= 0.6)
コード例 #2
0
    def test_train_impossible(self):

        data = np.random.uniform(0, 1, size=[100, 40])
        labels = np.random.randint(0, 30, size=100)

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)
        tp = trepan.Trepan(data, labels, oracle, 5, 50)

        tp.train()

        def count_nodes(node):

            internal = 0
            leafs = 0

            if node is not None:

                if node.leaf:
                    leafs += 1
                else:
                    internal += 1

                x, y = count_nodes(node.left_child)
                internal += x
                leafs += y

                x, y = count_nodes(node.right_child)
                internal += x
                leafs += y

            return internal, leafs

        internal_count, _ = count_nodes(tp.root)

        self.assertEqual(internal_count, 5)
コード例 #3
0
    def test_sample_failure_mode(self):

        # test with a specific failure mode
        model = trepan.DiscreteModel()
        model.distributions = [
            np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32),
            np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([0.5, 0.5], dtype=np.float32),
            np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([0.5, 0.5], dtype=np.float32),
            np.array([0.5, 0.5], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32),
            np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32), np.array([1.], dtype=np.float32)
        ]
        model.values = [
            [0.0], [1.0], [1.0], [1.0], [1.0], [0.0, 1.0], [1.0], [1.0], [0.0, 1.0], [0.0, 1.0], [1.0], [1.0], [1.0],
            [0.0], [1.0]
        ]
        model.num_features = len(model.values)

        rule1 = trepan.Rule(1, 0.5, trepan.Rule.SplitType.ABOVE)
        rule1.add_split(5, 0.5, trepan.Rule.SplitType.BELOW)
        rule1.num_required = 2

        rule2 = trepan.Rule(14, 0.5, trepan.Rule.SplitType.ABOVE)
        rule2.add_split(4, 0.5, trepan.Rule.SplitType.BELOW)
        rule2.num_required = 2

        rule3 = trepan.Rule(13, 0.5, trepan.Rule.SplitType.ABOVE)
        rule3.add_split(7, 0.5, trepan.Rule.SplitType.ABOVE)
        rule3.add_split(8, 0.5, trepan.Rule.SplitType.BELOW)

        rule4 = trepan.Rule(3, 0.5, trepan.Rule.SplitType.ABOVE)
        rule4.add_split(10, 0.5, trepan.Rule.SplitType.ABOVE)
        rule4.num_required = 2

        rule5 = trepan.Rule(2, 0.5, trepan.Rule.SplitType.ABOVE)
        rule5.add_split(9, 0.5, trepan.Rule.SplitType.ABOVE)

        constraints = [
            ("left", trepan.Rule(6, 0.5, trepan.Rule.SplitType.ABOVE)),
            ("right", trepan.Rule(0, 0.5, trepan.Rule.SplitType.ABOVE)),
            ("left", rule1),
            ("left", rule2),
            ("left", rule3),
            ("left", trepan.Rule(12, 0.5, trepan.Rule.SplitType.ABOVE)),
            ("left", rule4),
            ("left", rule5)
        ]

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)

        oracle.sample_with_constraints(model, constraints)
コード例 #4
0
    def test_step_end(self):

        data = np.concatenate([np.zeros(100, dtype=np.float32), np.ones(100, dtype=np.float32)], axis=0)
        labels = cp.deepcopy(data)
        data = np.expand_dims(data, axis=1)

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)
        tp = trepan.Trepan(data, labels, oracle, 15, 50)

        tp.step()

        self.assertIsNotNone(tp.root.left_child)
        self.assertIsNotNone(tp.root.left_child.parent)
        self.assertIsNotNone(tp.root.right_child)
        self.assertIsNotNone(tp.root.right_child.parent)

        self.assertIsNone(tp.root.left_child.rule)
        self.assertIsNone(tp.root.right_child.rule)

        self.assertTrue(tp.queue.is_empty())
コード例 #5
0
    def test_sample_with_disj_constraints(self):

        f1 = np.concatenate([np.zeros(75, dtype=np.float32), np.ones(25, dtype=np.float32)], axis=0)
        f2 = np.concatenate([np.zeros(50, dtype=np.float32), np.ones(50, dtype=np.float32)], axis=0)
        f3 = np.concatenate([np.zeros(20, dtype=np.float32), np.ones(80, dtype=np.float32)], axis=0)

        data = np.stack([f1, f2, f3], axis=1)

        model = trepan.DiscreteModel()
        model.fit(data)

        rule = trepan.Rule(0, 0.5, trepan.Rule.SplitType.BELOW)
        rule.add_split(1, 0.5, trepan.Rule.SplitType.ABOVE)
        rule.add_split(2, 0.5, trepan.Rule.SplitType.BELOW)
        rule.num_required = 2

        constraints = [
            ("left", rule)
        ]

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)

        oracle.sample_with_constraints(model, constraints)
コード例 #6
0
    def test_prune_tree_none(self):

        p = trepan.Trepan.Node()
        p_lc = trepan.Trepan.Node()
        p_rc = trepan.Trepan.Node()

        p_lc.majority_class = "a"
        p_lc.fidelity = 1.0
        p_lc.leaf = True

        p_rc.majority_class = "b"
        p_rc.fidelity = 1.0
        p_rc.leaf = True

        p.majority_class = "a"
        p.fidelity = 1.0
        p.leaf = False

        p.left_child = p_lc
        p_lc.parent = p

        p.right_child = p_rc
        p_rc.parent = p

        data = np.random.uniform(-1, 1, size=[100, 10])
        labels = np.random.randint(0, 10, size=100)

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)
        tp = trepan.Trepan(data, labels, oracle, 15, 50)
        tp.root = p

        tp.prune()

        self.assertEqual(tp.root.majority_class, "a")
        self.assertEqual(tp.root.fidelity, 1)
        self.assertEqual(tp.root.left_child, p_lc)
        self.assertEqual(tp.root.right_child, p_rc)
コード例 #7
0
    def test_step_continue(self):

        data = np.concatenate([np.zeros(20, dtype=np.float32), np.ones(20, dtype=np.float32)], axis=0)
        data = np.expand_dims(data, axis=1)

        labels = np.concatenate([
            np.zeros(10, dtype=np.float32), np.ones(10, dtype=np.float32),
            np.ones(10, dtype=np.float32) + 1, np.ones(10, dtype=np.float32) + 2
        ], axis=0)

        oracle = trepan.Oracle(lambda x: x[:, 0], trepan.Oracle.DataType.DISCRETE, 0.05, 0.05)
        tp = trepan.Trepan(data, labels, oracle, 15, 50)

        tp.step()

        self.assertIsNotNone(tp.root.left_child)
        self.assertIsNotNone(tp.root.left_child.parent)
        self.assertIsNotNone(tp.root.right_child)
        self.assertIsNotNone(tp.root.right_child.parent)

        self.assertIsNone(tp.root.left_child.rule)
        self.assertIsNone(tp.root.right_child.rule)

        self.assertEqual(tp.queue.size, 2)