示例#1
0
def test_grid8_set_transitions():
    grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
    assert grid8_map.get_transitions(
        0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
    grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH),
                             Grid8TransitionsEnum.NORTH, 1)
    assert grid8_map.get_transitions(
        0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
    grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH),
                             Grid8TransitionsEnum.NORTH, 0)
    assert grid8_map.get_transitions(
        0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
示例#2
0
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
                            agent_position: Tuple[int, int],
                            rail: GridTransitionMap) -> Set[RailEnvNextAction]:
    """
    Get the valid move actions (forward, left, right) for an agent.

    TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
    and more elegant. But given the few calls this has no priority now.

    Parameters
    ----------
    agent_direction : Grid4TransitionsEnum
    agent_position: Tuple[int,int]
    rail : GridTransitionMap


    Returns
    -------
    Set of `RailEnvNextAction` (tuples of (action,position,direction))
        Possible move actions (forward,left,right) and the next position/direction they lead to.
        It is not checked that the next cell is free.
    """
    valid_actions: Set[RailEnvNextAction] = OrderedSet()
    possible_transitions = rail.get_transitions(*agent_position,
                                                agent_direction)
    num_transitions = np.count_nonzero(possible_transitions)
    # Start from the current orientation, and see which transitions are available;
    # organize them as [left, forward, right], relative to the current orientation
    # If only one transition is possible, the forward branch is aligned with it.
    if rail.is_dead_end(agent_position):
        action = RailEnvActions.MOVE_FORWARD
        exit_direction = (agent_direction + 2) % 4
        if possible_transitions[exit_direction]:
            new_position = get_new_position(agent_position, exit_direction)
            valid_actions.add(
                RailEnvNextAction(action, new_position, exit_direction))
    elif num_transitions == 1:
        action = RailEnvActions.MOVE_FORWARD
        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[new_direction]:
                new_position = get_new_position(agent_position, new_direction)
                valid_actions.add(
                    RailEnvNextAction(action, new_position, new_direction))
    else:
        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[new_direction]:
                if new_direction == agent_direction:
                    action = RailEnvActions.MOVE_FORWARD
                elif new_direction == (agent_direction + 1) % 4:
                    action = RailEnvActions.MOVE_RIGHT
                elif new_direction == (agent_direction - 1) % 4:
                    action = RailEnvActions.MOVE_LEFT
                else:
                    raise Exception("Illegal state")

                new_position = get_new_position(agent_position, new_direction)
                valid_actions.add(
                    RailEnvNextAction(action, new_position, new_direction))
    return valid_actions
    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
                  np_random: RandomState = None) -> Schedule:
        _runtime_seed = seed + num_resets

        valid_positions = []
        for r in range(rail.height):
            for c in range(rail.width):
                if rail.get_full_transitions(r, c) > 0:
                    valid_positions.append((r, c))
        if len(valid_positions) == 0:
            return Schedule(agent_positions=[], agent_directions=[],
                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None, max_episode_steps=0)

        if len(valid_positions) < num_agents:
            warnings.warn("schedule_generators: len(valid_positions) < num_agents")
            return Schedule(agent_positions=[], agent_directions=[],
                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None, max_episode_steps=0)

        agents_position_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)]
        agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
        agents_target_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)]
        agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)]
        update_agents = np.zeros(num_agents)

        re_generate = True
        cnt = 0
        while re_generate:
            cnt += 1
            if cnt > 1:
                print("re_generate cnt={}".format(cnt))
            if cnt > 1000:
                raise Exception("After 1000 re_generates still not success, giving up.")
            # update position
            for i in range(num_agents):
                if update_agents[i] == 1:
                    x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx)
                    agents_position_idx[i] = np_random.choice(x)
                    agents_position[i] = valid_positions[agents_position_idx[i]]
                    x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx)
                    agents_target_idx[i] = np_random.choice(x)
                    agents_target[i] = valid_positions[agents_target_idx[i]]
            update_agents = np.zeros(num_agents)

            # agents_direction must be a direction for which a solution is
            # guaranteed.
            agents_direction = [0] * num_agents
            re_generate = False
            for i in range(num_agents):
                valid_movements = []
                for direction in range(4):
                    position = agents_position[i]
                    moves = rail.get_transitions(position[0], position[1], direction)
                    for move_index in range(4):
                        if moves[move_index]:
                            valid_movements.append((direction, move_index))

                valid_starting_directions = []
                for m in valid_movements:
                    new_position = get_new_position(agents_position[i], m[1])
                    if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1],
                                                                                        agents_target[i]):
                        valid_starting_directions.append(m[0])

                if len(valid_starting_directions) == 0:
                    update_agents[i] = 1
                    warnings.warn(
                        "reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i]))
                    re_generate = True
                    break
                else:
                    agents_direction[i] = valid_starting_directions[
                        np_random.choice(len(valid_starting_directions), 1)[0]]

        agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random)

        # Compute max number of steps with given schedule
        extra_time_factor = 1.5  # Factor to allow for more then minimal time
        max_episode_steps = int(extra_time_factor * rail.height * rail.width)

        return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
                        agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None,
                        max_episode_steps=max_episode_steps)
示例#4
0
def test_grid4_get_transitions():
    grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    assert grid4_map.get_full_transitions(0, 0) == 0

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH),
                             Grid4TransitionsEnum.NORTH, 1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (1, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    assert grid4_map.get_full_transitions(0, 0) == pow(
        2, 15)  # the most significant bit is on

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH),
                             Grid4TransitionsEnum.WEST, 1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (1, 0, 0,
                                                                     1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    # the most significant and the fourth most significant bits are on
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH),
                             Grid4TransitionsEnum.NORTH, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.NORTH) == (0, 0, 0,
                                                                     1)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.SOUTH) == (0, 0, 0,
                                                                     0)
    assert grid4_map.get_transitions(0, 0,
                                     Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    # the fourth most significant bits are on
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
示例#5
0
def get_new_position_for_action(
        agent_position: Tuple[int, int], agent_direction: Grid4TransitionsEnum,
        action: RailEnvActions,
        rail: GridTransitionMap) -> Tuple[int, int, int]:
    """
    Get the next position for this action.

    TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
    and more elegant. But given the few calls this has no priority now.

    Parameters
    ----------
    agent_position
    agent_direction
    action
    rail


    Returns
    -------
    Tuple[int,int,int]
        row, column, direction
    """
    possible_transitions = rail.get_transitions(*agent_position,
                                                agent_direction)
    num_transitions = np.count_nonzero(possible_transitions)
    # Start from the current orientation, and see which transitions are available;
    # organize them as [left, forward, right], relative to the current orientation
    # If only one transition is possible, the forward branch is aligned with it.
    if rail.is_dead_end(agent_position):
        valid_action = RailEnvActions.MOVE_FORWARD
        exit_direction = (agent_direction + 2) % 4
        if possible_transitions[exit_direction]:
            new_position = get_new_position(agent_position, exit_direction)
            if valid_action == action:
                return new_position, exit_direction
    elif num_transitions == 1:
        valid_action = RailEnvActions.MOVE_FORWARD
        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[new_direction]:
                new_position = get_new_position(agent_position, new_direction)
                if valid_action == action:
                    return new_position, new_direction
    else:
        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
            if possible_transitions[new_direction]:
                if new_direction == agent_direction:
                    valid_action = RailEnvActions.MOVE_FORWARD
                    if valid_action == action:
                        new_position = get_new_position(
                            agent_position, new_direction)
                        return new_position, new_direction
                elif new_direction == (agent_direction + 1) % 4:
                    valid_action = RailEnvActions.MOVE_RIGHT
                    if valid_action == action:
                        new_position = get_new_position(
                            agent_position, new_direction)
                        return new_position, new_direction
                elif new_direction == (agent_direction - 1) % 4:
                    valid_action = RailEnvActions.MOVE_LEFT
                    if valid_action == action:
                        new_position = get_new_position(
                            agent_position, new_direction)
                        return new_position, new_direction