def test_synchronize(self): class DummyStore: def __init__(self): self._data = { "torchelastic/test0": "data0".encode(encoding="UTF-8"), "torchelastic/test1": "data1".encode(encoding="UTF-8"), "torchelastic/test2": "data2".encode(encoding="UTF-8"), } def set(self, key, value): self._data[key] = value def get(self, key): return self._data[key] def set_timeout(self, timeout): pass data = "data0".encode(encoding="UTF-8") store = DummyStore() res = store_util.synchronize(store, data, 0, 3, key_prefix="torchelastic/test") self.assertEqual(3, len(res)) for idx, res_data in enumerate(res): actual_str = res_data.decode(encoding="UTF-8") self.assertEqual(f"data{idx}", actual_str)
def _share_and_gather(self, store, group_rank: int, group_world_size: int, spec: WorkerSpec) -> List: agent_role_info = _RoleInstanceInfo(spec.role, group_rank, spec.local_world_size) key_prefix = "torchelastic/role_info" agent_config_enc = agent_role_info.serialize() role_infos_bytes = store_util.synchronize(store, agent_config_enc, group_rank, group_world_size, key_prefix) role_infos = [ _RoleInstanceInfo.deserialize(role_info_bytes) for role_info_bytes in role_infos_bytes ] return role_infos