def test_stage_to_global(self): topo = Topo(axes=['pipe', 'data'], dims=[2, 2]) grid = Grid(topology=topo) assert grid._is_grid_valid() assert grid.stage_to_global(stage_id=0, data=0) == 0 assert grid.stage_to_global(stage_id=0, data=1) == 1 assert grid.stage_to_global(stage_id=1, data=0) == 2 assert grid.stage_to_global(stage_id=1, data=1) == 3 me = topo.get_coord(rank=dist.get_rank()) if me.data == 0: assert grid.stage_to_global(stage_id=0) == 0 assert grid.stage_to_global(stage_id=1) == 2 else: assert grid.stage_to_global(stage_id=0) == 1 assert grid.stage_to_global(stage_id=1) == 3
def test_grid_pipe_data(self): topo = Topo(axes=['pipe', 'data'], dims=[2, 2]) grid = Grid(topology=topo) assert grid._is_grid_valid() rank = dist.get_rank() assert grid.is_first_stage == (grid.get_stage_id() == 0) assert grid.is_last_stage == ( grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1) # Test collectives along the pipeline parallel process groups rank_tensor = torch.LongTensor(data=[rank]).cuda() dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group()) pipe_group = grid.pp_group assert torch.all(rank_tensor == sum(pipe_group)) # Test collectives along the data parallel process groups rank_tensor = torch.LongTensor(data=[rank]).cuda() dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group()) data_group = grid.dp_group assert torch.all(rank_tensor == sum(data_group))