def train( rank: int, args: argparse.Namespace, backend: str = "gloo", optim_type: OptimType = OptimType.vanilla, check_regression: bool = True, ): logging.basicConfig( level=logging.INFO if not args.debug else logging.DEBUG) use_multi_tensor = args.multi_tensor_optim and hasattr( torch.optim, "_multi_tensor") OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop # type: ignore # attr is checked but mypy misses that logging.info("Multi tensor optimizer: {}".format(use_multi_tensor)) # DDP dist_init(rank=rank, world_size=args.world_size, backend=backend) # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cpu") if args.cpu else torch.device(rank) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.model) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None model = cast(nn.Module, model) scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None if optim_type == OptimType.oss_sharded_ddp: optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) # Single node run typically, no need for reduce buckets model = ShardedDDP(model, optimizer, reduce_buffer_size=0) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore optimizer = (OSS( params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if optim_type == OptimType.oss_ddp else OPTIM( model.parameters(), lr=1e-4, momentum=0.9)) optimizer = cast(torch.optim.Optimizer, optimizer) # Reset the memory use counter if not args.cpu: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(rank) torch.cuda.synchronize(rank) # Standard training loop training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 need_profiling = args.profile for epoch in range(args.epochs): n_items = 0 epoch_runtime = 0.0 for batch in dataloader: if not args.cpu: torch.cuda.synchronize(rank) batch_start = time.monotonic() def closure(data=batch, grad_scaler=None): model.zero_grad() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("\nbefore: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) if grad_scaler is not None: # Automatically computes the FW pass in half precision with torch.cuda.amp.autocast(): outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) # Accumulates scaled gradients. grad_scaler.scale(loss).backward() else: outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) loss.backward() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("after BW: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) return loss def run_closure(closure, scaler, optimizer): if scaler is not None: final_loss = closure( grad_scaler=scaler ) # AMP scaler.step does not support closures scaler.step(optimizer) scaler.update() return final_loss else: return optimizer.step(closure) if need_profiling and not args.cpu: logging.info("Profiling the run") with profiler.profile( use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore with profiler.record_function("batch"): final_loss = run_closure(closure, scaler, optimizer) prof.export_chrome_trace( f"{optim_type}_trace_rank_{rank}.json") need_profiling = False # only profile once else: final_loss = run_closure(closure, scaler, optimizer) if args.debug and rank == 0: logging.debug("buffer: {}".format( next(model.buffers()).norm().item())) logging.debug("after update: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) n_items += args.batch_size if not args.cpu: # make sure that the cuda kernels are finished before taking a timestamp torch.cuda.synchronize(rank) batch_end = time.monotonic() epoch_runtime += batch_end - batch_start if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() logging.info("... State dict collected") measurements.append(n_items / epoch_runtime) if dist.get_rank() == 0: logging.info( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}" ) training_stop = time.monotonic() img_per_sec = n_items / (training_stop - training_start) * args.epochs logging.info( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint" ) validate_benchmark(measurements, final_loss, args, check_regression) dist.destroy_process_group() # type: ignore
def train( rank: int, args: argparse.Namespace, backend: str = "gloo", optim_type: OptimType = OptimType.vanilla, check_regression: bool = True, ): logging.basicConfig( level=logging.INFO if not args.debug else logging.DEBUG) # DDP dist_init(rank=rank, world_size=args.world_size, backend=backend) # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cpu") if args.cpu else torch.device(rank) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None model = cast(nn.Module, model) scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None if optim_type == OptimType.oss_sharded_ddp: optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) model = ShardedDDP(model, optimizer) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore optimizer = (OSS( params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if optim_type == OptimType.oss_ddp else OPTIM( model.parameters(), lr=1e-4, momentum=0.9)) optimizer = cast(torch.optim.Optimizer, optimizer) # Reset the memory use counter if not args.cpu: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(rank) torch.cuda.synchronize(rank) # Standard training loop training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 need_profiling = args.profile for epoch in range(args.epochs): n_items = 0 epoch_runtime = 0.0 for batch in dataloader: if not args.cpu: torch.cuda.synchronize(rank) batch__start = time.monotonic() def closure(data=batch, grad_scaler=None): model.zero_grad() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("\nbefore: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) if grad_scaler is not None: # Automatically computes the FW pass in half precision with torch.cuda.amp.autocast(): outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) # Accumulates scaled gradients. grad_scaler.scale(loss).backward() else: outputs = model(data["inputs"]) loss = loss_fn(outputs, data["label"]) loss.backward() if args.debug and rank == 0 and next( model.parameters()).grad is not None: logging.debug("after BW: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) return loss if need_profiling and not args.cpu: logging.info("Profiling the run") with profiler.profile( use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore with profiler.record_function("batch"): if scaler is not None: final_loss = closure( grad_scaler=scaler ) # AMP scaler.step does not support closures scaler.step(optimizer) scaler.update() else: final_loss = optimizer.step(closure) prof.export_chrome_trace( f"{optim_type}_trace_rank_{rank}.json") need_profiling = False # only profile once else: if scaler is not None: final_loss = closure( grad_scaler=scaler ) # AMP scaler.step does not support closures scaler.step(optimizer) scaler.update() else: final_loss = optimizer.step(closure) if args.debug and rank == 0: logging.debug("buffer: {}".format( next(model.buffers()).norm().item())) logging.debug("after update: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item())) n_items += args.batch_size if not args.cpu: # make sure that the cuda kernels are finished before taking a timestamp torch.cuda.synchronize(rank) batch_end = time.monotonic() epoch_runtime += batch_end - batch__start if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() logging.info("... State dict collected") measurements.append(n_items / epoch_runtime) if dist.get_rank() == 0: logging.info( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}" ) max_memory = -1.0 if not args.cpu: torch.cuda.synchronize(rank) max_memory = torch.cuda.max_memory_allocated(rank) / 2**20 logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") training_stop = time.monotonic() img_per_sec = n_items / (training_stop - training_start) * args.epochs logging.info( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint" ) # Compute the median and median of absolute differences img per second measurements.sort() median = measurements[len(measurements) // 2] abs_diff = list(map(lambda x: abs(x - median), measurements)) abs_diff.sort() mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1 logging.info( f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}") if check_regression and dist.get_rank() == 0: assert (median + 3.0 * mad) > args.reference_speed, "Speed regression detected" assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected" assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected" logging.info("[Regression Test] VALID") dist.destroy_process_group() # type: ignore
def train(rank: int, args, world_size: int, epochs: int): # DDP init example dist_init(rank, world_size) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) # Problem statement model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(rank) if args.use_ortmodule: print("Converting to ORTModule....") model = ORTModule(model) train_dataloader, test_dataloader = get_dataloader(args, rank, args.batch_size) loss_fn = my_loss base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here base_optimizer_arguments = { } # pass any optimizer specific arguments here, or directly below when instantiating OSS if args.use_sharded_optimizer: # Wrap the optimizer in its state sharding brethren optimizer = OSS(params=model.parameters(), optim=base_optimizer, lr=args.lr) # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks model = ShardedDDP(model, optimizer) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) # Any relevant training loop, nothing specific to OSS. For example: model.train() total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0 for epoch in range(epochs): total_training_time += train_step(args, model, rank, optimizer, loss_fn, train_dataloader, epoch) if epoch == 0: epoch_0_training = total_training_time if args.test_batch_size > 0: test_time, validation_accuracy = test(args, model, rank, loss_fn, test_dataloader) total_test_time += test_time print('\n======== Global stats ========') if args.use_ortmodule: estimated_export = 0 if args.epochs > 1: estimated_export = epoch_0_training - ( total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format( estimated_export)) else: print( " Estimated ONNX export took: Estimate available when epochs > 1 only" ) print(" Accumulated training without export took: {:.4f}s".format( total_training_time - estimated_export)) print(" Accumulated training took: {:.4f}s".format( total_training_time)) print(" Accumulated validation took: {:.4f}s".format( total_test_time)) dist.destroy_process_group()
def train( rank: int, args: argparse.Namespace, backend: str = "gloo", optim_type: OptimType = OptimType.vanilla, check_regression: bool = True, ): logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) # DDP dist_init(rank=rank, world_size=args.world_size, backend=backend) # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device("cpu") if args.cpu else torch.device(rank) model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None model = cast(nn.Module, model) if optim_type == OptimType.oss_sharded_ddp: model = ShardedDDP( model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=args.world_size, broadcast_buffers=True, ) optimizer = model.sharded_optimizer else: model = DDP(model, device_ids=[rank], find_unused_parameters=False) # type: ignore optimizer = ( OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if optim_type == OptimType.oss_ddp else OPTIM(model.parameters(), lr=1e-4, momentum=0.9) ) optimizer = cast(torch.optim.Optimizer, optimizer) # Reset the memory use counter if not args.cpu: torch.cuda.reset_peak_memory_stats(rank) torch.cuda.synchronize(rank) # Standard training loop training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 need_profiling = args.profile for epoch in range(args.epochs): n_items = 0 epoch_runtime = 0.0 for batch in dataloader: batch__start = time.monotonic() def closure(): model.zero_grad() if args.debug and rank == 0 and next(model.parameters()).grad is not None: logging.debug( "\nbefore: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item() ) ) outputs = model(batch["inputs"]) loss = loss_fn(outputs, batch["label"]) loss.backward() if optim_type == OptimType.oss_sharded_ddp: model.reduce() if args.debug and rank == 0 and next(model.parameters()).grad is not None: logging.debug( "after BW: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item() ) ) return loss if need_profiling and not args.cpu: logging.info("Profiling the run") with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore with profiler.record_function("batch"): final_loss = optimizer.step(closure) logging.info("profiling done") if rank == 0: prof.export_chrome_trace(f"{optim_type}_trace.json") need_profiling = False # only profile once else: final_loss = optimizer.step(closure) if args.debug and rank == 0: logging.debug("buffer: {}".format(next(model.buffers()).norm().item())) logging.debug( "after update: param {} -- grad {}".format( next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item() ) ) n_items += args.batch_size batch_end = time.monotonic() epoch_runtime += batch_end - batch__start if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() logging.info("... State dict collected") measurements.append(n_items / epoch_runtime) if dist.get_rank() == 0: logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}") max_memory = -1.0 if not args.cpu: torch.cuda.synchronize(rank) max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20 logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") training_stop = time.monotonic() img_per_sec = n_items / (training_stop - training_start) * args.epochs max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20 logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint") logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") # Compute the mean and average img per second mean = sum(measurements) / len(measurements) diff = map(lambda x: pow(x - mean, 2.0), measurements) std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1 logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") if check_regression and dist.get_rank() == 0: assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected" assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected" assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected" logging.info("[Regression Test] VALID") dist.destroy_process_group() # type: ignore