Ejemplo n.º 1
0
    def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
        for data_parallel_size in range(1, self.world_size + 1):

            expected_global_batch_size = self.GLOBAL_BATCH_SIZE
            expected_micro_batch_size = self.MICRO_BATCH_SIZE
            if rampup_batch_size:
                expected_global_batch_size = rampup_batch_size[0]
                num_consumed_samples = 0
                step_of_global_batch_size = rampup_batch_size[1]
                threshold = rampup_batch_size[2]

            if data_parallel_size > 1 and data_parallel_size % 2 != 0:
                continue
            if self.world_size % data_parallel_size != 0:
                continue
            with self.subTest(data_parallel_size=data_parallel_size):
                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=self.world_size //
                    data_parallel_size,
                    pipeline_model_parallel_size_=1,
                )
                self.assertEqual(data_parallel_size,
                                 parallel_state.get_data_parallel_world_size())

                _reconfigure_microbatch_calculator(
                    self.rank,
                    rampup_batch_size,
                    self.GLOBAL_BATCH_SIZE,
                    self.MICRO_BATCH_SIZE,
                    data_parallel_size,
                )

                self.assertEqual(get_micro_batch_size(),
                                 expected_micro_batch_size)
                self.assertEqual(
                    get_num_microbatches(), expected_global_batch_size /
                    expected_micro_batch_size / data_parallel_size)
                current_global_batch_size = get_current_global_batch_size()
                self.assertEqual(current_global_batch_size,
                                 expected_global_batch_size)

                # Make sure `global_batch_size` equals to the final global batch size after
                # certain number of updates.
                if rampup_batch_size:
                    update_num_microbatches(current_global_batch_size)
                    for i in range(100):
                        current_global_batch_size = get_current_global_batch_size(
                        )
                        update_num_microbatches(current_global_batch_size)
                    current_global_batch_size = get_current_global_batch_size()
                    self.assertEqual(get_current_global_batch_size(),
                                     self.GLOBAL_BATCH_SIZE)
                parallel_state.destroy_model_parallel()
Ejemplo n.º 2
0
    def _forward_backward_test_impl(
        self,
        forward_only: bool,
        fwd_bwd_func: FwdStepFunc,
        pipeline_model_parallel_world_size: Optional[int],
        virtual_pipeline_model_parallel_size: Optional[int],
        async_comm: bool = False,
        *,
        default_backend: Optional[str] = None,
        p2p_backend: Optional[str] = None,
    ) -> None:
        if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
            self.assertIsNotNone(virtual_pipeline_model_parallel_size)
            self.assertGreater(virtual_pipeline_model_parallel_size, 1)
        dtype_options = self.dtypes or [torch.float32, torch.double
                                        ] + _get_autocast_dtypes()

        for dtype, deallocate_pipeline_outputs in itertools.product(
                dtype_options,
                self.deallocate_options,
        ):
            grad_scaler = (torch.cuda.amp.GradScaler(
                init_scale=4.0) if dtype == torch.half else None)

            (tensor_model_parallel_world_size, data_parallel_size,
             pipeline_model_parallel_world_size
             ) = _get_default_world_sizes_model_parallel_world_size(
                 pipeline_model_parallel_world_size)

            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_parallel_world_size,
                pipeline_model_parallel_size_=
                pipeline_model_parallel_world_size,
                virtual_pipeline_model_parallel_size_=
                virtual_pipeline_model_parallel_size,
                default_backend=default_backend,
                p2p_backend=p2p_backend,
            )
            pp_utils._reconfigure_microbatch_calculator(
                rank=parallel_state.get_tensor_model_parallel_rank(),
                rampup_batch_size=None,
                global_batch_size=self.GLOBAL_BATCH_SIZE,
                micro_batch_size=self.MICRO_BATCH_SIZE,
                data_parallel_size=parallel_state.get_data_parallel_world_size(
                ),
            )

            global_batch_shape = (
                self.GLOBAL_BATCH_SIZE //
                parallel_state.get_data_parallel_world_size(),
                self.HIDDEN_SIZE,
                self.HIDDEN_SIZE,
            )

            batch = None
            if parallel_state.is_pipeline_first_stage():
                batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )

            model = build_model(
                testing_utils.model_provider_func,
                # Use DDP only when it's better to have
                wrap_with_ddp=data_parallel_size > 1,
                virtual_pipeline_model_parallel_size=
                virtual_pipeline_model_parallel_size,
                hidden_size=self.HIDDEN_SIZE,
            )

            offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
            for idx, model_module in enumerate(model):
                model_module = model_module.to(dtype)
                model_module.apply(get_init_weights_func(idx * offset))

            _param_groups = _get_params_for_weight_decay_optimization(model)
            optimizer = torch.optim.Adam(_param_groups, lr=1e-3)

            pp_utils.update_num_microbatches(0)

            loss = fwd_bwd_func(
                testing_utils.fwd_step_func,
                batch,
                model,
                forward_only=forward_only,
                # `tensor_shape` is the shape of micro batch.
                tensor_shape=(
                    self.MICRO_BATCH_SIZE,
                    self.HIDDEN_SIZE,
                    self.HIDDEN_SIZE,
                ),
                dtype=dtype,
                async_comm=async_comm,
                grad_scaler=grad_scaler,
                deallocate_pipeline_output=deallocate_pipeline_outputs,
            )

            if dtype == torch.double:
                hidden_size = self.HIDDEN_SIZE
                microbatch_size = self.MICRO_BATCH_SIZE
                total_layers = pipeline_model_parallel_world_size
                if virtual_pipeline_model_parallel_size is not None:
                    total_layers *= virtual_pipeline_model_parallel_size
                target_loss, target_model = get_target_loss_and_model(
                    global_batch_shape, hidden_size, total_layers)

                for loss_item in loss:
                    x = loss_item['avg']
                    torch.testing.assert_close(x.item() / microbatch_size,
                                               target_loss.item())

                if not forward_only:
                    for vm_id, model_module in enumerate(model):
                        params = list(model_module.parameters())
                        rank = params[0].get_device()
                        offset = pipeline_model_parallel_world_size
                        param_id = rank // data_parallel_size + vm_id * offset
                        target_params = target_model[param_id]

                        torch.testing.assert_close(params[0].cpu(),
                                                   target_params[0])
                        torch.testing.assert_close(params[1].cpu(),
                                                   target_params[1])
                        torch.testing.assert_close(
                            params[0].grad.cpu() / microbatch_size,
                            target_params[0].grad)
                        torch.testing.assert_close(
                            params[1].grad.cpu() / microbatch_size,
                            target_params[1].grad)

            if not forward_only:
                for m in model:
                    for p in m.parameters():
                        self.assertIsNotNone(p.grad)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            parallel_state.destroy_model_parallel()
Ejemplo n.º 3
0
def forward_backward_func_template(
    name: str,
    forward_backward_func,
    pipeline_model_parallel_size: int,
    forward_only: bool,
) -> None:
    print_separator(
        f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}"
    )
    virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
    if name == "no_pipelining":
        # note (mkozuki): `forward_backward_no_pipelining` is **NOTE** compatible with
        # pipeline_model_parallel_size>1. So use pipeline_model_parallel_size as
        # tensor_model_parallel_size and set pipeline_model_parallel_size to 1.
        parallel_state.initialize_model_parallel(1, 1, None)
    else:
        # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is necessary to enable interleaving scheduling
        # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
        # used ubiquitously but this test uses custom model so it's safe to abuse.
        parallel_state.initialize_model_parallel(
            1, pipeline_model_parallel_size,
            virtual_pipeline_model_parallel_size)
        if virtual_pipeline_model_parallel_size is not None:
            # Check the experimental warning message
            get_forward_backward_func(virtual_pipeline_model_parallel_size,
                                      pipeline_model_parallel_size)
    pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size(
    )

    model = build_model(
        model_provider_func,
        wrap_with_ddp=True,
        virtual_pipeline_model_parallel_size=
        virtual_pipeline_model_parallel_size,
    )
    assert isinstance(model, list)
    assert len(model) == (1 if virtual_pipeline_model_parallel_size is None
                          else virtual_pipeline_model_parallel_size)
    _param_groups = _get_params_for_weight_decay_optimization(model)
    torch.optim.Adam(_param_groups, lr=1e-4)

    tensor_shape = [
        batch_size // parallel_state.get_data_parallel_world_size(),
        hidden_size
    ]
    batch = (torch.randn(tensor_shape).cuda(), )
    tensor_shape[0] = micro_batch_size

    update_num_microbatches(0)
    forward_backward_func(fwd_step_func,
                          batch,
                          model,
                          forward_only=forward_only,
                          tensor_shape=tensor_shape)

    if not forward_only:
        for m in model:
            for p in m.parameters():
                if p.grad is None:
                    raise RuntimeError("grad not found")
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)
Ejemplo n.º 4
0
def run_interleaved_with_dynamic_batch_size(
    pipeline_model_parallel_size: int,
    forward_only: bool,
    BatchSamplerCls,
) -> None:
    args = global_vars.get_args()
    _reconfigure_microbatch_calculator(
        args.rank,
        args.rampup_batch_size,
        args.global_batch_size,
        args.micro_batch_size,
        1,  # args.data_parallel_size,
    )
    virtual_pipeline_model_parallel_size = 2
    # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling
    # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
    # used ubiquitously but this test uses custom model so it's safe to abuse.
    parallel_state.initialize_model_parallel(
        1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
    pipeline_model_parallel_size = (
        parallel_state.get_pipeline_model_parallel_world_size())

    print_separator(
        f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}"
    )

    model = build_model(
        model_provider_func,
        wrap_with_ddp=True,
        virtual_pipeline_model_parallel_size=
        virtual_pipeline_model_parallel_size,
        hidden_size=HIDDEN_SIZE,
    )
    assert isinstance(model, list)
    assert len(model) == virtual_pipeline_model_parallel_size
    optimizer = torch.optim.Adam(
        _get_params_for_weight_decay_optimization(model))

    initial_local_minibatch_size = get_num_microbatches() * micro_batch_size
    dataset = Dataset(NUM_SAMPLES)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=BatchSamplerCls(
            NUM_SAMPLES,
            0,
            initial_local_minibatch_size,
            parallel_state.get_data_parallel_rank(),
            parallel_state.get_data_parallel_world_size(),
        ),
    )
    data_iter = iter(data_loader)

    def get_num_samples(batch):
        if isinstance(batch, torch.Tensor):
            return len(batch)
        assert isinstance(batch, (list, tuple))
        return [get_num_samples(b) for b in batch]

    tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE]
    consumed_samples = 0
    for i in range(NUM_ITERATIONS):
        update_num_microbatches(consumed_samples, consistency_check=False)
        local_batch_size = get_num_microbatches() * micro_batch_size
        data_iter._index_sampler.local_minibatch_size = local_batch_size
        local_mini_batch = next(data_iter)

        _logger.info(f"iter: {i} / {NUM_ITERATIONS} "
                     f"local batchsize: {get_num_samples(local_mini_batch)} "
                     f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}")
        _forward_backward_pipelining_with_interleaving(
            fwd_step_func,
            local_mini_batch,
            model,
            forward_only=forward_only,
            tensor_shape=tensor_shape,
        )

        consumed_samples += (parallel_state.get_data_parallel_world_size() *
                             get_num_microbatches() * micro_batch_size)

        if not forward_only:
            for m in model:
                for p in m.parameters():
                    if p.grad is None:
                        raise RuntimeError("grad not found")
            else:
                optimizer.zero_grad(set_to_none=True)

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print(TEST_SUCCESS_MESSAGE)