def test_grid4_get_transitions(): grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) assert grid4_map.get_full_transitions(0, 0) == 0 grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) assert grid4_map.get_full_transitions(0, 0) == pow( 2, 15) # the most significant bit is on grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) # the most significant and the fourth most significant bits are on assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12) grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) # the fourth most significant bits are on assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, np_random: RandomState = None) -> Schedule: _runtime_seed = seed + num_resets valid_positions = [] for r in range(rail.height): for c in range(rail.width): if rail.get_full_transitions(r, c) > 0: valid_positions.append((r, c)) if len(valid_positions) == 0: return Schedule(agent_positions=[], agent_directions=[], agent_targets=[], agent_speeds=[], agent_malfunction_rates=None, max_episode_steps=0) if len(valid_positions) < num_agents: warnings.warn("schedule_generators: len(valid_positions) < num_agents") return Schedule(agent_positions=[], agent_directions=[], agent_targets=[], agent_speeds=[], agent_malfunction_rates=None, max_episode_steps=0) agents_position_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] agents_target_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)] agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)] update_agents = np.zeros(num_agents) re_generate = True cnt = 0 while re_generate: cnt += 1 if cnt > 1: print("re_generate cnt={}".format(cnt)) if cnt > 1000: raise Exception("After 1000 re_generates still not success, giving up.") # update position for i in range(num_agents): if update_agents[i] == 1: x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx) agents_position_idx[i] = np_random.choice(x) agents_position[i] = valid_positions[agents_position_idx[i]] x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx) agents_target_idx[i] = np_random.choice(x) agents_target[i] = valid_positions[agents_target_idx[i]] update_agents = np.zeros(num_agents) # agents_direction must be a direction for which a solution is # guaranteed. agents_direction = [0] * num_agents re_generate = False for i in range(num_agents): valid_movements = [] for direction in range(4): position = agents_position[i] moves = rail.get_transitions(position[0], position[1], direction) for move_index in range(4): if moves[move_index]: valid_movements.append((direction, move_index)) valid_starting_directions = [] for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1], agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: update_agents[i] = 1 warnings.warn( "reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i])) re_generate = True break else: agents_direction[i] = valid_starting_directions[ np_random.choice(len(valid_starting_directions), 1)[0]] agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) # Compute max number of steps with given schedule extra_time_factor = 1.5 # Factor to allow for more then minimal time max_episode_steps = int(extra_time_factor * rail.height * rail.width) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None, max_episode_steps=max_episode_steps)