def test_empty_test(self): dataset = Dataset_Counts(state_shape=[2], nb_actions=3, count_param=0.2) dataset.add(*transitions[0]) dataset.add(*transitions[1]) dataset_train, dataset_test = dataset.train_validation_split( test_size=0.5) s_train, a_train, p_train, r_train, s2_train, t_train, _, _, _ = dataset_train._get_transition( 0) s_test, a_test, p_test, r_test, s2_test, t_test, _, _, _ = dataset_test._get_transition( 0) if a_train == transitions[0].a: trans_train = transitions[0] trans_test = transitions[1] else: trans_train = transitions[1] trans_test = transitions[0] assert_sequence_almost_equal(self, s_train, trans_train.s) self.assertEqual(a_train, trans_train.a) assert_sequence_almost_equal(self, p_train, trans_train.p) self.assertAlmostEqual(r_train, trans_train.r) self.assertEqual(t_train, trans_train.t) assert_sequence_almost_equal(self, s_test, trans_test.s) self.assertEqual(a_test, trans_test.a) assert_sequence_almost_equal(self, p_test, trans_test.p) self.assertAlmostEqual(r_test, trans_test.r) self.assertEqual(t_test, trans_test.t)
class TestSplittingDataset(TestCase): def setUp(self): self.dataset = Dataset_Counts(state_shape=[2], nb_actions=3, count_param=0.2) self.dataset_size = 100 for i in np.random.randint(0, len(transitions), self.dataset_size): self.dataset.add(*transitions[i]) def test_empty_test(self): dataset_train, dataset_test = self.dataset.train_validation_split( test_size=0) self.assertEqual(dataset_test.size, 0) self.assertEqual(dataset_train.size, self.dataset_size) def test_default(self): dataset_train, dataset_test = self.dataset.train_validation_split() self.assertEqual(dataset_test.size, 20) self.assertEqual(dataset_train.size, 80) def test_empty_train(self): dataset_train, dataset_test = self.dataset.train_validation_split(1) self.assertEqual(dataset_test.size, self.dataset_size) self.assertEqual(dataset_train.size, 0) def test_original_data_set_does_not_change(self): random_ind = np.random.randint(0, self.dataset_size) s, a, p1, r, s2, t, c, p2, c1 = self.dataset._get_transition( random_ind) _, _ = self.dataset.train_validation_split(np.random.rand()) new_s, new_a, new_p1, new_r, new_s2, new_t, new_c, new_p2, new_c1 = self.dataset._get_transition( random_ind) assert_sequence_almost_equal(self, new_s, s) self.assertEqual(new_a, a) assert_sequence_almost_equal(self, new_p1, p1) self.assertAlmostEqual(new_r, r) assert_sequence_almost_equal(self, new_s2, s2) self.assertEqual(new_t, t) assert_sequence_almost_equal(self, new_c, c) assert_sequence_almost_equal(self, new_p2, p2) self.assertAlmostEqual(new_c1, c1)