def test_plot_error(self, mock_show): all_datasets = [generate_dataset(intercept=i, coeff=i, size=50, min_x=( i-1)*10, max_x=i*10) for i in range(1, 9)] dataset = sum(all_datasets, []) reg = compute_regression(dataset) reg.plot_error() reg.plot_error(log=True) reg.plot_error(log_x=True) reg.plot_error(log_y=True)
def generic_multiplesplits(self, cls, repeat): self.maxDiff = None all_datasets = [generate_dataset(intercept=i, coeff=i, size=50, min_x=( i-1)*10, max_x=i*10, cls=cls, repeat=repeat) for i in range(1, 9)] dataset = sum(all_datasets, []) reg = compute_regression(dataset) flat_reg = reg.flatify() self.assertEqual(list(flat_reg), list(sorted(dataset))) # TODO should be 7, but is 8 in reality because of the non-optimality of the algorithm self.assertEqual(reg.nb_params, flat_reg.nb_params) self.assertEqual(reg.breakpoints, flat_reg.breakpoints) self.assertTrue(flat_reg.null_RSS) self.assertTrue(flat_reg.rss_equal(flat_reg.RSS, 0)) self.assertIn(len(flat_reg.breakpoints), (7, 8)) self.assertAlmostIncluded( range(10, 80, 10), flat_reg.breakpoints, epsilon=2) for x, y in dataset: prediction = flat_reg.predict(x) self.assertAlmostEqual(y, prediction) other_flat = compute_regression(dataset, breakpoints=flat_reg.breakpoints) self.assertEqual(str(other_flat), str(flat_reg))
def test_nosplit(self): intercept = random.uniform(0, 100) coeff = random.uniform(0, 100) dataset = generate_dataset( intercept=intercept, coeff=coeff, size=50, min_x=0, max_x=100) reg = compute_regression(dataset) self.assertIsInstance(reg, Leaf) self.assertAlmostEqual(reg.intercept, intercept) self.assertAlmostEqual(reg.coeff, coeff) self.assertAlmostEqual(reg.RSS, 0, delta=1e-3) self.assertEqual(reg.breakpoints, []) self.assertEqual(list(reg), list(sorted(dataset)))
def test_plot_dataset(self, mock_show): all_datasets = [generate_dataset(intercept=i, coeff=i, size=50, min_x=( i-1)*10, max_x=i*10) for i in range(1, 9)] dataset = sum(all_datasets, []) reg = compute_regression(dataset) reg.plot_dataset() reg.plot_dataset(log=True) reg.plot_dataset(log_x=True) reg.plot_dataset(log_y=True) reg.plot_dataset(plot_merged_reg=True) reg.plot_dataset(color=False) reg.plot_dataset(color='green') reg.plot_dataset(color=['green', 'blue', 'red'])
def generic_multiplesplits(self, cls, repeat): all_datasets = [generate_dataset(intercept=i, coeff=i, size=50, min_x=( i-1)*10, max_x=i*10, cls=cls, repeat=repeat) for i in range(1, 9)] dataset = sum(all_datasets, []) reg = compute_regression(dataset) self.assertEqual(list(reg), list(sorted(dataset))) # TODO should be 7, but is 8 in reality because of the non-optimality of the algorithm self.assertIn(len(reg.breakpoints), (7, 8)) self.assertAlmostIncluded( range(10, 80, 10), reg.breakpoints, epsilon=2) for x, y in dataset: prediction = reg.predict(x) self.assertAlmostEqual(y, prediction)
def generic_multiplesplits_simplify(self, cls, repeat): self.maxDiff = None all_datasets = [generate_dataset(intercept=i, coeff=i, size=50, min_x=( i-1)*10, max_x=i*10, cls=cls, repeat=repeat) for i in range(1, 9)] dataset = sum(all_datasets, []) reg = compute_regression(dataset) merged = reg.merge() simple_df = reg.simplify() self.assertEqual(len(simple_df), len(reg.breakpoints)+1) self.assertEqual(list(simple_df.nb_breakpoints), list(range(len(reg.breakpoints), -1, -1))) self.assertTrue(reg.rss_equal(reg.RSS, simple_df.RSS[0])) self.assertTrue(reg.rss_equal(list(simple_df.RSS)[-1], merged.RSS)) self.assertTrue(reg.error_equal(reg.BIC, simple_df.BIC[0])) self.assertTrue(reg.error_equal(list(simple_df.BIC)[-1], merged.BIC)) for old_rss, new_rss in zip(simple_df.RSS, simple_df.RSS[1:]): if not reg.rss_equal(old_rss, new_rss): self.assertLess(old_rss, new_rss) for nb_breakpoints, new_reg in zip(simple_df.nb_breakpoints, simple_df.regression): self.assertEqual(list(reg), list(new_reg)) self.assertEqual(nb_breakpoints, len(new_reg.breakpoints)) self.assertTrue(set(new_reg.breakpoints) <= set(reg.breakpoints)) simple_reg = reg.auto_simplify() expected_reg = simple_df.regression[1] self.assertEqual(simple_reg.breakpoints, expected_reg.breakpoints) self.assertEqual(simple_reg.RSS, expected_reg.RSS) self.assertEqual(simple_reg.BIC, expected_reg.BIC) # Checking that the auto_simplify() is a fix-point new_reg = simple_reg.auto_simplify() self.assertEqual(simple_reg.breakpoints, new_reg.breakpoints) # Checking to_pandas method df = new_reg.to_pandas() self.assertEqual(len(df), len(new_reg.segments)) for (_, row), ((min_x, max_x), leaf) in zip(df.iterrows(), new_reg.segments): self.assertEqual(row['min_x'], min_x) self.assertEqual(row['max_x'], max_x) self.assertEqual(row['intercept'], leaf.intercept) self.assertEqual(row['coefficient'], leaf.coeff) self.assertEqual(row['RSS'], leaf.RSS) self.assertEqual(row['MSE'], leaf.MSE)
def test_singlesplit(self): intercept_1 = random.uniform(0, 50) coeff_1 = random.uniform(0, 50) intercept_2 = random.uniform(50, 100) coeff_2 = random.uniform(50, 100) split = random.uniform(30, 60) dataset1 = generate_dataset( intercept=intercept_1, coeff=coeff_1, size=50, min_x=0, max_x=split) dataset2 = generate_dataset( intercept=intercept_2, coeff=coeff_2, size=50, min_x=split, max_x=100) dataset = dataset1 + dataset2 random.shuffle(dataset) reg = compute_regression(dataset) self.assertIsInstance(reg, Node) self.assertAlmostEqual(reg.RSS, 0, delta=1e-3) self.assertIsInstance(reg.left, Leaf) self.assertAlmostEqual(reg.left.intercept, intercept_1) self.assertAlmostEqual(reg.left.coeff, coeff_1) self.assertIsInstance(reg.right, Leaf) self.assertAlmostEqual(reg.right.intercept, intercept_2) self.assertAlmostEqual(reg.right.coeff, coeff_2) self.assertEqual(reg.split, max(dataset1)[0]) self.assertEqual(reg.breakpoints, [reg.split]) self.assertEqual(list(reg), list(sorted(dataset)))