Esempio n. 1
0
    def set_amp_args(self):
        """
        Two automatic mixed precision implementations are available: Apex's and PyTorch's.

        - If Apex's AMP is enabled, amp_args is a dictionary containing arguments
        to be passed to amp.initialize. Set to None to disable amp.
        To enable mixed precision training, pass amp_args={"opt_level": "O1"} here.
        See https://nvidia.github.io/apex/amp.html for more info.

        - If Pytorch's AMP is enabled, no arguments are needed.
        """

        if self.config.MODEL.AMP_PARAMS.USE_AMP:
            assert (
                self.device.type == "cuda"
            ), "Mixed precision is only available on CUDA devices for now"

            # This will rightly fail if the setting is not correct
            self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()]

            # Check Apex availability
            if self.amp_type == AmpType.APEX:
                if not is_apex_available():
                    raise RuntimeError(
                        "Apex is not available. Can't use mixed precision"
                    )

                # "amp_args" are actually Apex Amp args
                self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS
                logging.info(f"Setting AMP: using apex, args {self.amp_args}")

            elif self.amp_type == AmpType.PYTORCH:
                # if the optimizer is sharded or FSDP data parallel is used, then the GradScaler
                # needs to be shard-aware.
                if (
                    self.config["TRAINER"]["TASK_NAME"] == "self_supervision_fsdp_task"
                    or self.config["OPTIMIZER"]["name"] == "zero"
                ):
                    assert is_fairscale_sharded_available(), (
                        "To use ZeRO with PyTorch AMP, ShardedGradScaler() "
                        "from fairscale is needed. Please upgrade fairscale"
                    )
                    from fairscale.optim.grad_scaler import ShardedGradScaler

                    self.amp_grad_scaler = ShardedGradScaler()
                    logging.info("Setting AMP: using sharded grad scaler")
                else:
                    self.amp_grad_scaler = TorchGradScaler()
                    logging.info("Setting AMP: using pytorch grad scaler")
            logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}")

        else:
            self.amp_args, self.amp_type = None, None
            logging.info("Not using Automatic Mixed Precision")
Esempio n. 2
0
    def _init_pytorch_grad_scaler(self):
        if self.config["OPTIMIZER"]["name"] == "zero":
            assert is_fairscale_sharded_available(), (
                "To use ZeRO with PyTorch AMP, ShardedGradScaler() "
                "from fairscale is needed. Please upgrade fairscale")
            from fairscale.optim.grad_scaler import ShardedGradScaler

            self.amp_grad_scaler = ShardedGradScaler()
            logging.info("Setting AMP: using sharded grad scaler")
        else:
            self.amp_grad_scaler = TorchGradScaler()
            logging.info("Setting AMP: using pytorch grad scaler")
    def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
        """Disable / enable apex.amp and set the automatic mixed precision parameters.

        apex.amp can be utilized for mixed / half precision training.

        Args:
            amp_args: Dictionary containing arguments to be passed to
            amp.initialize. Set to None to disable amp.  To enable mixed
            precision training, pass amp_args={"opt_level": "O1"} here.
            See https://nvidia.github.io/apex/amp.html for more info.

        Raises:
            RuntimeError: If opt_level is not None and apex is not installed.

        Warning: apex needs to be installed to utilize this feature.
        """
        self.amp_args = amp_args

        if amp_args is None:
            logging.info("AMP disabled")
        else:
            # Check that the requested AMP type is known
            try:
                self.amp_type = AmpType[self.amp_args["amp_type"].upper()]
            except KeyError:
                logging.info("AMP type not specified, defaulting to Apex")
                self.amp_type = AmpType.APEX

            # Check for CUDA availability, required for both Apex and Pytorch AMP
            if not torch.cuda.is_available():
                raise RuntimeError(
                    "AMP is required but CUDA is not supported, cannot enable AMP"
                )

            # Check for Apex availability
            if self.amp_type == AmpType.APEX and not apex_available:
                raise RuntimeError(
                    "Apex AMP is required but Apex is not installed, cannot enable AMP"
                )

            # Set Torch AMP grad scaler, used to prevent gradient underflow
            elif self.amp_type == AmpType.PYTORCH:
                self.amp_grad_scaler = TorchGradScaler()

            logging.info(f"AMP enabled with args {amp_args}")
        return self
Esempio n. 4
0
    def set_amp_args(self):
        """
        Two automatic mixed precision implementations are available: Apex's and PyTorch's.

        - If Apex's AMP is enabled, amp_args is a dictionary containing arguments
        to be passed to amp.initialize. Set to None to disable amp.
        To enable mixed precision training, pass amp_args={"opt_level": "O1"} here.
        See https://nvidia.github.io/apex/amp.html for more info.

        - If Pytorch's AMP is enabled, no arguments are needed.
        """

        if self.config.MODEL.AMP_PARAMS.USE_AMP:
            assert (
                self.device.type == "cuda"
            ), "Mixed precision is only available on CUDA devices for now"

            # This will rightly fail if the setting is not correct
            self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()]

            # Check Apex availability
            if self.amp_type == AmpType.APEX:
                if not is_apex_available():
                    raise RuntimeError(
                        "Apex is not available. Can't use mixed precision"
                    )

                # "amp_args" are actually Apex Amp args
                self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS

            elif self.amp_type == AmpType.PYTORCH:
                # If the optimizer is sharded, then the GradScaler needs to be shard-aware
                self.amp_grad_scaler = (
                    ShardedGradScaler()
                    if self.config["OPTIMIZER"]["name"] == "zero"
                    else TorchGradScaler()
                )
            logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}")

        else:
            self.amp_args, self.amp_type = None, None
            logging.info("Not using Automatic Mixed Precision")
Esempio n. 5
0
def train(
    rank: int,
    args: argparse.Namespace,
    backend: str = "gloo",
    optim_type: OptimType = OptimType.vanilla,
    check_regression: bool = True,
):
    logging.basicConfig(
        level=logging.INFO if not args.debug else logging.DEBUG)

    use_multi_tensor = args.multi_tensor_optim and hasattr(
        torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

    # DDP
    dist_init(rank=rank, world_size=args.world_size, backend=backend)

    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.world_size,
                                             args.batch_size, device,
                                             args.model)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None
    model = cast(nn.Module, model)
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else
              ShardedGradScaler()) if args.amp else None

    if optim_type == OptimType.oss_sharded_ddp:
        optimizer = OSS(params=model.parameters(),
                        optim=OPTIM,
                        lr=1e-4,
                        momentum=0.9)
        # Single node run typically, no need for reduce buckets
        model = ShardedDDP(model, optimizer, reduce_buffer_size=0)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore
        optimizer = (OSS(
            params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
                     if optim_type == OptimType.oss_ddp else OPTIM(
                         model.parameters(), lr=1e-4, momentum=0.9))
    optimizer = cast(torch.optim.Optimizer, optimizer)

    # Reset the memory use counter
    if not args.cpu:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)

    # Standard training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0
    need_profiling = args.profile

    for epoch in range(args.epochs):
        n_items = 0
        epoch_runtime = 0.0

        for batch in dataloader:
            if not args.cpu:
                torch.cuda.synchronize(rank)
            batch_start = time.monotonic()

            def closure(data=batch, grad_scaler=None):
                model.zero_grad()
                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("\nbefore:  param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                if grad_scaler is not None:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
                else:
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
                    loss.backward()

                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("after BW: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                return loss

            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(
                        grad_scaler=scaler
                    )  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

            if need_profiling and not args.cpu:
                logging.info("Profiling the run")
                with profiler.profile(
                        use_cuda=True, record_shapes=True,
                        profile_memory=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = run_closure(closure, scaler, optimizer)

                prof.export_chrome_trace(
                    f"{optim_type}_trace_rank_{rank}.json")
                need_profiling = False  # only profile once

            else:
                final_loss = run_closure(closure, scaler, optimizer)

            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(
                    next(model.buffers()).norm().item()))
                logging.debug("after update: param {} -- grad {}".format(
                    next(model.parameters()).norm().item(),
                    next(model.parameters()).grad.norm().item()))

            n_items += args.batch_size

            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch_start

        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                logging.info("... State dict collected")

        measurements.append(n_items / epoch_runtime)
        if dist.get_rank() == 0:
            logging.info(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    training_stop = time.monotonic()
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
    logging.info(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint"
    )

    validate_benchmark(measurements, final_loss, args, check_regression)

    dist.destroy_process_group()  # type: ignore
Esempio n. 6
0
def run_ddp_parity(
    rank,
    world_size,
    backend,
    temp_file_name,
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
    multiple_fw,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
    NUMBER_BATCHS = 5

    # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
    print(
        f"{rank}: Checking configuration: accumulate {grad_accumulation}" +
        f" - change train graph {change_train_graph}" + f" - amp {amp}" +
        f" - manual reduction {manual_reduction}" +
        f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}",
        flush=True,
    )

    # The API should be the exact same in between the sharded and non-sharded variants, generic closure
    def closure(model,
                scaler,
                input_tensor,
                should_accumulate,
                _manual_reduction=False):
        accumulate_steps = 3 if should_accumulate else 1

        model.zero_grad()

        def step():
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    loss = model(input_tensor).abs().sum()
                    scaler.scale(loss).backward()
            else:
                loss = model(input_tensor).abs().sum()
                loss.backward()

        with model.no_sync() if should_accumulate else suppress():
            for _ in range(accumulate_steps - 1):
                step()

        if not _manual_reduction:
            step()
        else:
            with model.no_sync():
                step()

            model.reduce()

    # Any model works. Add one different buffer per rank
    model = _get_mlp_emb(multiple_fw)
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    # Make sure that the model starts with non-trainable, so that we check for the buckets to be
    # properly reassigned when/if this changes
    next(model.parameters()).requires_grad = False

    sharded_optimizer = OSS(params=model.parameters(),
                            optim=torch.optim.SGD,
                            lr=1e-4,
                            momentum=0.99)
    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=sharded_optimizer,
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=fp16_reduction,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                    lr=1e-4,
                                    momentum=0.99)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True,
                    find_unused_parameters=True)

    if fp16_reduction:
        from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

        ddp_model.register_comm_hook(state=None,
                                     hook=fp16_compress_hook)  # type: ignore

    ddp_scaler = TorchGradScaler() if amp else None
    sharded_scaler = ShardedGradScaler() if amp else None

    # The model should be synchronized in between the ranks at construction time, check that
    check_same_model_params(sharded_ddp_model, ddp_model)

    # Typical training loop, check that we get the exact same results as DDP
    for i in range(NUMBER_BATCHS):
        input_tensor = _get_random_inputs(device)

        def ddp_closure(input_tensor=input_tensor):
            return closure(ddp_model, ddp_scaler, input_tensor,
                           grad_accumulation)

        def sharded_closure(input_tensor=input_tensor):
            return closure(
                sharded_ddp_model,
                sharded_scaler,
                input_tensor,
                grad_accumulation,
                _manual_reduction=manual_reduction,
            )

        # Step/scale both
        for _scaler, _closure, _optimizer in (
            (ddp_scaler, ddp_closure, ddp_optimizer),
            (sharded_scaler, sharded_closure, sharded_optimizer),
        ):
            if _scaler is not None:
                _ = _closure(input_tensor)
                _scaler.step(_optimizer)
                _scaler.update()
            else:
                _optimizer.step(_closure())

        check_same_model_params(sharded_ddp_model, ddp_model,
                                f"Rank: {rank} - Step {i} broke")

        # Check that the two grad norm are equivalent
        # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
        # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
        # be valid for ShardedDDP
        # NOTE: DDP does not handle parameters trainability being changed after the fact, see
        # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
        if clip_grad_norm and not change_train_graph:
            if torch_version() >= (1, 9, 0):
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(),
                    0.3,
                    norm_type=2.0,
                    error_if_nonfinite=False)  # type: ignore
            else:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
            if not torch.isnan(total_norm):
                oss_total_norm = sharded_optimizer.clip_grad_norm(
                    0.3, norm_type=2.0)
                allclose = torch.allclose(oss_total_norm,
                                          total_norm,
                                          atol=1e-2 if amp else 1e-8)

                if not allclose:
                    # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
                    for idx, (p_ddp, p_sdp) in enumerate(
                            zip(ddp_model.parameters(),
                                sharded_ddp_model.parameters())):
                        if p_ddp.grad is not None:
                            if p_sdp.grad is not None:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      torch.norm(p_sdp.grad),
                                      flush=True)
                            else:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      "not owned",
                                      flush=True)

                assert (
                    allclose
                ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
            else:
                print(rank, "NaN grad norm in DDP", flush=True)

        # Flip the trainability of the first parameter back and forth
        if i == 0 and change_train_graph:
            next(sharded_ddp_model.parameters()).requires_grad = not next(
                sharded_ddp_model.parameters()).requires_grad
            next(ddp_model.parameters()).requires_grad = not next(
                ddp_model.parameters()).requires_grad
            check_same_model_params(
                sharded_ddp_model, ddp_model,
                f"Rank: {rank} - Trainability refresh {i} broke")

    dist.destroy_process_group()
Esempio n. 7
0
    def check_parity(manual_reduction: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model,
                    scaler,
                    input_tensor,
                    should_accumulate,
                    _manual_reduction=False):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            if not _manual_reduction:
                step()
            else:
                with model.no_sync():
                    step()

                model.reduce()

        # Any model works. Add one different buffer per rank
        model = _get_mlp()
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(),
                                optim=torch.optim.SGD,
                                lr=1e-4,
                                momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model,
            sharded_optimizer=sharded_optimizer,
            broadcast_buffers=True,
            reduce_buffer_size=reduce_buffer_size,
            reduce_fp16=fp16_reduction,
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                        lr=1e-4,
                                        momentum=0.99)
        ddp_model = DDP(ddp_model_single,
                        device_ids=[rank],
                        broadcast_buffers=True,
                        find_unused_parameters=True)

        if fp16_reduction:
            from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

            ddp_model.register_comm_hook(
                state=None, hook=fp16_compress_hook)  # type: ignore

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)

            def ddp_closure(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor,
                               grad_accumulation)

            def sharded_closure(input_tensor=input_tensor):
                return closure(
                    sharded_ddp_model,
                    sharded_scaler,
                    input_tensor,
                    grad_accumulation,
                    _manual_reduction=manual_reduction,
                )

            # Step/scale both
            for _scaler, _closure, _optimizer in (
                (ddp_scaler, ddp_closure, ddp_optimizer),
                (sharded_scaler, sharded_closure, sharded_optimizer),
            ):
                if _scaler is not None:
                    _ = _closure(input_tensor)
                    _scaler.step(_optimizer)
                    _scaler.update()

            check_same_model_params(sharded_ddp_model, ddp_model,
                                    f"Rank: {rank} - Step {i} broke")

            # Check that the two grad norm are equivalent
            # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
            # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
            # be valid for ShardedDDP
            if clip_grad_norm:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
                if not torch.isnan(total_norm):
                    oss_total_norm = sharded_optimizer.clip_grad_norm(
                        0.3, norm_type=2.0)
                    assert torch.allclose(
                        oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8
                    ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
                else:
                    print(rank, "NaN grad norm in DDP", flush=True)

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(
                    ddp_model.parameters()).requires_grad
                check_same_model_params(
                    sharded_ddp_model, ddp_model,
                    f"Rank: {rank} - Trainability refresh {i} broke")
Esempio n. 8
0
    def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model, scaler, input_tensor, should_accumulate):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            step()

        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor, accumulate)

            def closure_sharded(input_tensor=input_tensor):
                return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()
                ).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
                check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
Esempio n. 9
0
def train(
    rank: int,
    args: argparse.Namespace,
    backend: str = "gloo",
    optim_type: OptimType = OptimType.vanilla,
    check_regression: bool = True,
):
    logging.basicConfig(
        level=logging.INFO if not args.debug else logging.DEBUG)

    # DDP
    dist_init(rank=rank, world_size=args.world_size, backend=backend)

    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.world_size,
                                             args.batch_size, device,
                                             args.torchvision_model)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None
    model = cast(nn.Module, model)
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else
              ShardedGradScaler()) if args.amp else None

    if optim_type == OptimType.oss_sharded_ddp:
        optimizer = OSS(params=model.parameters(),
                        optim=OPTIM,
                        lr=1e-4,
                        momentum=0.9)
        model = ShardedDDP(model, optimizer)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore
        optimizer = (OSS(
            params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
                     if optim_type == OptimType.oss_ddp else OPTIM(
                         model.parameters(), lr=1e-4, momentum=0.9))
    optimizer = cast(torch.optim.Optimizer, optimizer)

    # Reset the memory use counter
    if not args.cpu:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)

    # Standard training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0
    need_profiling = args.profile

    for epoch in range(args.epochs):
        n_items = 0
        epoch_runtime = 0.0

        for batch in dataloader:
            if not args.cpu:
                torch.cuda.synchronize(rank)
            batch__start = time.monotonic()

            def closure(data=batch, grad_scaler=None):
                model.zero_grad()
                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("\nbefore:  param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                if grad_scaler is not None:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
                else:
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
                    loss.backward()

                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("after BW: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                return loss

            if need_profiling and not args.cpu:
                logging.info("Profiling the run")
                with profiler.profile(
                        use_cuda=True, record_shapes=True,
                        profile_memory=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        if scaler is not None:
                            final_loss = closure(
                                grad_scaler=scaler
                            )  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)

                prof.export_chrome_trace(
                    f"{optim_type}_trace_rank_{rank}.json")
                need_profiling = False  # only profile once

            else:
                if scaler is not None:
                    final_loss = closure(
                        grad_scaler=scaler
                    )  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)

            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(
                    next(model.buffers()).norm().item()))
                logging.debug("after update: param {} -- grad {}".format(
                    next(model.parameters()).norm().item(),
                    next(model.parameters()).grad.norm().item()))

            n_items += args.batch_size

            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start

        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                logging.info("... State dict collected")

        measurements.append(n_items / epoch_runtime)
        if dist.get_rank() == 0:
            logging.info(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    max_memory = -1.0
    if not args.cpu:
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2**20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    training_stop = time.monotonic()
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
    logging.info(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint"
    )

    # Compute the median and median of absolute differences img per second
    measurements.sort()
    median = measurements[len(measurements) // 2]

    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    logging.info(
        f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and dist.get_rank() == 0:
        assert (median +
                3.0 * mad) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) -
                   args.reference_loss) < 1e-3, "Loss regression detected"

        logging.info("[Regression Test] VALID")

    dist.destroy_process_group()  # type: ignore
Esempio n. 10
0
    def check_parity(amp: bool):
        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        def check_same_model_params():
            for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
                for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                    assert torch.allclose(
                        p, ddp_p, atol=1e-3
                    ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"

            for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
                assert torch.allclose(
                    b, ddp_b, atol=1e-3
                ), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params()

        # The models should stay the same in between the ranks
        for i in range(10):
            input_tensor = torch.rand((64, 2)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                ddp_optimizer.zero_grad()

                if ddp_scaler is not None:
                    with torch.cuda.amp.autocast():
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_scaler.scale(ddp_loss).backward()
                else:
                    ddp_loss = ddp_model(input_tensor).abs().sum()
                    ddp_loss.backward()
                return ddp_loss

            def closure_sharded(input_tensor=input_tensor):
                sharded_optimizer.zero_grad()

                if sharded_ddp_scaler is not None:
                    with torch.cuda.amp.autocast():
                        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                        sharded_ddp_scaler.scale(sharded_loss).backward()
                else:
                    sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                    sharded_loss.backward()
                return sharded_loss

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params()