Exemplo n.º 1
0
class TestPlanning(unittest.TestCase):
    
    def setUp(self):
        self.n = 5
        self.p = 1
        self.gridworld = GridWorld(self.n, self.p)
        self.go_right_policy = np.ones(self.n * self.n, dtype=int)
        self.discount = 0.9
        self.large_discount = 0.2
        self.policy = np.array(
                [['TERMINAL', 'RIGHT', 'RIGHT', 'RIGHT', 'TERMINAL'],
                 ['RIGHT', 'RIGHT', 'RIGHT', 'RIGHT', 'UP'],
                 ['RIGHT', 'RIGHT', 'RIGHT', 'RIGHT', 'UP'],
                 ['RIGHT', 'RIGHT', 'RIGHT', 'RIGHT', 'UP'],
                 ['RIGHT', 'RIGHT', 'RIGHT', 'RIGHT', 'UP']])
        self.policy_large_discount = np.array(
                [['TERMINAL', 'LEFT', 'RIGHT', 'RIGHT', 'TERMINAL'],
                 ['UP', 'LEFT', 'RIGHT', 'RIGHT', 'UP'],
                 ['UP', 'LEFT', 'RIGHT', 'RIGHT', 'UP'],
                 ['UP', 'LEFT', 'RIGHT', 'RIGHT', 'UP'],
                 ['UP', 'LEFT', 'RIGHT', 'RIGHT', 'UP']])
    
    def test_transition_matrix(self):
        transition_rows = []
        for i in range(self.n):
            for j in range(self.n):
                transition_row = np.zeros((self.n, self.n))
                if (i, j) not in set([(0, 0), (0, self.n - 1)]):
                    transition_row[i, min(j + 1, self.n - 1)] = 1
                transition_rows.append(transition_row.flatten())
        expected = np.vstack(transition_rows)
    
        actual = policy_transition_matrix(
                self.go_right_policy, self.gridworld)
        assert_array_equal(expected, actual)
        
    def test_full_policy_eval(self):
        transition = policy_transition_matrix(
                self.go_right_policy, self.gridworld)
        reward = self.gridworld.get_rewards()
        actual = full_policy_evaluation(transition, reward, self.discount)
        
        expected = np.zeros((self.n, self.n))
        expected[0, :] = [10 * self.discount ** (self.n - 1 - i) 
                            for i in range(self.n)]
        expected[0, 0] = 1
        expected = expected.flatten()
        
        assert_array_almost_equal(expected, actual)
#        
    def test_policy_iteration(self):
        policy = policy_iteration(self.gridworld, self.discount)
        actual = int_policy_to_str_policy(policy).reshape(self.n, self.n)
        assert_array_equal(self.policy, actual)
        
    def test_policy_iteration_more_discount(self):
        policy = policy_iteration(self.gridworld, self.large_discount)
        actual = int_policy_to_str_policy(policy).reshape(self.n, self.n)
        assert_array_equal(self.policy_large_discount, actual)
        
    def test_value_iteration(self):
        policy = value_iteration(self.gridworld, self.discount)
        actual = int_policy_to_str_policy(policy).reshape(self.n, self.n)
        assert_array_equal(self.policy, actual)
        
    def test_value_iteration_more_discount(self):
        policy = value_iteration(self.gridworld, self.large_discount)
        actual = int_policy_to_str_policy(policy).reshape(self.n, self.n)
        assert_array_equal(self.policy_large_discount, actual)
    
    def test_modified_policy_iteration(self):
        policy = modified_policy_iteration(self.gridworld, self.discount,
                                           num_eval_iters=3)
        actual = int_policy_to_str_policy(policy).reshape(self.n, self.n)
        assert_array_equal(self.policy, actual)
        
    def test_modified_policy_iteration_more_discount(self):
        policy = modified_policy_iteration(self.gridworld, self.large_discount,
                                           num_eval_iters=3)
        actual = int_policy_to_str_policy(policy).reshape(self.n, self.n)
        assert_array_equal(self.policy_large_discount, actual)