def manhattan_distance(env: MapfEnv, s, a1, a2): """Return the manhattan distance between the two given agents in the given state""" locations = env.state_to_locations(s) return abs(locations[a1][0] - locations[a2][0]) + abs(locations[a1][1] - locations[a2][1])
def test_couple_detect_conflict_3_agents_multiple_agents_in_group(self): """This test may sometime be used to test detecting a conflict for only a couple of agents. The test will make sure that agent 0 got no conflicts with 1 and 2 while agents 1 and 2 do get a conflict. Now agent 1 will be a part of a group contains both agent 0 and 1 ([0,1]). This way agent 1 index in its group will be 1 and not 0. This case is catching a bug I had previously. """ grid = MapfGrid(['...', '...', '...']) agents_starts = ((0, 0), (2, 0), (2, 2)) agents_goals = ((0, 2), (2, 2), (2, 0)) env = MapfEnv(grid, 3, agents_starts, agents_goals, 0, 0, -1, 1, -0.01) single_agent_env = MapfEnv(grid, 1, (agents_starts[0], ), (agents_goals[0], ), 0, 0, -1, 1, -0.01) env01 = get_local_view(env, [0, 1]) # >>S # SSS # SSS policy0 = { 0: ACTIONS.index(RIGHT), 1: ACTIONS.index(STAY), 2: ACTIONS.index(STAY), 3: ACTIONS.index(RIGHT), 4: ACTIONS.index(STAY), 5: ACTIONS.index(STAY), 6: ACTIONS.index(STAY), 7: ACTIONS.index(STAY), 8: ACTIONS.index(STAY), } # SSS # SSS # >>S policy1 = { 0: ACTIONS.index(STAY), 1: ACTIONS.index(STAY), 2: ACTIONS.index(RIGHT), 3: ACTIONS.index(STAY), 4: ACTIONS.index(STAY), 5: ACTIONS.index(RIGHT), 6: ACTIONS.index(STAY), 7: ACTIONS.index(STAY), 8: ACTIONS.index(STAY), } # policy01 is a cross between agent 0 and agent 1 policy01 = {} for s0 in range(9): for s1 in range(9): joint_state = env01.locations_to_state( (single_agent_env.state_to_locations(s0)[0], single_agent_env.state_to_locations(s1)[0])) policy01[joint_state] = vector_action_to_integer( (integer_action_to_vector(policy0[s0], 1)[0], integer_action_to_vector(policy1[s1], 1)[0])) # SSS # SSS # S<< policy2 = { 0: ACTIONS.index(STAY), 1: ACTIONS.index(STAY), 2: ACTIONS.index(STAY), 3: ACTIONS.index(STAY), 4: ACTIONS.index(STAY), 5: ACTIONS.index(LEFT), 6: ACTIONS.index(STAY), 7: ACTIONS.index(STAY), 8: ACTIONS.index(LEFT), } joint_policy = CrossedPolicy(env, [ DictPolicy(env01, 1.0, policy01), DictPolicy(get_local_view(env, [2]), 1.0, policy2) ], [[0, 1], [2]]) aux_local_env = get_local_view(env, [0]) # Assert a conflict is found for agents 1 and 2 self.assertEqual( couple_detect_conflict(env, joint_policy, 2, 1), ((2, aux_local_env.locations_to_state( ((2, 2), )), aux_local_env.locations_to_state(((2, 1), ))), (1, aux_local_env.locations_to_state( ((2, 0), )), aux_local_env.locations_to_state(((2, 1), ))))) # Assert no conflict is found for agents 0 and 1 self.assertIsNone(couple_detect_conflict(env, joint_policy, 0, 1)) # Assert no conflict is found for agents 0 and 2 self.assertIsNone(couple_detect_conflict(env, joint_policy, 0, 2))