def test_valid_railenv_transitions():
    rail_env_trans = RailEnvTransitions()

    # directions:
    #            'N': 0
    #            'E': 1
    #            'S': 2
    #            'W': 3

    for i in range(2):
        assert (rail_env_trans.get_transitions(int('1100110000110011', 2),
                                               i) == (1, 1, 0, 0))
        assert (rail_env_trans.get_transitions(int('1100110000110011', 2),
                                               2 + i) == (0, 0, 1, 1))

    no_transition_cell = int('0000000000000000', 2)

    for i in range(4):
        assert (rail_env_trans.get_transitions(no_transition_cell,
                                               i) == (0, 0, 0, 0))

    # Facing south, going south
    north_south_transition = rail_env_trans.set_transitions(
        no_transition_cell, 2, (0, 0, 1, 0))
    assert (rail_env_trans.set_transition(north_south_transition, 2, 2,
                                          0) == no_transition_cell)
    assert (rail_env_trans.get_transition(north_south_transition, 2, 2))

    # Facing north, going east
    south_east_transition = \
        rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
    assert (rail_env_trans.get_transition(south_east_transition, 0, 1))

    # The opposite transitions are not feasible
    assert (not rail_env_trans.get_transition(north_south_transition, 2, 0))
    assert (not rail_env_trans.get_transition(south_east_transition, 2, 1))

    east_west_transition = rail_env_trans.rotate_transition(
        north_south_transition, 90)
    north_west_transition = rail_env_trans.rotate_transition(
        south_east_transition, 180)

    # Facing west, going west
    assert (rail_env_trans.get_transition(east_west_transition, 3, 3))
    # Facing south, going west
    assert (rail_env_trans.get_transition(north_west_transition, 2, 3))

    assert (south_east_transition == rail_env_trans.rotate_transition(
        south_east_transition, 360))
def test_rail_environment_single_agent():
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    cells = transitions.transition_list
    vertical_line = cells[1]
    south_symmetrical_switch = cells[6]
    north_symmetrical_switch = transitions.rotate_transition(
        south_symmetrical_switch, 180)
    south_east_turn = int('0100000000000010', 2)
    south_west_turn = transitions.rotate_transition(south_east_turn, 90)
    north_east_turn = transitions.rotate_transition(south_east_turn, 270)
    north_west_turn = transitions.rotate_transition(south_east_turn, 180)

    rail_map = np.array(
        [[south_east_turn, south_symmetrical_switch, south_west_turn],
         [vertical_line, vertical_line, vertical_line],
         [north_east_turn, north_symmetrical_switch, north_west_turn]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=3, height=3, transitions=transitions)
    rail.grid = rail_map
    rail_env = RailEnv(width=3,
                       height=3,
                       rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(),
                       number_of_agents=1,
                       obs_builder_object=GlobalObsForRailEnv())

    for _ in range(200):
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        assert (transitions.get_transitions(rail_map[agent.position],
                                            agent.direction) != (0, 0, 0, 0))

        initial_pos = agent.position

        valid_active_actions_done = 0
        pos = initial_pos
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.randint(4)

            _, _, _, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]
            if prev_pos != pos:
                valid_active_actions_done += 1

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            done = False
            while not done:
                # We randomly select an action
                action = np.random.randint(4)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
    def generator(width: int,
                  height: int,
                  num_agents: int,
                  num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:
        t_utils = RailEnvTransitions()

        transition_probability = cell_type_relative_proportion

        transitions_templates_ = []
        transition_probabilities = []
        for i in range(len(t_utils.transitions)):  # don't include dead-ends
            if t_utils.transitions[i] == int('0010000000000000', 2):
                continue

            all_transitions = 0
            for dir_ in range(4):
                trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
                all_transitions |= (trans[0] << 3) | \
                                   (trans[1] << 2) | \
                                   (trans[2] << 1) | \
                                   (trans[3])

            template = [int(x) for x in bin(all_transitions)[2:]]
            template = [0] * (4 - len(template)) + template

            # add all rotations
            for rot in [0, 90, 180, 270]:
                transitions_templates_.append(
                    (template,
                     t_utils.rotate_transition(t_utils.transitions[i], rot)))
                transition_probabilities.append(transition_probability[i])
                template = [template[-1]] + template[:-1]

        def get_matching_templates(template):
            """
            Returns a list of possible transition maps for a given template

            Parameters:
            ------
            template:List[int]

            Returns:
            ------
            List[int]
            """
            ret = []
            for i in range(len(transitions_templates_)):
                is_match = True
                for j in range(4):
                    if template[j] >= 0 and template[
                            j] != transitions_templates_[i][0][j]:
                        is_match = False
                        break
                if is_match:
                    ret.append((transitions_templates_[i][1],
                                transition_probabilities[i]))
            return ret

        MAX_INSERTIONS = (width - 2) * (height - 2) * 10
        MAX_ATTEMPTS_FROM_SCRATCH = 10

        attempt_number = 0
        while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
            cells_to_fill = []
            rail = []
            for r in range(height):
                rail.append([None] * width)
                if r > 0 and r < height - 1:
                    cells_to_fill = cells_to_fill + [
                        (r, c) for c in range(1, width - 1)
                    ]

            num_insertions = 0
            while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
                cell = cells_to_fill[np_random.choice(len(cells_to_fill),
                                                      1)[0]]
                cells_to_fill.remove(cell)
                row = cell[0]
                col = cell[1]

                # look at its neighbors and see what are the possible transitions
                # that can be chosen from, if any.
                valid_template = [-1, -1, -1, -1]

                for el in [(0, 2, (-1, 0)), (1, 3, (0, 1)), (2, 0, (1, 0)),
                           (3, 1, (0, -1))]:  # N, E, S, W
                    neigh_trans = rail[row + el[2][0]][col + el[2][1]]
                    if neigh_trans is not None:
                        # select transition coming from facing direction el[1] and
                        # moving to direction el[1]
                        max_bit = 0
                        for k in range(4):
                            max_bit |= t_utils.get_transition(
                                neigh_trans, k, el[1])

                        if max_bit:
                            valid_template[el[0]] = 1
                        else:
                            valid_template[el[0]] = 0

                possible_cell_transitions = get_matching_templates(
                    valid_template)

                if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
                    # no cell can be filled in without violating some transitions
                    # can a dead-end solve the problem?
                    if valid_template.count(1) == 1:
                        for k in range(4):
                            if valid_template[k] == 1:
                                rot = 0
                                if k == 0:
                                    rot = 180
                                elif k == 1:
                                    rot = 270
                                elif k == 2:
                                    rot = 0
                                elif k == 3:
                                    rot = 90

                                rail[row][col] = t_utils.rotate_transition(
                                    int('0010000000000000', 2), rot)
                                num_insertions += 1

                                break

                    else:
                        # can I get valid transitions by removing a single
                        # neighboring cell?
                        bestk = -1
                        besttrans = []
                        for k in range(4):
                            tmp_template = valid_template[:]
                            tmp_template[k] = -1
                            possible_cell_transitions = get_matching_templates(
                                tmp_template)
                            if len(possible_cell_transitions) > len(besttrans):
                                besttrans = possible_cell_transitions
                                bestk = k

                        if bestk >= 0:
                            # Replace the corresponding cell with None, append it
                            # to cells to fill, fill in a transition in the current
                            # cell.
                            replace_row = row - 1
                            replace_col = col
                            if bestk == 1:
                                replace_row = row
                                replace_col = col + 1
                            elif bestk == 2:
                                replace_row = row + 1
                                replace_col = col
                            elif bestk == 3:
                                replace_row = row
                                replace_col = col - 1

                            cells_to_fill.append((replace_row, replace_col))
                            rail[replace_row][replace_col] = None

                            possible_transitions, possible_probabilities = zip(
                                *besttrans)
                            possible_probabilities = [
                                p / sum(possible_probabilities)
                                for p in possible_probabilities
                            ]

                            rail[row][col] = np_random.choice(
                                possible_transitions, p=possible_probabilities)
                            num_insertions += 1

                        else:
                            print('WARNING: still nothing!')
                            rail[row][col] = int('0000000000000000', 2)
                            num_insertions += 1
                            pass

                else:
                    possible_transitions, possible_probabilities = zip(
                        *possible_cell_transitions)
                    possible_probabilities = [
                        p / sum(possible_probabilities)
                        for p in possible_probabilities
                    ]

                    rail[row][col] = np_random.choice(possible_transitions,
                                                      p=possible_probabilities)
                    num_insertions += 1

            if num_insertions == MAX_INSERTIONS:
                # Failed to generate a valid level; try again for a number of times
                attempt_number += 1
            else:
                break

        if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
            print('ERROR: failed to generate level')

        # Finally pad the border of the map with dead-ends to avoid border issues;
        # at most 1 transition in the neigh cell
        for r in range(height):
            # Check for transitions coming from [r][1] to WEST
            max_bit = 0
            neigh_trans = rail[r][1]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & 1)
            if max_bit:
                rail[r][0] = t_utils.rotate_transition(
                    int('0010000000000000', 2), 270)
            else:
                rail[r][0] = int('0000000000000000', 2)

            # Check for transitions coming from [r][-2] to EAST
            max_bit = 0
            neigh_trans = rail[r][-2]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
            if max_bit:
                rail[r][-1] = t_utils.rotate_transition(
                    int('0010000000000000', 2), 90)
            else:
                rail[r][-1] = int('0000000000000000', 2)

        for c in range(width):
            # Check for transitions coming from [1][c] to NORTH
            max_bit = 0
            neigh_trans = rail[1][c]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
            if max_bit:
                rail[0][c] = int('0010000000000000', 2)
            else:
                rail[0][c] = int('0000000000000000', 2)

            # Check for transitions coming from [-2][c] to SOUTH
            max_bit = 0
            neigh_trans = rail[-2][c]
            if neigh_trans is not None:
                for k in range(4):
                    neigh_trans_from_direction = (neigh_trans >>
                                                  ((3 - k) * 4)) & (2**4 - 1)
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
            if max_bit:
                rail[-1][c] = t_utils.rotate_transition(
                    int('0010000000000000', 2), 180)
            else:
                rail[-1][c] = int('0000000000000000', 2)

        # For display only, wrong levels
        for r in range(height):
            for c in range(width):
                if rail[r][c] is None:
                    rail[r][c] = int('0000000000000000', 2)

        tmp_rail = np.asarray(rail, dtype=np.uint16)

        return_rail = GridTransitionMap(width=width,
                                        height=height,
                                        transitions=t_utils)
        return_rail.grid = tmp_rail

        return return_rail, None
Ejemplo n.º 4
0
def test_rail_environment_single_agent(show=False):
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

    transitions = RailEnvTransitions()
    
    
    
    if False:
        # This env creation doesn't quite work right.
        cells = transitions.transition_list
        vertical_line = cells[1]
        south_symmetrical_switch = cells[6]
        north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
        south_east_turn = int('0100000000000010', 2)
        south_west_turn = transitions.rotate_transition(south_east_turn, 90)
        north_east_turn = transitions.rotate_transition(south_east_turn, 270)
        north_west_turn = transitions.rotate_transition(south_east_turn, 180)

        rail_map = np.array([[south_east_turn, south_symmetrical_switch,
                            south_west_turn],
                            [vertical_line, vertical_line, vertical_line],
                            [north_east_turn, north_symmetrical_switch,
                            north_west_turn]],
                            dtype=np.uint16)

        rail = GridTransitionMap(width=3, height=3, transitions=transitions)
        rail.grid = rail_map
        rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
                        schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
    else:
        rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
        rail_map = rail_env.rail.grid
    
    rail_env._max_episode_steps = 1000

    _ = rail_env.reset(False, False, True)

    liActions = [int(a) for a in RailEnvActions]

    env_renderer = RenderTool(rail_env)

    #RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
    
    for _ in range(5):

        #rail_env.agents[0].initial_position = (1,2)
        _ = rail_env.reset(False, False, True)

        # We do not care about target for the moment
        agent = rail_env.agents[0]
        agent.target = [-1, -1]

        # Check that trains are always initialized at a consistent position
        # or direction.
        # They should always be able to go somewhere.
        if show:
            print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
            print(transitions.get_transitions(rail_map[agent.position], agent.direction))

        #assert (transitions.get_transitions(
        #    rail_map[agent.position],
        #    agent.direction) != (0, 0, 0, 0))

        # HACK - force the direction to one we know is good.
        #agent.initial_position = agent.position = (2,3)
        agent.initial_direction = agent.direction = 0

        if show:
            print ("handle:", agent.handle)
        #agent.initial_position = initial_pos = agent.position

        valid_active_actions_done = 0
        pos = agent.position

        if show:
            env_renderer.render_env(show=show, show_agents=True)
            time.sleep(0.01)

        iStep = 0
        while valid_active_actions_done < 6:
            # We randomly select an action
            action = np.random.choice(liActions)
            #action = RailEnvActions.MOVE_FORWARD

            _, _, dict_done, _ = rail_env.step({0: action})

            prev_pos = pos
            pos = agent.position  # rail_env.agents_position[0]

            print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
            print(dict_done)
            if prev_pos != pos:
                valid_active_actions_done += 1
            iStep += 1
            
            if show:
                env_renderer.render_env(show=show, show_agents=True, step=iStep)
                time.sleep(0.01)
            assert iStep < 100, "valid actions should have been performed by now - hung agent"

        # After 6 movements on this railway network, the train should be back
        # to its original height on the map.
        #assert (initial_pos[0] == agent.position[0])

        # We check that the train always attains its target after some time
        for _ in range(10):
            _ = rail_env.reset()

            rail_env.agents[0].direction = 0

            # JW - to avoid problem with random_schedule_generator.
            #rail_env.agents[0].position = (1,2)

            iStep = 0
            while iStep < 100:
                # We randomly select an action
                action = np.random.choice(liActions)

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
                if done:
                    break
                iStep +=1
                assert iStep < 100, "agent should have finished by now"
                env_renderer.render_env(show=show)