예제 #1
0
    def test_empty_test(self):
        dataset = Dataset_Counts(state_shape=[2],
                                 nb_actions=3,
                                 count_param=0.2)
        dataset.add(*transitions[0])
        dataset.add(*transitions[1])

        dataset_train, dataset_test = dataset.train_validation_split(
            test_size=0.5)
        s_train, a_train, p_train, r_train, s2_train, t_train, _, _, _ = dataset_train._get_transition(
            0)
        s_test, a_test, p_test, r_test, s2_test, t_test, _, _, _ = dataset_test._get_transition(
            0)

        if a_train == transitions[0].a:
            trans_train = transitions[0]
            trans_test = transitions[1]
        else:
            trans_train = transitions[1]
            trans_test = transitions[0]

        assert_sequence_almost_equal(self, s_train, trans_train.s)
        self.assertEqual(a_train, trans_train.a)
        assert_sequence_almost_equal(self, p_train, trans_train.p)
        self.assertAlmostEqual(r_train, trans_train.r)
        self.assertEqual(t_train, trans_train.t)

        assert_sequence_almost_equal(self, s_test, trans_test.s)
        self.assertEqual(a_test, trans_test.a)
        assert_sequence_almost_equal(self, p_test, trans_test.p)
        self.assertAlmostEqual(r_test, trans_test.r)
        self.assertEqual(t_test, trans_test.t)
예제 #2
0
class TestSplittingDataset(TestCase):
    def setUp(self):
        self.dataset = Dataset_Counts(state_shape=[2],
                                      nb_actions=3,
                                      count_param=0.2)
        self.dataset_size = 100
        for i in np.random.randint(0, len(transitions), self.dataset_size):
            self.dataset.add(*transitions[i])

    def test_empty_test(self):
        dataset_train, dataset_test = self.dataset.train_validation_split(
            test_size=0)
        self.assertEqual(dataset_test.size, 0)
        self.assertEqual(dataset_train.size, self.dataset_size)

    def test_default(self):
        dataset_train, dataset_test = self.dataset.train_validation_split()
        self.assertEqual(dataset_test.size, 20)
        self.assertEqual(dataset_train.size, 80)

    def test_empty_train(self):
        dataset_train, dataset_test = self.dataset.train_validation_split(1)
        self.assertEqual(dataset_test.size, self.dataset_size)
        self.assertEqual(dataset_train.size, 0)

    def test_original_data_set_does_not_change(self):
        random_ind = np.random.randint(0, self.dataset_size)
        s, a, p1, r, s2, t, c, p2, c1 = self.dataset._get_transition(
            random_ind)
        _, _ = self.dataset.train_validation_split(np.random.rand())
        new_s, new_a, new_p1, new_r, new_s2, new_t, new_c, new_p2, new_c1 = self.dataset._get_transition(
            random_ind)

        assert_sequence_almost_equal(self, new_s, s)
        self.assertEqual(new_a, a)
        assert_sequence_almost_equal(self, new_p1, p1)
        self.assertAlmostEqual(new_r, r)
        assert_sequence_almost_equal(self, new_s2, s2)
        self.assertEqual(new_t, t)
        assert_sequence_almost_equal(self, new_c, c)
        assert_sequence_almost_equal(self, new_p2, p2)
        self.assertAlmostEqual(new_c1, c1)