def test_scaler_cpu_offload_breaks(): device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(random.randint(2000, 3000)) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: scaler = ShardedGradScaler() model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True) optim = torch.optim.SGD(model.parameters(), lr=1e-3) input = torch.rand((1, 5), dtype=torch.float).to(device) optim.zero_grad() with autocast(): output = model(input) loss = F.mse_loss(input, output) scaler.scale(loss).backward() # TODO (Min): Need to fix. Details in issue #421. with pytest.raises(RuntimeError): scaler.step(optim) scaler.update() finally: # Clean-up is important or the next test in this file may fail to init the PG. torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def new_process(self, process_idx, trainer, mp_queue): # Ensure that the scaler points to the correct process group # which is re-initialized in a new process precision_plugin = trainer.accelerator.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() super().new_process(process_idx, trainer, mp_queue)
def _init_pytorch_grad_scaler(self): assert is_fairscale_sharded_available(), ( "To use FSDP with PyTorch AMP, ShardedGradScaler() " "from fairscale is needed. Please upgrade fairscale") from fairscale.optim.grad_scaler import ShardedGradScaler self.amp_grad_scaler = ShardedGradScaler() logging.info("Setting AMP: using ShardedGradScaler")
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process precision_plugin = trainer.accelerator.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue)
def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None): model_device = next(model.parameters()).device # use SGD with momentum instead of Adam, since Adam is scale invariant # and this makes it bad for tests optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9) scaler = ShardedGradScaler() for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): # Inputs always cuda regardless of move_grads_cpu, or model.device input = model.module.get_input(torch.device("cuda")) output = model(*input) loss = model.module.get_loss(input, output).to(model_device) loss = scaler.scale(loss) assert loss.dtype == torch.float32 model.module.run_backward(loss) if norm_type is not None: clip_norm = 0.3 if isinstance(model, FullyShardedDataParallel): model.clip_grad_norm_(clip_norm, norm_type) else: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) scaler.step(optim) scaler.update() if hasattr(model, "assert_idle"): model.assert_idle() if isinstance(model, FullyShardedDataParallel): model.assert_state(TrainingState.IDLE) return loss.detach()
def __init__( self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: if not _FAIRSCALE_AVAILABLE: raise MisconfigurationException( "You have asked for sharded AMP but you have not installed it." " Install `fairscale` using this guide: https://https://github.com/facebookresearch/fairscale" ) super().__init__(precision, device, scaler=scaler or ShardedGradScaler())
def set_amp_args(self): """ Two automatic mixed precision implementations are available: Apex's and PyTorch's. - If Apex's AMP is enabled, amp_args is a dictionary containing arguments to be passed to amp.initialize. Set to None to disable amp. To enable mixed precision training, pass amp_args={"opt_level": "O1"} here. See https://nvidia.github.io/apex/amp.html for more info. - If Pytorch's AMP is enabled, no arguments are needed. """ if self.config.MODEL.AMP_PARAMS.USE_AMP: assert ( self.device.type == "cuda" ), "Mixed precision is only available on CUDA devices for now" # This will rightly fail if the setting is not correct self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()] # Check Apex availability if self.amp_type == AmpType.APEX: if not is_apex_available(): raise RuntimeError( "Apex is not available. Can't use mixed precision" ) # "amp_args" are actually Apex Amp args self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS logging.info(f"Setting AMP: using apex, args {self.amp_args}") elif self.amp_type == AmpType.PYTORCH: # if the optimizer is sharded or FSDP data parallel is used, then the GradScaler # needs to be shard-aware. if ( self.config["TRAINER"]["TASK_NAME"] == "self_supervision_fsdp_task" or self.config["OPTIMIZER"]["name"] == "zero" ): assert is_fairscale_sharded_available(), ( "To use ZeRO with PyTorch AMP, ShardedGradScaler() " "from fairscale is needed. Please upgrade fairscale" ) from fairscale.optim.grad_scaler import ShardedGradScaler self.amp_grad_scaler = ShardedGradScaler() logging.info("Setting AMP: using sharded grad scaler") else: self.amp_grad_scaler = TorchGradScaler() logging.info("Setting AMP: using pytorch grad scaler") logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}") else: self.amp_args, self.amp_type = None, None logging.info("Not using Automatic Mixed Precision")
def __init__( self, address: str = None, port: Union[str, int] = None, ddp_kwargs: Dict[str, Any] = None, process_group_kwargs: Dict[str, Any] = None, scaler_kwargs: Dict[str, Any] = None, ): """Init.""" super().__init__( address=address, port=port, ddp_kwargs=ddp_kwargs, process_group_kwargs=process_group_kwargs, ) # @TODO: should we support scaler for each optimizer? if scaler_kwargs is None: scaler_kwargs = {} self.scaler_kwargs = scaler_kwargs self.scaler = ShardedGradScaler(**self.scaler_kwargs)
def _init_pytorch_grad_scaler(self): if self.config["OPTIMIZER"]["name"] == "zero": assert is_fairscale_sharded_available(), ( "To use ZeRO with PyTorch AMP, ShardedGradScaler() " "from fairscale is needed. Please upgrade fairscale") from fairscale.optim.grad_scaler import ShardedGradScaler self.amp_grad_scaler = ShardedGradScaler() logging.info("Setting AMP: using sharded grad scaler") else: self.amp_grad_scaler = TorchGradScaler() logging.info("Setting AMP: using pytorch grad scaler")
def set_amp_args(self): """ Two automatic mixed precision implementations are available: Apex's and PyTorch's. - If Apex's AMP is enabled, amp_args is a dictionary containing arguments to be passed to amp.initialize. Set to None to disable amp. To enable mixed precision training, pass amp_args={"opt_level": "O1"} here. See https://nvidia.github.io/apex/amp.html for more info. - If Pytorch's AMP is enabled, no arguments are needed. """ if self.config.MODEL.AMP_PARAMS.USE_AMP: assert ( self.device.type == "cuda" ), "Mixed precision is only available on CUDA devices for now" # This will rightly fail if the setting is not correct self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()] # Check Apex availability if self.amp_type == AmpType.APEX: if not is_apex_available(): raise RuntimeError( "Apex is not available. Can't use mixed precision" ) # "amp_args" are actually Apex Amp args self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS elif self.amp_type == AmpType.PYTORCH: # If the optimizer is sharded, then the GradScaler needs to be shard-aware self.amp_grad_scaler = ( ShardedGradScaler() if self.config["OPTIMIZER"]["name"] == "zero" else TorchGradScaler() ) logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}") else: self.amp_args, self.amp_type = None, None logging.info("Not using Automatic Mixed Precision")
def load_fp16_scaler(self): if self.training_config.fp16: assert ( torch.__version__ >= "1.6" ), "Using fp16 requires torch version >- 1.6" assert self.device != torch.device("cpu"), "fp16 cannot be used on cpu" set_torch_grad_scaler = True if self.training_config.fp16 and self.distributed: try: from fairscale.optim.oss import OSS from fairscale.optim.grad_scaler import ShardedGradScaler if isinstance(self.optimizer, OSS): self.scaler = ShardedGradScaler() set_torch_grad_scaler = False logger.info("Using FairScale ShardedGradScaler") except ImportError: logger.info("Using Pytorch AMP GradScaler") if set_torch_grad_scaler: self.scaler = torch.cuda.amp.GradScaler(enabled=self.training_config.fp16)
def check_parity(amp: bool): # Any model works. Add one different buffer per rank model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) ddp_scaler = TorchGradScaler() if amp else None sharded_ddp_scaler = ShardedGradScaler() if amp else None def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose( b, ddp_b, atol=1e-3 ), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}" # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() # The models should stay the same in between the ranks for i in range(10): input_tensor = torch.rand((64, 2)).to(device) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() if ddp_scaler is not None: with torch.cuda.amp.autocast(): ddp_loss = ddp_model(input_tensor).abs().sum() ddp_scaler.scale(ddp_loss).backward() else: ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() if sharded_ddp_scaler is not None: with torch.cuda.amp.autocast(): sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_ddp_scaler.scale(sharded_loss).backward() else: sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss # Step/scale both if ddp_scaler is not None: _ = closure_ddp(input_tensor) ddp_scaler.step(ddp_optimizer) ddp_scaler.update() else: ddp_optimizer.step(closure=closure_ddp) if sharded_ddp_scaler is not None: _ = closure_sharded(input_tensor) sharded_ddp_scaler.step(sharded_optimizer) sharded_ddp_scaler.update() else: sharded_optimizer.step(closure=closure_sharded) check_same_model_params()
def train( rank: int, args: argparse.Namespace, backend: str = "gloo", optim_type: OptimType = OptimType.vanilla, check_regression: bool = True, ): logging.basicConfig( level=logging.INFO if not args.debug else logging.DEBUG) use_multi_tensor = args.multi_tensor_optim and hasattr( torch.optim, "_multi_tensor") OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop # type: ignore # attr is checked but mypy misses that logging.info("Multi tensor optimizer: {}".format(use_multi_tensor)) # DDP dist_init(rank=rank, world_size=args.world_size, backend=backend) # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cpu") if args.cpu else torch.device(rank) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None model = cast(nn.Module, model) scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None if optim_type == OptimType.oss_sharded_ddp: optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) # Single node run typically, no need for reduce buckets model = ShardedDDP(model, optimizer, reduce_buffer_size=0) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore optimizer = (OSS( params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if optim_type == OptimType.oss_ddp else OPTIM( model.parameters(), lr=1e-4, momentum=0.9)) optimizer = cast(torch.optim.Optimizer, optimizer) # Reset the memory use counter if not args.cpu: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(rank) torch.cuda.synchronize(rank) # Standard training loop training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 need_profiling = args.profile for epoch in range(args.epochs): n_items = 0 epoch_runtime = 0.0 for batch in dataloader: if not args.cpu: torch.cuda.synchronize(rank) batch_start = time.monotonic() def closure(data=batch, grad_scaler=None): model.zero_grad() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("\nbefore: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) if grad_scaler is not None: # Automatically computes the FW pass in half precision with torch.cuda.amp.autocast(): outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) # Accumulates scaled gradients. grad_scaler.scale(loss).backward() else: outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) loss.backward() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("after BW: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) return loss def run_closure(closure, scaler, optimizer): if scaler is not None: final_loss = closure( grad_scaler=scaler ) # AMP scaler.step does not support closures scaler.step(optimizer) scaler.update() return final_loss else: return optimizer.step(closure) if need_profiling and not args.cpu: logging.info("Profiling the run") with profiler.profile( use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore with profiler.record_function("batch"): final_loss = run_closure(closure, scaler, optimizer) prof.export_chrome_trace( f"{optim_type}_trace_rank_{rank}.json") need_profiling = False # only profile once else: final_loss = run_closure(closure, scaler, optimizer) if args.debug and rank == 0: logging.debug("buffer: {}".format( next(model.buffers()).norm().item())) logging.debug("after update: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) n_items += args.batch_size if not args.cpu: # make sure that the cuda kernels are finished before taking a timestamp torch.cuda.synchronize(rank) batch_end = time.monotonic() epoch_runtime += batch_end - batch_start if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() logging.info("... State dict collected") measurements.append(n_items / epoch_runtime) if dist.get_rank() == 0: logging.info( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}" ) training_stop = time.monotonic() img_per_sec = n_items / (training_stop - training_start) * args.epochs logging.info( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint" ) validate_benchmark(measurements, final_loss, args, check_regression) dist.destroy_process_group() # type: ignore
class SharedDataParallelFairScaleAMPEngine(SharedDataParallelFairScaleEngine): """Distributed FairScale MultiGPU training device engine. Args: address: address to use for backend. port: port to use for backend. sync_bn: boolean flag for batchnorm synchonization during disributed training. if True, applies PyTorch `convert_sync_batchnorm`_ to the model for native torch distributed only. Default, False. ddp_kwargs: parameters for `fairscale.nn.data_parallel.ShardedDataParallel`. Docs for `fairscale.nn.ShardedDataParallel`: https://fairscale.readthedocs.io/en/latest/api/nn/sharded_ddp.html process_group_kwargs: parameters for `torch.distributed.init_process_group`. More info here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group scaler_kwargs: parameters for `fairscale.optim.grad_scaler.ShardedGradScaler`. Possible parameters: https://fairscale.readthedocs.io/en/latest/api/index.html Examples: .. code-block:: python from catalyst import dl runner = dl.SupervisedRunner() runner.train( engine=dl.SharedDataParallelFairScaleAMPEngine(), ... ) .. code-block:: python from catalyst import dl class MyRunner(dl.IRunner): # ... def get_engine(self): return dl.SharedDataParallelFairScaleAMPEngine( address="0.0.0.0", port=23234, ddp_kwargs={"find_unused_parameters": False}, process_group_kwargs={"port": 12345}, scaler_kwargs={"growth_factor": 1.5} ) # ... .. code-block:: yaml args: logs: ... model: _target_: ... ... engine: _target_: SharedDataParallelFairScaleAMPEngine address: 0.0.0.0 port: 23234 ddp_kwargs: find_unused_parameters: false process_group_kwargs: port: 12345 scaler_kwargs: growth_factor: 1.5 stages: ... .. _convert_sync_batchnorm: https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html# torch.nn.SyncBatchNorm.convert_sync_batchnorm """ def __init__( self, address: str = None, port: Union[str, int] = None, ddp_kwargs: Dict[str, Any] = None, process_group_kwargs: Dict[str, Any] = None, scaler_kwargs: Dict[str, Any] = None, ): """Init.""" super().__init__( address=address, port=port, ddp_kwargs=ddp_kwargs, process_group_kwargs=process_group_kwargs, ) # @TODO: should we support scaler for each optimizer? if scaler_kwargs is None: scaler_kwargs = {} self.scaler_kwargs = scaler_kwargs self.scaler = ShardedGradScaler(**self.scaler_kwargs) def zero_grad(self, loss, model, optimizer) -> None: """Abstraction over ``model.zero_grad()`` step.""" optimizer.zero_grad() def backward_loss(self, loss, model, optimizer) -> None: """Abstraction over ``loss.backward()`` step.""" self.scaler.scale(loss).backward() def optimizer_step(self, loss, model, optimizer) -> None: """Abstraction over ``optimizer.step()`` step.""" self.scaler.step(optimizer) self.scaler.update() def autocast(self): """AMP context""" return amp.autocast()
def train( rank: int, args: argparse.Namespace, backend: str = "gloo", optim_type: OptimType = OptimType.vanilla, check_regression: bool = True, ): logging.basicConfig( level=logging.INFO if not args.debug else logging.DEBUG) # DDP dist_init(rank=rank, world_size=args.world_size, backend=backend) # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cpu") if args.cpu else torch.device(rank) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None model = cast(nn.Module, model) scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None if optim_type == OptimType.oss_sharded_ddp: optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) model = ShardedDDP(model, optimizer) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore optimizer = (OSS( params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if optim_type == OptimType.oss_ddp else OPTIM( model.parameters(), lr=1e-4, momentum=0.9)) optimizer = cast(torch.optim.Optimizer, optimizer) # Reset the memory use counter if not args.cpu: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(rank) torch.cuda.synchronize(rank) # Standard training loop training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 need_profiling = args.profile for epoch in range(args.epochs): n_items = 0 epoch_runtime = 0.0 for batch in dataloader: if not args.cpu: torch.cuda.synchronize(rank) batch__start = time.monotonic() def closure(data=batch, grad_scaler=None): model.zero_grad() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("\nbefore: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) if grad_scaler is not None: # Automatically computes the FW pass in half precision with torch.cuda.amp.autocast(): outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) # Accumulates scaled gradients. grad_scaler.scale(loss).backward() else: outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) loss.backward() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("after BW: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) return loss if need_profiling and not args.cpu: logging.info("Profiling the run") with profiler.profile( use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore with profiler.record_function("batch"): if scaler is not None: final_loss = closure( grad_scaler=scaler ) # AMP scaler.step does not support closures scaler.step(optimizer) scaler.update() else: final_loss = optimizer.step(closure) prof.export_chrome_trace( f"{optim_type}_trace_rank_{rank}.json") need_profiling = False # only profile once else: if scaler is not None: final_loss = closure( grad_scaler=scaler ) # AMP scaler.step does not support closures scaler.step(optimizer) scaler.update() else: final_loss = optimizer.step(closure) if args.debug and rank == 0: logging.debug("buffer: {}".format( next(model.buffers()).norm().item())) logging.debug("after update: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) n_items += args.batch_size if not args.cpu: # make sure that the cuda kernels are finished before taking a timestamp torch.cuda.synchronize(rank) batch_end = time.monotonic() epoch_runtime += batch_end - batch__start if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() logging.info("... State dict collected") measurements.append(n_items / epoch_runtime) if dist.get_rank() == 0: logging.info( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}" ) max_memory = -1.0 if not args.cpu: torch.cuda.synchronize(rank) max_memory = torch.cuda.max_memory_allocated(rank) / 2**20 logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") training_stop = time.monotonic() img_per_sec = n_items / (training_stop - training_start) * args.epochs logging.info( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint" ) # Compute the median and median of absolute differences img per second measurements.sort() median = measurements[len(measurements) // 2] abs_diff = list(map(lambda x: abs(x - median), measurements)) abs_diff.sort() mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1 logging.info( f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}") if check_regression and dist.get_rank() == 0: assert (median + 3.0 * mad) > args.reference_speed, "Speed regression detected" assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected" assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected" logging.info("[Regression Test] VALID") dist.destroy_process_group() # type: ignore
def train(args, *, tbl): cfg, tokenizer, _, _ = nlp.models.bert.get_pretrained_bert(args.model_name, load_backbone=False, load_mlm=False) cfg = nlp.torch.models.bert.BertModel.get_cfg().clone_merge(cfg) model = nlp.torch.models.bert.QTBertForPretrain(cfg) model.to(args.device) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) parameters_option(args.start_step, model, args, 'Loading') else: model.apply(nlp.torch.models.bert.init_weights) writer = None if args.local_rank in (-1, 0): writer = SummaryWriter(log_dir=os.path.join(args.ckpt_dir, 'tensorboard')) # pin_memory=False due to lack of https://github.com/pytorch/pytorch/commit/54ce171f16c8859f829dde09f87c364c8a6b4130 sampler = RandomSampler(tbl) if args.local_rank == -1 else DistributedSampler( tbl, seed=args.seed) # batch_size // 2 for QuickThought train_dataloader = DataLoader(np.arange(len(tbl)), sampler=sampler, collate_fn=functools.partial(collate_fn, args=args, tbl=tbl), batch_size=args.batch_size // 2, num_workers=args.num_dataloader_workers, pin_memory=True) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer_arguments = {"lr": args.lr} if get_world_size(args) > 1 and args.ZeRO: optimizer = OSS(params=model.parameters(), optim=nlp.torch.optimizers.FusedLANS, **optimizer_arguments) model = ShardedDataParallel(model, optimizer) elif get_world_size(args) > 1: optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters, **optimizer_arguments) model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) else: optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters, **optimizer_arguments) save_interval = args.ckpt_interval logging.info(f'#Total Training Steps={args.num_steps}, ' f'Warmup Steps={args.warmup_ratio * args.num_steps}, ' f'Save Interval={save_interval}') scheduler = nlp.torch.optimizers.schedules.get_warmup_linear_const_decay_poly_schedule( optimizer, total_steps=args.num_steps, warmup_ratio=args.warmup_ratio, const_ratio=args.const_ratio) if args.start_step: logging.info(f'Restart training from {args.start_step}') states_option(args.start_step, optimizer, args, 'Loading') ce_loss_fn = th.nn.CrossEntropyLoss() step_num = args.start_step if args.phase2: step_num -= args.phase1_num_steps running_num_tks, running_grad_norm = 0, 0 running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0 train_start_time = time.time() tic = time.time() model.zero_grad() if get_world_size(args) > 1 and args.ZeRO: scaler = ShardedGradScaler() if args.fp16 else None else: scaler = th.cuda.amp.GradScaler() if args.fp16 else None train_iter = repeat(train_dataloader, set_epoch=args.local_rank != -1) while step_num < args.num_steps: step_num += 1 for accum_step in range(args.num_accumulated): (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = next(train_iter) (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = (arr.to(args.device) for arr in next(train_iter)) model.train() accumulation = ((accum_step + 1) % args.num_accumulated != 0) with model.no_sync() if get_world_size(args) > 1 and accumulation else suppress(): with th.cuda.amp.autocast(enabled=args.fp16): _, pooled_out, mlm_scores, qt_similarity = model(input_id, segment_id, valid_length, mlm_positions) mlm_loss = ce_loss_fn(mlm_scores, mlm_labels) qt_label = th.arange(len(input_id) // 2, device=args.device) qt_loss = ce_loss_fn(qt_similarity, qt_label) loss = mlm_loss + qt_loss if args.num_accumulated > 1: loss = loss / args.num_accumulated if args.fp16: scaler.scale(loss).backward() else: loss.backward() with th.no_grad(): qt_acc = (qt_similarity.argmax(dim=1) == qt_label).sum() / (len(input_id) // 2) mlm_acc = (mlm_scores.argmax(dim=1) == mlm_labels).sum() / len(mlm_labels) # Gather information from all workers for accurate statistics reduced_num_tokens = valid_length.sum() if get_world_size(args) > 1: distributed.all_reduce(reduced_num_tokens) reduced_num_mlm_tokens = th.tensor(len(mlm_labels), device=args.device) if get_world_size(args) > 1: distributed.all_reduce(reduced_num_mlm_tokens) reduced_loss_mlm = mlm_loss.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens if get_world_size(args) > 1: distributed.all_reduce(reduced_loss_mlm) reduced_acc_mlm = mlm_acc.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens if get_world_size(args) > 1: distributed.all_reduce(reduced_acc_mlm) reduced_bs = th.tensor(len(input_id), device=args.device) if get_world_size(args) > 1: distributed.all_reduce(reduced_bs) reduced_loss_qt = qt_loss.detach().clone() * len(input_id) / reduced_bs if get_world_size(args) > 1: distributed.all_reduce(reduced_loss_qt) reduced_acc_qt = qt_acc.detach().clone() * len(input_id) / reduced_bs if get_world_size(args) > 1: distributed.all_reduce(reduced_acc_qt) running_num_tks += reduced_num_tokens.item() running_mlm_loss += reduced_loss_mlm.item() running_mlm_acc += reduced_acc_mlm.item() running_qt_loss += reduced_loss_qt.item() running_qt_acc += reduced_acc_qt.item() if not accumulation: if args.fp16: scaler.unscale_(optimizer) # unscale for gradient clipping if get_world_size(args) > 1 and args.ZeRO: total_norm = optimizer.clip_grad_norm(args.max_grad_norm) else: total_norm = th.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) if get_world_size(args) > 1: distributed.all_reduce(total_norm) total_norm /= get_world_size(args) running_grad_norm += total_norm if args.fp16: scaler.step(optimizer) scaler.update() else: optimizer.step() with warnings.catch_warnings(): # Scheduler may warn if optimizer.step() call is skipped # due to invalid gradients detected by scaler. warnings.simplefilter("ignore", UserWarning) scheduler.step() optimizer.zero_grad(set_to_none=True) if step_num % args.log_interval == 0: toc = time.time() wps = running_num_tks / (toc - tic) eta = (args.num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600 interval = args.log_interval * args.num_accumulated logging.info(f'[Step {step_num}], LR={scheduler.get_last_lr()[0]:.6f}, ' f'Loss MLM/QT={running_mlm_loss / interval:.4f}/' f'{running_qt_loss / interval:.4f}, ' f'Acc MLM/QT={running_mlm_acc / interval:.4f}/' f'{running_qt_acc / interval:.4f}, ' f'Grad_norm={running_grad_norm / interval:.4f}, ' f'Time cost={toc - tic:.2f}, ' f'Throughput={wps:.2f} tokens/s, ETA={eta:.2f}h') if args.local_rank in (-1, 0): writer.add_scalar('Throughput_wps', wps, step_num) writer.add_scalar('Loss/MLM', running_mlm_loss / interval, step_num) writer.add_scalar('Loss/QT', running_qt_loss / interval, step_num) writer.add_scalar('Acc/MLM', running_mlm_acc / interval, step_num) writer.add_scalar('Acc/QT', running_qt_acc / interval, step_num) writer.add_scalar('LR', scheduler.get_last_lr()[0], step_num) writer.add_scalar('Grad_norm', running_grad_norm / interval, step_num) running_num_tks, running_grad_norm = 0, 0 running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0 tic = time.time() # Saving if step_num % save_interval == 0 or step_num >= args.num_steps: states_option(step_num, optimizer, args, 'Saving') if args.local_rank in (0, -1): parameters_option(step_num, model, args, 'Saving') logging.info('Finish training step: %d', step_num) train_end_time = time.time() logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time)) if args.local_rank in (0, -1): save_dir = os.path.join(args.ckpt_dir, args.model_name) final_save(model, save_dir, tokenizer.vocab, cfg)
def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None: super().__init__(precision, use_cpu=use_cpu) if not self.use_cpu: self.scaler = ShardedGradScaler()
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool): # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() step() # Any model works. Add one different buffer per rank model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) ddp_scaler = TorchGradScaler() if amp else None sharded_ddp_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device) def closure_ddp(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, accumulate) def closure_sharded(input_tensor=input_tensor): return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate) # Step/scale both if ddp_scaler is not None: _ = closure_ddp(input_tensor) ddp_scaler.step(ddp_optimizer) ddp_scaler.update() else: ddp_optimizer.step(closure=closure_ddp) if sharded_ddp_scaler is not None: _ = closure_sharded(input_tensor) sharded_ddp_scaler.step(sharded_optimizer) sharded_ddp_scaler.update() else: sharded_optimizer.step(closure=closure_sharded) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters() ).requires_grad next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
def run_ddp_parity( rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) NUMBER_BATCHS = 5 # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets print( f"{rank}: Checking configuration: accumulate {grad_accumulation}" + f" - change train graph {change_train_graph}" + f" - amp {amp}" + f" - manual reduction {manual_reduction}" + f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}", flush=True, ) # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp_emb(multiple_fw) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = _get_random_inputs(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() else: _optimizer.step(_closure()) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP # NOTE: DDP does not handle parameters trainability being changed after the fact, see # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471 if clip_grad_norm and not change_train_graph: if torch_version() >= (1, 9, 0): total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0, error_if_nonfinite=False) # type: ignore else: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8) if not allclose: # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP for idx, (p_ddp, p_sdp) in enumerate( zip(ddp_model.parameters(), sharded_ddp_model.parameters())): if p_ddp.grad is not None: if p_sdp.grad is not None: print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True) else: print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True) assert ( allclose ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") dist.destroy_process_group()
def __init__(self) -> None: super().__init__() self.scaler = ShardedGradScaler()
def set_amp_args(self, amp_args: Optional[Dict[str, Any]]): """Disable / enable apex.amp and set the automatic mixed precision parameters. apex.amp can be utilized for mixed / half precision training. Args: amp_args: Dictionary containing arguments to be passed to amp.initialize. Set to None to disable amp. To enable mixed precision training, pass amp_args={"opt_level": "O1"} here. See https://nvidia.github.io/apex/amp.html for more info. Raises: RuntimeError: If opt_level is not None and apex is not installed. Warning: apex needs to be installed to utilize this feature. """ self.amp_args = amp_args if amp_args is None: logging.info("AMP disabled") else: # Check that the requested AMP type is known try: self.amp_type = AmpType[self.amp_args["amp_type"].upper()] except KeyError: logging.info("AMP type not specified, defaulting to Apex") self.amp_type = AmpType.APEX # Check for CUDA availability, required for both Apex and Pytorch AMP if not torch.cuda.is_available(): raise RuntimeError( "AMP is required but CUDA is not supported, cannot enable AMP" ) # Check for Apex availability if self.amp_type == AmpType.APEX and not apex_available: raise RuntimeError( "Apex AMP is required but Apex is not installed, cannot enable AMP" ) if self.use_sharded_ddp: if self.amp_type == AmpType.APEX: raise RuntimeError( "ShardedDDP has been requested, which is incompatible with Apex AMP" ) if not fairscale_available: raise RuntimeError( "ShardedDDP has been requested, but fairscale is not installed in the current environment" ) # Set Torch AMP grad scaler, used to prevent gradient underflow elif self.amp_type == AmpType.PYTORCH: if self.use_sharded_ddp: logging.info( "Using ShardedGradScaler to manage Pytorch AMP") self.amp_grad_scaler = ShardedGradScaler() else: self.amp_grad_scaler = TorchGradScaler() logging.info(f"AMP enabled with args {amp_args}") return self
def scaler(self): return ShardedGradScaler()
def check_parity(manual_reduction: bool): # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook( state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = torch.rand((BATCH_SIZE, 2)).to(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP if clip_grad_norm: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) assert torch.allclose( oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8 ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
def _distributed_worker( rank, world_size, fsdp_config, fsdp_wrap_bn, ddp_mixed_precision, tempfile_name, unused, state_before, inputs, rank_0_output, state_after, sync_bn, conv_bias, linear_bias, ): torch.backends.cudnn.deterministic = True result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" ddp = True if fsdp_config: ddp = False assert isinstance(fsdp_config, dict), str(fsdp_config) if fsdp_config["mixed_precision"]: # To match DDP in AMP -O1, we need fp32 reduce scatter. fsdp_config["fp32_reduce_scatter"] = True model = Model(conv_bias, linear_bias) model.load_state_dict(state_before) model = model.cuda() class DummyScaler: def scale(self, loss): return loss def step(self, optim): optim.step() def update(self): pass scaler = DummyScaler() if ddp: if sync_bn == "pytorch": model = pytorch_bn_converter(model) model = DDP(model, device_ids=[rank], broadcast_buffers=True) if ddp_mixed_precision: scaler = GradScaler() else: # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}") if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) if sync_bn == "pytorch": model = pytorch_bn_converter(model) else: print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}") if sync_bn == "pytorch": model = pytorch_bn_converter(model) if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) model = FSDP(model, **fsdp_config).cuda() if fsdp_config["mixed_precision"]: scaler = ShardedGradScaler() # Print the model for verification. if rank == 0: print(model) optim = SGD(model.parameters(), lr=0.1) loss_func = CrossEntropyLoss() for in_data in inputs[rank]: in_data = in_data.cuda() context = contextlib.suppress() if ddp and ddp_mixed_precision: in_data = in_data.half() context = torch.cuda.amp.autocast(enabled=True) if not ddp and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) with context: out = model(in_data) fake_label = torch.zeros(1, dtype=torch.long).cuda() loss = loss_func(out.unsqueeze(0), fake_label) scaler.scale(loss).backward() scaler.step(optim) scaler.update() optim.zero_grad() if ddp: # Save the rank 0 state_dict to the output file. if rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) else: model.assert_state(TrainingState.IDLE) # Ensure final state equals to the state_after. fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() # Change False to True to enable this when you want to debug the mismatch. if False and rank == 0: def dump(d): for k, v in d.items(): print(k, v) dump(state_after) dump(fsdp_state) # If sync_bn is used, all ranks should have the same state, so we can compare with # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0. if sync_bn != "none" or rank == 0: assert objects_are_equal(state_after, fsdp_state, raise_exception=True) teardown()