示例#1
0
    def test_share_and_gather(self, store_mock):
        # when the state is unknown we exit immediately; no retries
        spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1)
        agent = TestAgent(spec)
        expected_agent_infos = [
            _RoleInstanceInfo("trainer", 0, 10),
            _RoleInstanceInfo("trainer", 1, 10),
            _RoleInstanceInfo("validator", 2, 10),
        ]

        store_mock.return_value = [obj.serialize() for obj in expected_agent_infos]

        class DummyStore:
            def __init__(self):
                self.key = None
                self.value = None

            def set(self, key, value):
                self.key = key
                self.value = value

            def set_timeout(self, timeout):
                pass

        store = DummyStore()
        agent._share_and_gather(store, 1, 3, spec)
        self.assertEquals("torchelastic/role_info1", store.key)
        expected_info = _RoleInstanceInfo(spec.role, 1, spec.local_world_size)
        self.assertEquals(expected_info.serialize(), store.value)
        store_mock.assert_called_once()
示例#2
0
 def test_find_boundaries(self):
     role_infos = [
         _RoleInstanceInfo("trainer", 1, 1),
         _RoleInstanceInfo("trainer", 2, 2),
         _RoleInstanceInfo("trainer", 3, 3),
         _RoleInstanceInfo("parameter_server", 4, 5),
         _RoleInstanceInfo("parameter_server", 0, 4),
     ]
     start_idx, end_idx = _RoleInstanceInfo.find_role_boundaries(
         role_infos, "trainer")
     self.assertEqual(start_idx, 0)
     self.assertEqual(end_idx, 2)
示例#3
0
 def test_get_ranks(self):
     role_infos = [
         _RoleInstanceInfo("parameter_server", 0, 4),
         _RoleInstanceInfo("trainer", 1, 1),
         _RoleInstanceInfo("trainer", 2, 2),
         _RoleInstanceInfo("trainer", 3, 3),
         _RoleInstanceInfo("parameter_server", 4, 5),
     ]
     spec = self._get_worker_spec(
         max_restarts=3, monitor_interval=0.1, role="not_used", local_world_size=8
     )
     agent = TestAgent(spec)
     total_sum, ranks = agent._get_ranks(role_infos, 0, 0, len(role_infos))
     self.assertEquals(15, total_sum)
     self.assertEquals([0, 1, 2, 3], list(ranks))
示例#4
0
 def test_assign_worker_ranks(self):
     role_infos = [
         _RoleInstanceInfo("parameter_server", 0, 4),
         _RoleInstanceInfo("trainer", 1, 1),
         _RoleInstanceInfo("trainer", 2, 2),
         _RoleInstanceInfo("trainer", 3, 3),
         _RoleInstanceInfo("parameter_server", 4, 5),
     ]
     num_agents = len(role_infos)
     with patch.object(TestAgent, "_share_and_gather", return_value=role_infos):
         self.verify_worker_ranks(
             role_infos[0], num_agents, [0, 1, 2, 3], [0, 1, 2, 3]
         )
         self.verify_worker_ranks(role_infos[1], num_agents, [4], [0])
         self.verify_worker_ranks(role_infos[2], num_agents, [5, 6], [1, 2])
         self.verify_worker_ranks(role_infos[3], num_agents, [7, 8, 9], [3, 4, 5])
示例#5
0
 def test_serde(self):
     agent_role = _RoleInstanceInfo("role", 1, 10)
     str_data = agent_role.serialize()
     actual_agent_role = _RoleInstanceInfo.deserialize(str_data)
     self.assertEqual(agent_role.role, actual_agent_role.role)
     self.assertEqual(agent_role.rank, actual_agent_role.rank)
     self.assertEqual(agent_role.local_world_size,
                      actual_agent_role.local_world_size)
示例#6
0
 def test_compare(self):
     agent_role1 = _RoleInstanceInfo("role", 1, 10)
     agent_role2 = _RoleInstanceInfo("role", 2, 10)
     self.assertEqual(1, _RoleInstanceInfo.compare(agent_role2, agent_role1))
     agent_role1 = _RoleInstanceInfo("role1", 1, 10)
     agent_role2 = _RoleInstanceInfo("role2", 2, 10)
     self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2))
     agent_role1 = _RoleInstanceInfo("role1", 1, 10)
     agent_role2 = _RoleInstanceInfo("role2", 1, 10)
     self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2))