Пример #1
0
    def test_stage_to_global(self):
        topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
        grid = Grid(topology=topo)

        assert grid._is_grid_valid()

        assert grid.stage_to_global(stage_id=0, data=0) == 0
        assert grid.stage_to_global(stage_id=0, data=1) == 1
        assert grid.stage_to_global(stage_id=1, data=0) == 2
        assert grid.stage_to_global(stage_id=1, data=1) == 3

        me = topo.get_coord(rank=dist.get_rank())
        if me.data == 0:
            assert grid.stage_to_global(stage_id=0) == 0
            assert grid.stage_to_global(stage_id=1) == 2
        else:
            assert grid.stage_to_global(stage_id=0) == 1
            assert grid.stage_to_global(stage_id=1) == 3
Пример #2
0
    def test_grid_pipe_data(self):
        topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
        grid = Grid(topology=topo)

        assert grid._is_grid_valid()

        rank = dist.get_rank()

        assert grid.is_first_stage == (grid.get_stage_id() == 0)
        assert grid.is_last_stage == (
            grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1)

        # Test collectives along the pipeline parallel process groups
        rank_tensor = torch.LongTensor(data=[rank]).cuda()
        dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group())
        pipe_group = grid.pp_group
        assert torch.all(rank_tensor == sum(pipe_group))

        # Test collectives along the data parallel process groups
        rank_tensor = torch.LongTensor(data=[rank]).cuda()
        dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group())
        data_group = grid.dp_group
        assert torch.all(rank_tensor == sum(data_group))