Example #1
0
class test_Mondrian_Tree(unittest.TestCase):

    # TODO: Run tests with inputing data with numpy arrays

    def setUp(self):

        self.d = 3
        self.a = 1
        self.linear_dims = [[0, 1]] * self.d
        self.mt1 = Mondrian_Tree(self.linear_dims)

        self.n_points = 100
        self.n_labelled = 20
        random.seed(123)
        self.labels = [random.random() for i in range(self.n_labelled)]
        self.labelled_indices = range(self.n_labelled)
        self.data = []
        random.seed(1)
        for i in range(self.n_points):
            point = []
            for j in range(self.d):
                point.append(random.random())
            self.data.append(point)

    ###########################################

    # Basic life time updating tests

    def test_update_first_life_time_0p1(self):
        temp = Mondrian_Tree([[0, 1]] * 5)
        temp.update_life_time(0.1, set_seed=1)
        self.assertEqual(temp._num_leaves, 3)

    def test_update_first_life_time_0p5(self):
        temp = Mondrian_Tree([[0, 1]] * 5)
        temp.update_life_time(0.5, set_seed=1)
        self.assertEqual(temp._num_leaves, 8)

    def test_update_first_life_time_1(self):
        temp = Mondrian_Tree([[0, 1]] * 5)
        temp.update_life_time(1, set_seed=1)
        self.assertEqual(temp._num_leaves, 36)

    def test_update_life_time_tuned(self):

        self.mt1.update_life_time(self.a * self.n_labelled**(1 /
                                                             (2 + self.d) - 1),
                                  set_seed=1)
        self.assertEqual(self.mt1._num_leaves, 2)

    def test__test_point(self):
        with self.assertRaises(ValueError):
            self.mt1._test_point([])
            self.mt1._test_point([1, 2, 3, 4, 5])
        with self.assertRaises(TypeError):
            self.mt1._test_point(1)
            self.mt1._test_point({})
            self.mt1._test_point('abc')

        good_point = [1] * self.d
        self.assertEqual(good_point, self.mt1._test_point(good_point))
        self.assertEqual(good_point,
                         self.mt1._test_point(np.array(good_point)))

    ###########################################

    # Testing input data

    def test_input_data_empty_root(self):
        self.mt1.input_data([], [], [])
        self.assertEqual(self.mt1._num_points, 0)

    def test_input_data_empty(self):
        lbda = 1
        self.mt1.update_life_time(lbda, set_seed=1)
        self.mt1.input_data([], [], [])
        self.assertEqual(self.mt1._num_points, 0)

    def test_input_data_root_no_labels(self):
        self.mt1.input_data([[0] * self.d], [], [])
        self.assertEqual(self.mt1._num_points, 1)
        self.assertEqual(self.mt1.labels, [None])
        self.assertEqual(self.mt1._root.unlabelled_index, [0])
        self.assertEqual(self.mt1._root.labelled_index, [])

    def test_input_data_root_labels(self):
        self.mt1.input_data([[0] * self.d], [0], [3.141])
        self.assertEqual(self.mt1._num_points, 1)
        self.assertEqual(self.mt1.labels, [3.141])
        self.assertEqual(self.mt1._root.unlabelled_index, [])
        self.assertEqual(self.mt1._root.labelled_index, [0])

    def test_input_data(self):
        lbda = 0.5
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        self.assertEqual(self.mt1._num_points, self.n_points)
        self.assertEqual(self.mt1._num_labelled, self.n_labelled)

        self.mt1.make_full_leaf_list()
        for node in self.mt1._full_leaf_list:
            # print(len(node.labelled_index), len(node.unlabelled_index))

            linear_dims = node.linear_dims
            for ind in node.labelled_index:
                point = self.mt1.points[ind]
                for dim in range(self.d):
                    # print(linear_dims[dim][0], point[dim], linear_dims[dim][1])
                    self.assertTrue(point[dim] >= linear_dims[dim][0])
                    self.assertTrue(point[dim] <= linear_dims[dim][1])

            for ind in node.unlabelled_index:
                point = self.mt1.points[ind]
                for dim in range(self.d):
                    # print(linear_dims[dim][0], point[dim], linear_dims[dim][1])
                    self.assertTrue(point[dim] >= linear_dims[dim][0])
                    self.assertTrue(point[dim] <= linear_dims[dim][1])

    def test_update_life_time_with_data(self):
        lbda = 0.5
        lbda2 = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        # print(self.mt1._num_points)

        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.update_life_time(lbda2, set_seed=100)
        # print(self.mt1._num_points)

        self.mt1.make_full_leaf_list()
        for node in self.mt1._full_leaf_list:
            # print(len(node.labelled_index), len(node.unlabelled_index))

            linear_dims = node.linear_dims
            for ind in node.labelled_index:
                point = self.mt1.points[ind]
                for dim in range(self.d):
                    # print(linear_dims[dim][0], point[dim], linear_dims[dim][1])
                    self.assertTrue(point[dim] >= linear_dims[dim][0])
                    self.assertTrue(point[dim] <= linear_dims[dim][1])

            for ind in node.unlabelled_index:
                point = self.mt1.points[ind]
                for dim in range(self.d):
                    # print(linear_dims[dim][0], point[dim], linear_dims[dim][1])
                    self.assertTrue(point[dim] >= linear_dims[dim][0])
                    self.assertTrue(point[dim] <= linear_dims[dim][1])

    def test_input_data_with_numpy(self):
        lbda = 0.5
        self.mt1.update_life_time(lbda, set_seed=100)
        np_data = np.array(self.data)
        # print(np_data)
        np_labels = np.array(self.labels)
        self.mt1.input_data(np_data, self.labelled_indices, np_labels)

        self.assertEqual(self.mt1._num_points, self.n_points)
        self.assertEqual(self.mt1._num_labelled, self.n_labelled)

        self.mt1.make_full_leaf_list()
        for node in self.mt1._full_leaf_list:
            # print(len(node.labelled_index), len(node.unlabelled_index))

            linear_dims = node.linear_dims
            for ind in node.labelled_index:
                point = self.mt1.points[ind]
                for dim in range(self.d):
                    # print(linear_dims[dim][0], point[dim], linear_dims[dim][1])
                    self.assertTrue(point[dim] >= linear_dims[dim][0])
                    self.assertTrue(point[dim] <= linear_dims[dim][1])

            for ind in node.unlabelled_index:
                point = self.mt1.points[ind]
                for dim in range(self.d):
                    # print(linear_dims[dim][0], point[dim], linear_dims[dim][1])
                    self.assertTrue(point[dim] >= linear_dims[dim][0])
                    self.assertTrue(point[dim] <= linear_dims[dim][1])

    def test_label_point_root(self):
        val = 1
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.label_point(self.n_labelled, val)
        self.assertEqual(self.mt1.labels[self.n_labelled], val)
        self.assertEqual(self.mt1._num_labelled, self.n_labelled + 1)
        leaf = self.mt1._root.leaf_for_point(self.data[self.n_labelled])
        self.assertTrue(self.n_labelled in leaf.labelled_index)

    def test_label_point_empty(self):
        with self.assertRaises(RuntimeError):
            self.mt1.label_point(1, 1)

    def test_label_point_too_big(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        with self.assertRaises(ValueError):
            self.mt1.label_point(self.n_points + 1, 1)

    def test_label_point(self):
        val = 1
        lbda = 0.5
        seed = 1
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.label_point(self.n_labelled, val)
        self.assertEqual(self.mt1.labels[self.n_labelled], val)
        self.assertEqual(self.mt1._num_labelled, self.n_labelled + 1)
        leaf = self.mt1._root.leaf_for_point(self.data[self.n_labelled])
        # print(leaf.labelled_index)
        self.assertTrue(self.n_labelled in leaf.labelled_index)

    def test_add_data_point_bad(self):
        with self.assertRaises(TypeError):
            self.mt1.add_data_point(1)

    def test_add_data_point_empty(self):
        self.mt1.add_data_point([1] * self.d)
        self.assertEqual([[1] * self.d], self.mt1.points)
        self.assertEqual([None], self.mt1.labels)
        self.assertEqual(self.mt1._root.unlabelled_index, [0])
        self.assertEqual(self.mt1._root.labelled_index, [])

    def test_add_data_point_empty_2(self):
        self.mt1.add_data_point([1] * self.d, 2)
        self.assertEqual([[1] * self.d], self.mt1.points)
        self.assertEqual([2], self.mt1.labels)
        self.assertEqual(self.mt1._root.unlabelled_index, [])
        self.assertEqual(self.mt1._root.labelled_index, [0])

    def test_add_data_point(self):
        lbda = 0.5
        seed = 1
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.update_life_time(lbda, set_seed=seed)

        self.mt1.add_data_point([1] * self.d)
        self.assertEqual(self.data + [[1] * self.d], self.mt1.points)
        self.assertEqual(
            self.labels + [None] * (self.n_points - self.n_labelled) + [None],
            self.mt1.labels)
        leaf = self.mt1._root.leaf_for_point([1] * self.d)
        self.assertEqual(leaf.unlabelled_index[-1], self.n_points)
        self.assertEqual(len(self.mt1.points), self.mt1._num_points)

    def test_add_data_point_labelled(self):
        lbda = 0.5
        seed = 1
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.update_life_time(lbda, set_seed=seed)

        self.mt1.add_data_point([1] * self.d, 1)
        self.assertEqual(self.data + [[1] * self.d], self.mt1.points)
        self.assertEqual(
            self.labels + [None] * (self.n_points - self.n_labelled) + [1],
            self.mt1.labels)
        leaf = self.mt1._root.leaf_for_point([1] * self.d)
        self.assertEqual(leaf.labelled_index[-1], self.n_points)
        self.assertEqual(len(self.mt1.points), self.mt1._num_points)
        self.assertEqual(self.n_labelled + 1, self.mt1._num_labelled)

    ###########################################

    # Testing make leaf list

    def test_make_full_leaf_list_root(self):
        self.mt1.make_full_leaf_list()
        self.assertEqual(len(self.mt1._full_leaf_list), 1)

    def test_make_full_leaf_list(self):
        lbda = 1
        self.mt1.update_life_time(lbda, set_seed=1)
        self.mt1.make_full_leaf_list()
        # print(len(self.mt1._full_leaf_list))
        # for node in self.mt1._full_leaf_list:
        #     print(node.leaf_id)
        self.assertEqual(len(self.mt1._full_leaf_list), self.mt1._num_leaves)

    # Testing calculating leaf variances

    def test_make_full_leaf_var_list_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_var_list()
        self.assertEqual(utils.unbiased_var(self.labels),
                         self.mt1._full_leaf_var_list[0])

    def test_make_full_leaf_var_list_empty(self):
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_var_list()
        self.assertEqual(0, self.mt1._full_leaf_var_list[0])

    def test_make_full_leaf_var_list(self):
        lbda = 0.5
        self.mt1.update_life_time(lbda, set_seed=100)

        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_var_list()

        for i, node in enumerate(self.mt1._full_leaf_list):
            node_labels = [self.mt1.labels[x] for x in node.labelled_index]
            # print(node_labels)
            if len(node_labels) != 0:
                temp_mean = sum(node_labels) / len(node_labels)
                # print(temp_mean)
                temp_var = 1 / (len(node_labels) - 1) * sum(
                    [(x - temp_mean)**2 for x in node_labels])
                self.assertTrue(
                    abs(self.mt1._full_leaf_var_list[i] - temp_var) < 1e-9)
            else:
                self.assertEqual(self.mt1._full_leaf_var_list[i], 0)

    # Testing calculating leaf mean

    def test_make_full_leaf_mean_list_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_mean_list()
        self.assertEqual(
            sum(self.labels) / self.n_labelled,
            self.mt1._full_leaf_mean_list[0])

    def test_make_full_leaf_mean_list_empty(self):
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_mean_list()
        self.assertEqual(0, self.mt1._full_leaf_mean_list[0])

    def test_make_full_leaf_mean_list(self):
        lbda = 0.5
        self.mt1.update_life_time(lbda, set_seed=100)

        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_mean_list()

        for i, node in enumerate(self.mt1._full_leaf_list):
            node_labels = [self.mt1.labels[x] for x in node.labelled_index]
            # print(node_labels)
            if len(node_labels) != 0:
                temp_mean = sum(node_labels) / len(node_labels)
                # print(temp_mean)
                self.assertTrue(
                    abs(self.mt1._full_leaf_mean_list[i] - temp_mean) < 1e-9)
            else:
                self.assertEqual(self.mt1._full_leaf_mean_list[i], 0)

    # Testing calculating leaf marginal probabilities

    def test_make_full_leaf_marginal_list_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_marginal_list()
        self.assertEqual(1, self.mt1._full_leaf_marginal_list[0])

    def test_make_full_leaf_marginal_list_empty(self):
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_marginal_list()
        self.assertEqual(0, self.mt1._full_leaf_marginal_list[0])

    def test_make_full_leaf_marginal_list(self):
        lbda = 0.5
        self.mt1.update_life_time(lbda, set_seed=100)

        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.make_full_leaf_list()
        self.mt1.make_full_leaf_marginal_list()

        for i, node in enumerate(self.mt1._full_leaf_list):
            node_points = ([self.mt1.points[x] for x in node.labelled_index] +
                           [self.mt1.points[x] for x in node.unlabelled_index])
            # print(node_labels)
            temp_marginal = len(node_points) / self.mt1._num_points
            self.assertTrue(
                abs(self.mt1._full_leaf_marginal_list[i] -
                    temp_marginal) < 1e-9)

    ###########################################

    # Testing using predict and get_point_in_same_leaf

    def test_predict_and_get_point_in_same_leaf_bad_input(self):
        with self.assertRaises(TypeError):
            self.mt1.predict(1)
            self.mt1.get_points_in_same_leaf(1)

    def test_predict_and_get_point_in_same_leaf_bad_input_2(self):
        with self.assertRaises(TypeError):
            self.mt1.predict('abc')
            self.mt1.get_points_in_same_leaf('abc')

    def test_predict_and_get_point_in_same_leaf_bad_length(self):
        with self.assertRaises(ValueError):
            self.mt1.predict([])
            self.mt1.get_points_in_same_leaf([])

    def test_predict_empty(self):
        with self.assertWarns(UserWarning):
            self.mt1.predict([0.5] * self.d)

    def test_predict_no_leaf_list(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_point = []
        for j in range(self.d):
            new_point.append(random.random())
        pred = self.mt1.predict(new_point)
        # print(pred)

        node = self.mt1._root.leaf_for_point(new_point)
        vals = [self.labels[x] for x in node.labelled_index]
        # print(len(vals))
        self.assertEqual(pred, sum(vals) / len(vals))

    def test_predict_leaf_list(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_point = []
        for j in range(self.d):
            new_point.append(random.random())
        pred = self.mt1.predict(new_point)
        # print(pred)

        node = self.mt1._root.leaf_for_point(new_point)
        vals = [self.labels[x] for x in node.labelled_index]
        # print(len(vals))
        self.assertEqual(pred, sum(vals) / len(vals))

    def test_predict_leaf_list_numpy(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_point = []
        for j in range(self.d):
            new_point.append(random.random())
        pred = self.mt1.predict(new_point)
        # print(pred)

        node = self.mt1._root.leaf_for_point(np.array(new_point))
        vals = [self.labels[x] for x in node.labelled_index]
        # print(len(vals))
        self.assertEqual(pred, sum(vals) / len(vals))

    def test_predict_multi_values(self):
        num_preds = 10
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_points = []
        for i in range(num_preds):
            new_point = []
            for j in range(self.d):
                new_point.append(random.random())
            new_points.append(new_point)
        preds = self.mt1.predict(new_points)
        # print(pred)

        check_preds = []
        for i in range(num_preds):

            node = self.mt1._root.leaf_for_point(np.array(new_points[i]))
            vals = [self.labels[x] for x in node.labelled_index]
            check_preds.append(sum(vals) / len(vals))
        # print(len(vals))
        self.assertAlmostEqual(preds, check_preds)

    def test_get_point_in_same_leaf_labelled(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_point = []
        for j in range(self.d):
            new_point.append(random.random())
        labelled_list = self.mt1.get_points_in_same_leaf(new_point)

        node = self.mt1._root.leaf_for_point(new_point)
        self.assertEqual(labelled_list, node.labelled_index)

    def test_get_point_in_same_leaf_unlabelled(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_point = []
        for j in range(self.d):
            new_point.append(random.random())
        unlabelled_list = self.mt1.get_points_in_same_leaf(
            new_point, 'unlabelled')

        node = self.mt1._root.leaf_for_point(new_point)
        self.assertEqual(unlabelled_list, node.unlabelled_index)

    def test_get_point_in_same_leaf_bad_list_name(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=100)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)

        random.seed(seed)
        new_point = []
        for j in range(self.d):
            new_point.append(random.random())
        with self.assertWarns(UserWarning):
            labelled_list = self.mt1.get_points_in_same_leaf(
                new_point, 'neither')

        node = self.mt1._root.leaf_for_point(new_point)
        self.assertEqual(labelled_list, node.labelled_index)

    ###########################################

    # Testing theoretical bounds

    def test_expected_split_bound(self):
        reps = 100
        tot = 0
        lbda = 5
        d = 3
        for i in range(reps):
            temp = Mondrian_Tree([[0, 1]] * d)
            temp.update_life_time(lbda, set_seed=i)
            tot += temp._num_leaves - 1
        # print(tot/reps, ((1+lbda)*math.exp(1))**d)
        self.assertTrue(tot / reps < ((1 + lbda) * math.exp(1))**d)

    def test_expected_cell_diameter_bounds(self):
        tot = 0
        lbda = 5
        self.mt1.update_life_time(lbda, set_seed=1)
        self.mt1.make_full_leaf_list()
        for node in self.mt1._full_leaf_list:
            tot += node.calculate_cell_l2_diameter()**2

        # print(tot/len(self.mt1._full_leaf_list), 4*self.d/lbda**2)
        self.assertTrue(
            tot / len(self.mt1._full_leaf_list) < 4 * self.d / lbda**2)

    def test_probabilistic_cell_diameter_bounds(self):
        lbda = 5
        deltas = [0.1, 0.2, 0.3, 0.4, 0.5, 1, 1.5]
        self.mt1.update_life_time(lbda, set_seed=1)
        self.mt1.make_full_leaf_list()
        # print(len(self.mt1._full_leaf_list))
        for delta in deltas:
            tot = 0
            for node in self.mt1._full_leaf_list:
                tot += int(node.calculate_cell_l2_diameter() > delta)

            # print(tot/len(self.mt1._full_leaf_list),
            #     self.d * (1+ (lbda * delta)/math.sqrt(self.d))*
            #     math.exp(-lbda * delta/math.sqrt(self.d)))
            self.assertTrue(tot / len(self.mt1._full_leaf_list) < self.d *
                            (1 + (lbda * delta) / math.sqrt(self.d)) *
                            math.exp(-lbda * delta / math.sqrt(self.d)))

    def test_limit_number_of_leaves(self):
        # n_values = [1,10,100,1000,10000,100000]
        if False:
            n_values = [1, 10, 100, 1000]
            d_values = [1, 3, 5, 10, 20, 50]
            seed = 1
            for d_val in d_values:
                test_tree = Mondrian_Tree([[0, 1]] * d_val)
                print('d = {}'.format(d_val))
                for n_val in n_values:
                    test_tree.update_life_time(n_val**(1 / (2 + d_val)) - 1,
                                               set_seed=seed + n_val + d_val)
                    print(n_val, test_tree._num_leaves,
                          test_tree._num_leaves / n_val)

    ###########################################

    # Testing data driven default values for prediction and active learning variance

    def test_set_default_pred_global_mean_empty(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.make_full_leaf_list()
        self.mt1.prediction_default_value = 1
        self.mt1.set_default_pred_global_mean()
        self.assertEqual(self.mt1.prediction_default_value, 0)

    def test_set_default_pred_global_mean(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.set_default_pred_global_mean()
        self.assertEqual(self.mt1.prediction_default_value,
                         sum(self.labels) / self.n_labelled)

    def test_al_set_default_var_global_var_empty(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.make_full_leaf_list()
        self.mt1.al_default_var = 1
        self.mt1.al_set_default_var_global_var()
        self.assertEqual(self.mt1.al_default_var, 0)

    def test_al_set_default_var_global_var(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.make_full_leaf_list()
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_set_default_var_global_var()
        self.assertEqual(self.mt1.al_default_var,
                         utils.unbiased_var(self.labels))

    ###########################################

    # Testing active learning parts

    def test_al_calculate_leaf_proportions_empty_root(self):
        with self.assertWarns(UserWarning):
            self.mt1.al_calculate_leaf_proportions()
        self.assertEqual(self.mt1._al_proportions, [1])

    def test_al_calculate_leaf_proportions_empty(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        with self.assertWarns(UserWarning):
            self.mt1.al_calculate_leaf_proportions()
        self.assertEqual(self.mt1._al_proportions,
                         [1 / self.mt1._num_leaves] * self.mt1._num_leaves)

    def test_al_calculate_leaf_proportions_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.assertEqual(self.mt1._al_proportions, [1])

    def test_al_calculate_leaf_proportions(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        temp_prop_list = []
        for i, node in enumerate(self.mt1._full_leaf_list):
            node_points = [x for x in node.labelled_index
                           ] + [x for x in node.unlabelled_index]
            # print(len(node_points))
            # print(self.mt1._full_leaf_marginal_list[i])
            if len(node_points) != 0:
                temp_var = utils.unbiased_var(
                    [self.labels[x] for x in node.labelled_index])
                # print(self.mt1._al_proportions[i])
                # print(math.sqrt(temp_var * (len(node_labels)/self.n_points)))
                temp_prop_list.append(
                    math.sqrt(temp_var * (len(node_points) / self.n_points)))

            else:
                self.assertEqual(self.mt1._al_proportions[i], 0)

        normalizer = sum(temp_prop_list)
        temp_prop_list = [x / normalizer for x in temp_prop_list]
        # print(temp_prop_list)
        for i, val in enumerate(temp_prop_list):
            # print(val, self.mt1._al_proportions[i])
            self.assertTrue(abs(self.mt1._al_proportions[i] - val) < 1e-9)

    def test_al_calculate_leaf_number_new_labels_empty(self):
        with self.assertRaises(ValueError):
            self.mt1.al_calculate_leaf_number_new_labels(1)

    def test_al_calculate_leaf_number_new_labels_too_many_labelled(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        with self.assertRaises(ValueError):
            self.mt1.al_calculate_leaf_number_new_labels(1)

    def test_al_calculate_leaf_number_new_labels_too_few_points(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        with self.assertRaises(ValueError):
            self.mt1.al_calculate_leaf_number_new_labels(101)

    def test_al_calculate_leaf_number_new_labels_bad_round_by(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        with self.assertRaises(ValueError):
            self.mt1.al_calculate_leaf_number_new_labels(21, round_by='bad')

    def test_al_calculate_leaf_number_new_labels_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.mt1.al_calculate_leaf_number_new_labels(21)
        self.assertEqual(self.mt1._al_leaf_number_new_labels, [1])

    def test_al_calculate_leaf_number_new_labels_single(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        with self.assertWarns(UserWarning):
            self.mt1.al_calculate_leaf_number_new_labels(21)
        # print(self.mt1._al_leaf_number_new_labels)
        self.assertEqual(sum(self.mt1._al_leaf_number_new_labels), 1)

        # for i, node in enumerate(self.mt1._full_leaf_list):
        #     curr_num = len(node.labelled_index)
        #     tot_num = curr_num + self.mt1._al_leaf_number_new_labels[i]
        #     print(self.mt1._al_leaf_number_new_labels[i],tot_num, self.mt1._al_proportions[i])

    def test_al_calculate_leaf_number_new_labels_many(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.mt1.al_calculate_leaf_number_new_labels(40)
        # print(self.mt1._al_leaf_number_new_labels)
        self.assertEqual(sum(self.mt1._al_leaf_number_new_labels), 20)

        # for i, node in enumerate(self.mt1._full_leaf_list):
        #     curr_num = len(node.labelled_index)
        #     tot_num = curr_num + self.mt1._al_leaf_number_new_labels[i]
        #     print(self.mt1._al_leaf_number_new_labels[i],tot_num, self.mt1._al_proportions[i])

    # def test_al_calculate_leaf_number_new_labels_incomplete(self):
    # This is always a possible source of weird behaviour... Keep thinking about these heurisitcs
    #     self.assertTrue(False, 'Do more testing for this function. Both auto and visual. '
    #         'And read the function code for more bugs.')

    def test_al_calculate_point_probabilities_proportions_empty(self):
        with self.assertWarns(UserWarning):
            self.mt1.al_calculate_leaf_proportions()
            self.mt1.al_calculate_point_probabilities_proportions()
        self.assertEqual(self.mt1._al_point_weights_proportional, [])

    def test_al_calculate_point_probabilities_proportions_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.mt1.al_calculate_point_probabilities_proportions()
        self.assertEqual(self.mt1._al_point_weights_proportional, [None] * 20 +
                         [1 / (self.n_points - self.n_labelled)] *
                         (self.n_points - self.n_labelled))

    def test_al_calculate_point_probabilities_proportions(self):
        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.mt1.al_calculate_point_probabilities_proportions()
        for i, node in enumerate(self.mt1._full_leaf_list):
            tot = 0
            for ind in node.unlabelled_index:
                tot += self.mt1._al_point_weights_proportional[ind]

            self.assertTrue(abs(tot - self.mt1._al_proportions[i]) < 1e-9)

    def test_al_calculate_point_probabilities_adjustment_empty(self):
        with self.assertWarns(UserWarning):
            self.mt1.al_calculate_leaf_proportions()
        with self.assertRaises(ValueError):
            self.mt1.al_calculate_point_probabilities_adjustment(1)
        self.assertEqual(self.mt1._al_point_weights_adjustment, None)

    def test_al_calculate_point_probabilities_adjustment_root(self):
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.mt1.al_calculate_point_probabilities_adjustment(21)
        for val in self.mt1._al_point_weights_adjustment:
            if val is not None:
                self.assertAlmostEqual(val,
                                       1 / (self.n_points - self.n_labelled))

    def test_al_calculate_point_probabilities_adjustment(self):

        # TODO: Try and improve this test. It passing is not very informative

        lbda = 0.5
        seed = 1
        self.mt1.update_life_time(lbda, set_seed=seed)
        self.mt1.input_data(self.data, self.labelled_indices, self.labels)
        self.mt1.al_calculate_leaf_proportions()
        self.mt1.al_calculate_point_probabilities_adjustment(21)
        self.assertAlmostEqual(
            sum([x for x in self.mt1._al_point_weights_adjustment if x]), 1)

        # for val in self.mt1._al_point_weights_adjustment:
        #     if not None:
        #         print(val)

    ###########################################

    # Checking reasonable performance of tree

    def test_constant_value_performance_root(self):
        self.mt1.input_data(self.data, self.labelled_indices,
                            [1] * self.n_labelled)
        self.assertEqual(self.mt1.predict([0.5] * self.d), 1)

    def test_constant_value_performance(self):
        seed = 1
        self.mt1.input_data(self.data, self.labelled_indices,
                            [1] * self.n_labelled)
        self.mt1.update_life_time(0.5, set_seed=seed)
        self.assertEqual(self.mt1.predict([0.5] * self.d), 1)

    def test_OLS_noiseless_better_than_mean(self):
        seed = 1
        labs = []
        for i in self.data[:self.n_labelled]:
            labs.append(sum(i))
        self.mt1.input_data(self.data, self.labelled_indices, labs)
        self.mt1.update_life_time(0.5, set_seed=seed)
        self.mt1.make_full_leaf_list()
        # for i, node  in enumerate(self.mt1._full_leaf_list):
        #     print(len(node.labelled_index))
        test_labs = []
        for i in self.data[self.n_labelled:]:
            test_labs.append(sum(i))
        test_mean = sum(test_labs) / len(test_labs)
        SST = 0
        SSE = 0

        for point in self.data[self.n_labelled:]:
            SST += (sum(point) - test_mean)**2
            SSE += (sum(point) - self.mt1.predict(point))**2

        # print(SST, SSE)
        self.assertTrue(SST > SSE)

    def test_OLS_better_than_mean(self):
        seed = 1
        std = 0.1
        labs = []
        for i in self.data[:self.n_labelled]:
            val = sum(i) + random.normalvariate(0, std)
            # print(sum(i),val)
            labs.append(val)
        self.mt1.input_data(self.data, self.labelled_indices, labs)
        self.mt1.update_life_time(0.5, set_seed=seed)
        self.mt1.make_full_leaf_list()
        # for i, node  in enumerate(self.mt1._full_leaf_list):
        #     print(len(node.labelled_index))
        test_labs = []
        for i in self.data[self.n_labelled:]:
            test_labs.append(sum(i) + random.normalvariate(0, std))
        test_mean = sum(test_labs) / len(test_labs)
        SST = 0
        SSE = 0

        for point in self.data[self.n_labelled:]:
            SST += (sum(point) - test_mean)**2
            SSE += (sum(point) - self.mt1.predict(point))**2

        # print(SST, SSE)
        self.assertTrue(SST > SSE)
MT = Mondrian_Tree([[0,1]]*p)
MT.update_life_time(n**(1/(2+p))-1, set_seed=seed)
# print(MT._num_leaves)
MT.input_data(np.concatenate((X_train, X_test),axis=0), range(1000), y_train)

MT.make_full_leaf_list()
MT.make_full_leaf_var_list()
used_leaf_counter = 0
for node in MT._full_leaf_list:
    if len(node.labelled_index) != 0:
        print(len(node.labelled_index), node.calculate_cell_l2_diameter())
        # print('var = {}'.format(MT._full_leaf_var_list[node.full_leaf_list_pos]))
        used_leaf_counter += 1
print('number of used leaves = {}'.format(used_leaf_counter))
MT.al_set_default_var_global_var()
# print(MT.al_default_var)

MT.set_default_pred_global_mean()

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    MT_preds = MT.predict(X_test)
MT_preds = np.array(MT_preds)

print(sum(1/X_test.shape[0]*(y_test - MT_preds)**2))

BT = DecisionTreeRegressor(random_state=seed, max_leaf_nodes = used_leaf_counter)
BT.fit(X_train, y_train)
BT_preds = BT.predict(X_test)

print(sum(1/X_test.shape[0]*(y_test - BT_preds)**2))
Example #3
0
                # print(i)
                curr_num = len(node.labelled_index)
                tot_num = curr_num + MT_al._al_leaf_number_new_labels[i]
                # print(curr_num,tot_num, MT_al._al_proportions[i] * n_final,node.rounded_linear_dims(2))
                num_new_points = MT_al._al_leaf_number_new_labels[i]
                labels_to_add = node.pick_new_points(num_new_points,self_update = False, set_seed = tree_seed*i)
                # print(labels_to_add)
                new_labelled_points.extend(labels_to_add)
                for ind in labels_to_add:
                    MT_al.label_point(ind, y[ind])

            MT_al.set_default_pred_global_mean()

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                MT_al_preds = MT_al.predict(X_test)
            MT_al_preds = np.array(MT_al_preds)
            MT_al_MSE[n_final_ind] += sum(1/X_test.shape[0]*(y_test - MT_al_preds)**2)

            # print('Done MT_al')

            # MT_rn

            # MT_rn = Mondrian_Tree([[0,1]]*p)
            # MT_rn.update_life_time(n_final**(1/(2+p))-1, set_seed=tree_seed)
            # print(MT._num_leaves)
            MT_rn.input_data(X, range(n_final), y[:n_final])
            MT_rn.set_default_pred_global_mean()
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                MT_rn_preds = MT_rn.predict(X_test)
Example #4
0
    print(curr_num, tot_num, MT._al_proportions[i] * n_final,
          node.rounded_linear_dims(2))
    num_new_points = MT._al_leaf_number_new_labels[i]
    labels_to_add = node.pick_new_points(num_new_points,
                                         self_update=False,
                                         set_seed=seed * i)
    # print(labels_to_add)
    new_labelled_points.extend(labels_to_add)
    for ind in labels_to_add:
        MT.label_point(ind, y[ind])

MT.set_default_pred_global_mean()

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    MT_preds = MT.predict(X_test)
MT_preds = np.array(MT_preds)

print('MSE from AL = {}'.format(
    sum(1 / X_test.shape[0] * (y_test - MT_preds)**2)))

MT2 = Mondrian_Tree([[0, 1]] * p)
MT2.update_life_time(n_final**(1 / (2 + p)) - 1, set_seed=seed)
# print(MT._num_leaves)
MT2.input_data(np.concatenate((X_train, X_test), axis=0), range(n_final),
               y_train_rn)
MT2.set_default_pred_global_mean()
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    MT2_preds = MT2.predict(X_test)
MT2_preds = np.array(MT2_preds)