def __init__(self, width, height, transitions: Transitions = Grid4Transitions([]), random_seed=None): """ Builder for GridTransitionMap object. Parameters ---------- width : int Width of the grid. height : int Height of the grid. transitions : Transitions object The Transitions object to use to encode/decode transitions over the grid. """ self.width = width self.height = height self.transitions = transitions self.random_generator = np.random.RandomState() if random_seed is None: self.random_generator.seed(12) else: self.random_generator.seed(random_seed) self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])): """ Builder for GridTransitionMap object. Parameters ---------- width : int Width of the grid. height : int Height of the grid. transitions : Transitions object The Transitions object to use to encode/decode transitions over the grid. """ self.width = width self.height = height self.transitions = transitions self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def test_grid4_get_transitions(): grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) assert grid4_map.get_full_transitions(0, 0) == 0 grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) assert grid4_map.get_full_transitions(0, 0) == pow( 2, 15) # the most significant bit is on grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) # the most significant and the fourth most significant bits are on assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12) grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) # the fourth most significant bits are on assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: """ Returns directions in which the agent can move Parameters: --------- row : int col : int Returns: ------- List[int] """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def test_rail_env_remove_deadend(): ret = Grid4Transitions([]) rail_env_deadends = set([ int(rw('0010 0000 0000 0000'), 2), int(rw('0000 0001 0000 0000'), 2), int(rw('0000 0000 1000 0000'), 2), int(rw('0000 0000 0000 0100'), 2) ]) for t in rail_env_deadends: expected_has_deadend = 0 actual_had_deadend = ret.remove_deadends(t) assert actual_had_deadend == expected_has_deadend, \ "{} should be deadend = {}, actual = {}".format(t, ) assert ret.remove_deadends(int(rw('0010 0001 1000 0100'), 2)) == 0 assert ret.remove_deadends(int(rw('0010 0001 1000 0110'), 2)) == int(rw('0000 0000 0000 0010'), 2)
def _assert(transition, expected): actual = Grid4Transitions.get_entry_directions(transition) assert actual == expected, "Found {}, expected {}.".format( actual, expected)