コード例 #1
0
ファイル: test_action.py プロジェクト: dohnala/GridWorld
    def test_apply_on_boundary(self):
        state = GridWorld(10, 10)
        state.add_agent(Agent(9, 0))

        next_state = MoveRight().apply(state)

        self.assertEqual(1, next_state.step)
        self.assertEqual(9, next_state.agent.x)
コード例 #2
0
ファイル: test_action.py プロジェクト: dohnala/GridWorld
    def test_apply_boundary(self):
        state = GridWorld(10, 10)
        state.add_agent(Agent(0, 9))

        next_state = MoveUp().apply(state)

        self.assertEqual(1, next_state.step)
        self.assertEqual(9, next_state.agent.y)
コード例 #3
0
ファイル: test_action.py プロジェクト: dohnala/GridWorld
    def test_apply(self):
        state = GridWorld(10, 10)
        state.add_agent(Agent(0, 5))

        next_state = MoveDown().apply(state)

        self.assertEqual(1, next_state.step)
        self.assertEqual(4, next_state.agent.y)
コード例 #4
0
    def test_encode_with_treasure_position_layer(self):
        state = GridWorld(4, 4)
        state.add_object(Treasure(2, 3))

        encoder = OneHotEncoder(4, 4, agent_position=False, treasure_position=True)

        expected = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]

        assert_array_equal(expected, encoder.encode(state))
コード例 #5
0
ファイル: test_state.py プロジェクト: dohnala/GridWorld
    def test_init(self):
        world = GridWorld(10, 10)

        self.assertEqual(0, world.step)
        self.assertEqual(10, world.width)
        self.assertEqual(10, world.height)

        self.assertIsNone(world.agent)
        self.assertEqual(0, len(world.get_objects()))
        self.assertEqual(0, len(world.get_object_types()))
コード例 #6
0
    def test_encode_state_with_agent_position_layer(self):
        state = GridWorld(4, 4)
        state.add_agent(Agent(0, 0))

        encoder = LayerEncoder(4, 4, agent_position=True, treasure_position=False)

        expected = [[[1, 0, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 0]]]

        assert_array_equal(expected, encoder.encode(state))
コード例 #7
0
    def generate_grid_world(self):
        grid_world = GridWorld(self.width, self.height)

        if self.treasure_position:
            grid_world.add_object(Treasure(*self.treasure_position))
        else:
            grid_world.add_object(
                Treasure(*grid_world.get_random_free_position()))

        grid_world.add_agent(Agent(*grid_world.get_random_free_position()))

        return grid_world
コード例 #8
0
ファイル: test_state.py プロジェクト: dohnala/GridWorld
    def test_next_step(self):
        world = GridWorld(10, 10)
        world.add_agent(Agent(0, 0))
        world.add_object(Treasure(8, 8))

        world.next_step()
        self.assertEqual(1, world.step)
コード例 #9
0
ファイル: test_state.py プロジェクト: dohnala/GridWorld
    def test_add_treasure(self):
        world = GridWorld(10, 10)
        world.add_object(Treasure(8, 8))

        self.assertEqual(1, len(world.get_objects()))
        self.assertListEqual([Treasure], world.get_object_types())
        self.assertEqual(1, len(world.get_objects_by_type(Treasure)))
コード例 #10
0
ファイル: test_state.py プロジェクト: dohnala/GridWorld
    def test_agent_is_at_treasure(self):
        world = GridWorld(10, 10)
        world.add_agent(Agent(0, 0))
        world.add_object(Treasure(8, 8))

        self.assertFalse(world.agent.is_at_any_object(world.get_objects_by_type(Treasure)))

        world.agent.x = 8
        world.agent.y = 8

        self.assertTrue(world.agent.is_at_any_object(world.get_objects_by_type(Treasure)))
コード例 #11
0
ファイル: test_state.py プロジェクト: dohnala/GridWorld
    def test_add_agent(self):
        world = GridWorld(10, 10)
        world.add_agent(Agent(0, 0))

        self.assertEqual(0, world.agent.x)
        self.assertEqual(0, world.agent.y)

        self.assertEqual(1, len(world.get_objects()))
        self.assertListEqual([Agent], world.get_object_types())
        self.assertEqual(1, len(world.get_objects_by_type(Agent)))
コード例 #12
0
ファイル: test_state.py プロジェクト: dohnala/GridWorld
    def test_copy(self):
        world = GridWorld(10, 10)
        world.add_agent(Agent(0, 0))
        world.add_object(Treasure(8, 8))

        copy = world.copy()
        copy.step = 1
        copy.width = 5
        copy.height = 5
        copy.agent.x = 1
        copy.agent.y = 2

        copy_treasure = copy.get_object_by_type(Treasure)

        copy_treasure.x = 2
        copy_treasure.y = 3

        self.assertEqual(0, world.step)
        self.assertEqual(10, world.width)
        self.assertEqual(10, world.height)

        self.assertEqual(0, world.agent.x)
        self.assertEqual(0, world.agent.y)

        treasure = world.get_object_by_type(Treasure)

        self.assertIsNotNone(treasure)
        self.assertEqual(8, treasure.x)
        self.assertEqual(8, treasure.y)

        self.assertEqual(1, copy.step)
        self.assertEqual(5, copy.width)
        self.assertEqual(5, copy.height)

        self.assertEqual(1, copy.agent.x)
        self.assertEqual(2, copy.agent.y)

        self.assertEqual(2, copy_treasure.x)
        self.assertEqual(3, copy_treasure.y)
コード例 #13
0
    def test_encode_with_multiple_layers(self):
        state = GridWorld(4, 4)
        state.add_agent(Agent(0, 0))
        state.add_object(Treasure(2, 3))

        encoder = LayerEncoder(4, 4, agent_position=True, treasure_position=True)

        expected = [[[1, 0, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 0]],
                    [[0, 0, 0, 0],
                     [0, 0, 0, 0],
                     [0, 0, 0, 1],
                     [0, 0, 0, 0]]]

        assert_array_equal(expected, encoder.encode(state))
コード例 #14
0
    def test_encode_with_no_layers(self):
        state = GridWorld(4, 4)

        encoder = OneHotEncoder(4, 4, agent_position=False, treasure_position=False)

        assert_array_equal(np.empty(0), encoder.encode(state))
コード例 #15
0
ファイル: test_action.py プロジェクト: dohnala/GridWorld
    def test_is_valid_on_boundary(self):
        state = GridWorld(10, 10)
        state.add_agent(Agent(9, 0))

        self.assertFalse(MoveRight().__is_valid__(state))
コード例 #16
0
ファイル: test_action.py プロジェクト: dohnala/GridWorld
    def test_is_valid(self):
        state = GridWorld(10, 10)
        state.add_agent(Agent(0, 0))

        self.assertTrue(MoveRight().__is_valid__(state))