def __init__(self,
                 width,
                 height,
                 transitions: Transitions = Grid4Transitions([]),
                 random_seed=None):
        """
        Builder for GridTransitionMap object.

        Parameters
        ----------
        width : int
            Width of the grid.
        height : int
            Height of the grid.
        transitions : Transitions object
            The Transitions object to use to encode/decode transitions over the
            grid.

        """

        self.width = width
        self.height = height
        self.transitions = transitions
        self.random_generator = np.random.RandomState()
        if random_seed is None:
            self.random_generator.seed(12)
        else:
            self.random_generator.seed(random_seed)
        self.grid = np.zeros((height, width),
                             dtype=self.transitions.get_type())
예제 #2
0
    def __init__(self,
                 width,
                 height,
                 transitions: Transitions = Grid4Transitions([])):
        """
        Builder for GridTransitionMap object.

        Parameters
        ----------
        width : int
            Width of the grid.
        height : int
            Height of the grid.
        transitions : Transitions object
            The Transitions object to use to encode/decode transitions over the
            grid.

        """

        self.width = width
        self.height = height
        self.transitions = transitions

        self.grid = np.zeros((height, width),
                             dtype=self.transitions.get_type())
예제 #3
0
def test_grid4_get_transitions():
    grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    assert grid4_map.get_full_transitions(0, 0) == 0

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH),
                             Grid4TransitionsEnum.NORTH, 1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (1, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    assert grid4_map.get_full_transitions(0, 0) == pow(
        2, 15)  # the most significant bit is on

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH),
                             Grid4TransitionsEnum.WEST, 1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (1, 0, 0,
                                                                     1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    # the most significant and the fourth most significant bits are on
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH),
                             Grid4TransitionsEnum.NORTH, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (0, 0, 0,
                                                                     1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    # the fourth most significant bits are on
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
예제 #4
0
    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
        """
        Returns directions in which the agent can move

        Parameters:
        ---------
        row : int
        col : int

        Returns:
        -------
        List[int]
        """
        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def test_rail_env_remove_deadend():
    ret = Grid4Transitions([])
    rail_env_deadends = set([
        int(rw('0010 0000 0000 0000'), 2),
        int(rw('0000 0001 0000 0000'), 2),
        int(rw('0000 0000 1000 0000'), 2),
        int(rw('0000 0000 0000 0100'), 2)
    ])
    for t in rail_env_deadends:
        expected_has_deadend = 0
        actual_had_deadend = ret.remove_deadends(t)
        assert actual_had_deadend == expected_has_deadend, \
            "{} should be deadend = {}, actual = {}".format(t, )

    assert ret.remove_deadends(int(rw('0010 0001 1000 0100'), 2)) == 0
    assert ret.remove_deadends(int(rw('0010 0001 1000 0110'),
                                   2)) == int(rw('0000 0000 0000 0010'), 2)
예제 #6
0
 def _assert(transition, expected):
     actual = Grid4Transitions.get_entry_directions(transition)
     assert actual == expected, "Found {}, expected {}.".format(
         actual, expected)