def test_share_and_gather(self, store_mock): # when the state is unknown we exit immediately; no retries spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1) agent = TestAgent(spec) expected_agent_infos = [ _RoleInstanceInfo("trainer", 0, 10), _RoleInstanceInfo("trainer", 1, 10), _RoleInstanceInfo("validator", 2, 10), ] store_mock.return_value = [obj.serialize() for obj in expected_agent_infos] class DummyStore: def __init__(self): self.key = None self.value = None def set(self, key, value): self.key = key self.value = value def set_timeout(self, timeout): pass store = DummyStore() agent._share_and_gather(store, 1, 3, spec) self.assertEquals("torchelastic/role_info1", store.key) expected_info = _RoleInstanceInfo(spec.role, 1, spec.local_world_size) self.assertEquals(expected_info.serialize(), store.value) store_mock.assert_called_once()
def test_find_boundaries(self): role_infos = [ _RoleInstanceInfo("trainer", 1, 1), _RoleInstanceInfo("trainer", 2, 2), _RoleInstanceInfo("trainer", 3, 3), _RoleInstanceInfo("parameter_server", 4, 5), _RoleInstanceInfo("parameter_server", 0, 4), ] start_idx, end_idx = _RoleInstanceInfo.find_role_boundaries( role_infos, "trainer") self.assertEqual(start_idx, 0) self.assertEqual(end_idx, 2)
def test_get_ranks(self): role_infos = [ _RoleInstanceInfo("parameter_server", 0, 4), _RoleInstanceInfo("trainer", 1, 1), _RoleInstanceInfo("trainer", 2, 2), _RoleInstanceInfo("trainer", 3, 3), _RoleInstanceInfo("parameter_server", 4, 5), ] spec = self._get_worker_spec( max_restarts=3, monitor_interval=0.1, role="not_used", local_world_size=8 ) agent = TestAgent(spec) total_sum, ranks = agent._get_ranks(role_infos, 0, 0, len(role_infos)) self.assertEquals(15, total_sum) self.assertEquals([0, 1, 2, 3], list(ranks))
def test_assign_worker_ranks(self): role_infos = [ _RoleInstanceInfo("parameter_server", 0, 4), _RoleInstanceInfo("trainer", 1, 1), _RoleInstanceInfo("trainer", 2, 2), _RoleInstanceInfo("trainer", 3, 3), _RoleInstanceInfo("parameter_server", 4, 5), ] num_agents = len(role_infos) with patch.object(TestAgent, "_share_and_gather", return_value=role_infos): self.verify_worker_ranks( role_infos[0], num_agents, [0, 1, 2, 3], [0, 1, 2, 3] ) self.verify_worker_ranks(role_infos[1], num_agents, [4], [0]) self.verify_worker_ranks(role_infos[2], num_agents, [5, 6], [1, 2]) self.verify_worker_ranks(role_infos[3], num_agents, [7, 8, 9], [3, 4, 5])
def test_serde(self): agent_role = _RoleInstanceInfo("role", 1, 10) str_data = agent_role.serialize() actual_agent_role = _RoleInstanceInfo.deserialize(str_data) self.assertEqual(agent_role.role, actual_agent_role.role) self.assertEqual(agent_role.rank, actual_agent_role.rank) self.assertEqual(agent_role.local_world_size, actual_agent_role.local_world_size)
def test_compare(self): agent_role1 = _RoleInstanceInfo("role", 1, 10) agent_role2 = _RoleInstanceInfo("role", 2, 10) self.assertEqual(1, _RoleInstanceInfo.compare(agent_role2, agent_role1)) agent_role1 = _RoleInstanceInfo("role1", 1, 10) agent_role2 = _RoleInstanceInfo("role2", 2, 10) self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2)) agent_role1 = _RoleInstanceInfo("role1", 1, 10) agent_role2 = _RoleInstanceInfo("role2", 1, 10) self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2))