def test_get_host_assignments(self): hosts = parse_hosts('worker-0:2,worker-1:2') np = 4 assignments = get_host_assignments(hosts, np) sizes = dict(size=4, local_size=2, cross_size=2) expected = [ SlotInfo(hostname='worker-0', rank=0, local_rank=0, cross_rank=0, **sizes), SlotInfo(hostname='worker-0', rank=1, local_rank=1, cross_rank=0, **sizes), SlotInfo(hostname='worker-1', rank=2, local_rank=0, cross_rank=1, **sizes), SlotInfo(hostname='worker-1', rank=3, local_rank=1, cross_rank=1, **sizes) ] self.assertListEqual(assignments, expected)
def test_get_host_assignments_heterogeneous(self): hosts = parse_hosts('worker-0:1,worker-1:2') np = 3 assignments = get_host_assignments(hosts, np) expected = [ SlotInfo(hostname='worker-0', rank=0, local_rank=0, cross_rank=0, size=3, local_size=1, cross_size=2), SlotInfo(hostname='worker-1', rank=1, local_rank=0, cross_rank=1, size=3, local_size=2, cross_size=2), SlotInfo(hostname='worker-1', rank=2, local_rank=1, cross_rank=0, size=3, local_size=2, cross_size=1) ] self.assertListEqual(assignments, expected)
def test_get_host_assignments_elastic(self): hosts = parse_hosts('worker-0:2,worker-1:2') min_np = 1 max_np = 2 assignments = get_host_assignments(hosts, min_np=min_np, max_np=max_np) sizes = dict(size=2, local_size=2, cross_size=1) expected = [SlotInfo(hostname='worker-0', rank=0, local_rank=0, cross_rank=0, **sizes), SlotInfo(hostname='worker-0', rank=1, local_rank=1, cross_rank=0, **sizes)] self.assertListEqual(assignments, expected)