def train(rank: int, world_size: int, epochs: int, use_oss: bool): # DDP dist_init(rank, world_size) # Problem statement model = getModel().to(rank) dataloader = getData() loss_fn = getLossFun() optimizer: Optional[Union[OSS, torch.optim.SGD]] = None if not use_oss: optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-4) else: base_optimizer = torch.optim.SGD base_optimizer_arguments = { "lr": 1e-4 } # any optimizer specific arguments, LR, momentum, etc... optimizer = OSS(params=model.parameters(), optim=base_optimizer, default=base_optimizer_arguments) training_start = time.monotonic() # Any relevant training loop, nothing specific to OSS. For example: model.train() for _ in range(epochs): for (data, target) in dataloader: data, target = data.to(rank), target.to(rank) # Train model.zero_grad() outputs = model(data) loss = loss_fn(outputs, target) loss.backward() # if you want to clip the gradients / get the current max: max_norm = 1000.0 norm_type = 1 if not use_oss: _total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm, norm_type=norm_type) # type: ignore else: optimizer = cast(OSS, optimizer) _total_norm = optimizer.clip_grad_norm(max_norm, norm_type=norm_type) optimizer.step() print(f"Loss: {loss.item()}") training_end = time.monotonic() max_memory = torch.cuda.max_memory_allocated(rank) print( f"[{dist.get_rank()}] : Training done. {training_end-training_start:.2f} sec" ) print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
def prepare(self, param_groups) -> None: assert (dist.is_initialized( )), "torch.distributed needs to be initialized to prepare this rank" def optimizer_constructor(param_groups: Any, *args, **kwargs): # ClassyOptimizer have deferred initialization, while OSS needs access to the # raw optimizer instance, hence the trampoline logging.debug("Building a ZeRO enabled optimizer") self.base_optimizer.prepare(param_groups) return self.base_optimizer.optimizer self.optimizer = OSS(params=param_groups, optim=optimizer_constructor)
class ZeRO(ClassyOptimizer): def __init__(self, base_optimizer: ClassyOptimizer): """Wraps an arbitrary :class:`ClassyOptimizer <classy_vision.optim.ClassyOptimizer>` optimizer and shards its state as described by ZeRO_. :: opt = OSS(params, optim=torch.optim.Adam, lr=0.01) .. _ZeRO: https://arxiv.org/abs/1910.02054 This instance holds all of the parameters for the model (in the .param_groups attribute) but relies on a wrapped optimizer, which only process an original shard of the parameters. Every step all the parameters are synced across the replicas. The Fairscale library is used https://github.com/facebookresearch/fairscale """ assert ( fairscale_available ), "The Fairscale library needs to be installed to use this optimizer." super().__init__() self.base_optimizer = base_optimizer def prepare(self, param_groups) -> None: assert (dist.is_initialized( )), "torch.distributed needs to be initialized to prepare this rank" def optimizer_constructor(param_groups: Any, *args, **kwargs): # ClassyOptimizer have deferred initialization, while OSS needs access to the # raw optimizer instance, hence the trampoline logging.debug("Building a ZeRO enabled optimizer") self.base_optimizer.prepare(param_groups) return self.base_optimizer.optimizer self.optimizer = OSS(params=param_groups, optim=optimizer_constructor) @classmethod def from_config(cls, config): return cls(base_optimizer=build_optimizer(config["base_optimizer"])) def on_epoch(self, where: float) -> None: # Run the normal LR schedulers super().on_epoch(where) # Materialize the optimizer state on the replica in charge of checkpointing logging.info("Consolidating sharded state on primary rank. Where: %d" % where) self.consolidate_state_dict() def consolidate_state_dict(self) -> None: self.optimizer.consolidate_state_dict( recipient_rank=get_primary_rank())
def make_adam(params): if args.ddp_zero: return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr) else: return Adam(params, lr=lr)
def make_adam(model): if args.ddp_zero: return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr) else: return Adam(model.parameters(), lr=lr)
def train(rank: int, world_size: int, epochs: int, use_oss: bool): # DDP dist_init(rank, world_size) # Problem statement model = getModel().to(rank) dataloader = getData() loss_fn = getLossFun() base_optimizer_arguments = { "lr": 1e-4 } # any optimizer specific arguments, LR, momentum, etc... if ~use_oss: optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments) else: base_optimizer = torch.optim.SGD optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments) training_start = time.monotonic() # Any relevant training loop, nothing specific to OSS. For example: model.train() for e in range(epochs): for (data, target) in dataloader: data, target = data.to(rank), target.to(rank) # Train model.zero_grad() outputs = model(data) loss = loss_fn(outputs, target) loss /= world_size loss.backward() optimizer.step() print(f"Loss: {loss.item()}") training_end = time.monotonic() max_memory = torch.cuda.max_memory_allocated(rank) print( f"[{dist.get_rank()}] : Training done. {training_end-training_start:.2f} sec" ) print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
def build_optimizer(model, config): optimizer_config = config.optimizer if "type" not in optimizer_config: raise ValueError( "Optimizer attributes must have a 'type' key " "specifying the type of optimizer. " "(Custom or PyTorch, e.g. 'adam_w' or 'SGD')" ) optimizer_type = optimizer_config.type if "params" not in optimizer_config: warnings.warn("optimizer attributes has no params defined, defaulting to {}.") params = optimizer_config.get("params", {}) if hasattr(torch.optim, optimizer_type): optimizer_class = getattr(torch.optim, optimizer_type) else: optimizer_class = registry.get_optimizer_class(optimizer_type) if optimizer_class is None: raise ValueError( "No optimizer class of type {} present in " "either torch or registered to registry" ) parameters = get_optimizer_parameters(model, config) if optimizer_config.get("enable_state_sharding", False): # TODO(vedanuj): Remove once OSS is moved to PT upstream try: from fairscale.optim.oss import OSS except ImportError: print( "Optimizer state sharding requires fairscale. " + "Install using pip install fairscale." ) raise assert ( is_dist_initialized() ), "Optimizer state sharding can only be used in distributed mode." is_fp16 = config.get("training", {}).get("fp16", False) optimizer = OSS( params=parameters, optim=optimizer_class, broadcast_fp16=is_fp16, **params ) else: optimizer = optimizer_class(parameters, **params) return optimizer
def train(rank: int, args, world_size: int, epochs: int): # DDP init example dist_init(rank, world_size) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Setup if not args.cpu: torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) # Problem statement model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(rank) if args.use_ortmodule: print("Converting to ORTModule....") model = ORTModule(model) train_dataloader, test_dataloader = get_dataloader(args, rank, args.batch_size) loss_fn = my_loss base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here base_optimizer_arguments = { } # pass any optimizer specific arguments here, or directly below when instantiating OSS if args.use_sharded_optimizer: # Wrap the optimizer in its state sharding brethren optimizer = OSS(params=model.parameters(), optim=base_optimizer, lr=args.lr) # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks model = ShardedDDP(model, optimizer) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) # Any relevant training loop, nothing specific to OSS. For example: model.train() total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0 for epoch in range(epochs): total_training_time += train_step(args, model, rank, optimizer, loss_fn, train_dataloader, epoch) if epoch == 0: epoch_0_training = total_training_time if args.test_batch_size > 0: test_time, validation_accuracy = test(args, model, rank, loss_fn, test_dataloader) total_test_time += test_time print('\n======== Global stats ========') if args.use_ortmodule: estimated_export = 0 if args.epochs > 1: estimated_export = epoch_0_training - ( total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format( estimated_export)) else: print( " Estimated ONNX export took: Estimate available when epochs > 1 only" ) print(" Accumulated training without export took: {:.4f}s".format( total_training_time - estimated_export)) print(" Accumulated training took: {:.4f}s".format( total_training_time)) print(" Accumulated validation took: {:.4f}s".format( total_test_time)) dist.destroy_process_group()
def train( rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo", use_oss: bool = True, use_sdp: bool = False, check_regression: bool = True, reference_speed: float = -1.0, reference_memory: float = -1.0, reference_loss: float = -1.0, ): assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS" # DDP dist_init(rank=rank, world_size=world_size, backend=backend) # Setup torch.cuda.set_device(rank) torch.cuda.manual_seed(0) torch.manual_seed(0) # also sets the cuda seed np.random.seed(0) if backend == "nccl": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False model, dataloader, loss_fn = get_problem(rank, data_size, batch_size) # Shard the optimizer optimizer: Optional[torch.optim.Optimizer] = None if use_sdp: ddp = ShardedDataParallel( module=model, optimizer=OPTIM, optimizer_params={ "lr": 1e-4, "momentum": 0.9 }, world_size=world_size, broadcast_buffers=False, ) ddp.train() optimizer = ddp.optimizer model = ddp else: optimizer = ( OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)) # Reset the memory use counter torch.cuda.reset_peak_memory_stats(rank) # Dummy training loop torch.cuda.synchronize(rank) training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 for epoch in range(num_epochs): epoch_start = time.monotonic() for batch in dataloader: def closure(): model.zero_grad() outputs = model(batch["inputs"]) loss = loss_fn(outputs, batch["label"]) loss /= world_size loss.backward() dist.all_reduce(loss, op=dist.ReduceOp.SUM) if use_sdp: ddp.reduce( ) # Send the gradients to the appropriate shards return loss final_loss = optimizer.step(closure) epoch_end = time.monotonic() if use_oss: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() print("... State dict collected") measurements.append(data_size / (epoch_end - epoch_start)) if dist.get_rank() == 0: print( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}" ) torch.cuda.synchronize(rank) training_stop = time.monotonic() img_per_sec = data_size / (training_stop - training_start) * num_epochs max_memory = torch.cuda.max_memory_allocated(rank) / 2**20 print( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall" ) print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") # Compute the mean and average img per second mean = sum(measurements) / len(measurements) diff = map(lambda x: pow(x - mean, 2.0), measurements) std = math.sqrt(sum(diff) / (len(measurements) - 1)) print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") if use_oss and check_regression and dist.get_rank() == 0: assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected" assert abs(cast(float, final_loss) - reference_loss) < 1e-3, "Loss regression detected" print("[Regression Test] VALID")
def train( rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, use_oss: bool = True, check_regression: bool = True, reference_speed: float = -1.0, reference_memory: float = -1.0, ): # DDP dist_init(rank, world_size) # Standard RN101 model = resnet101(pretrained=False, progress=True).to(rank) # Data setup, dummy data def collate(inputs: List[Any]): return { "inputs": torch.stack([i[0] for i in inputs]).to(torch.device(rank)), "label": torch.stack([i[1] for i in inputs]).to(torch.device(rank)), } dataloader = DataLoader(dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate) loss_fn = nn.CrossEntropyLoss() # Reset the memory use counter torch.cuda.reset_peak_memory_stats(rank) # Shard the optimizer optimizer: Union[OSS, OPTIM] = OSS( params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if use_oss else OPTIM( model.parameters(), lr=1e-4, momentum=0.9) # Dummy training loop torch.cuda.synchronize(rank) training_start = time.monotonic() model.train() measurements = [] for epoch in range(num_epochs): epoch_start = time.monotonic() for batch in dataloader: def closure(): model.zero_grad() outputs = model(batch["inputs"]) loss = loss_fn(outputs, batch["label"]) dist.all_reduce(loss, op=dist.ReduceOp.SUM) loss /= world_size loss.backward() return loss optimizer.step(closure) epoch_end = time.monotonic() if use_oss: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) # optimizer.consolidate_state_dict() if dist.get_rank() == 0: # _ = optimizer.state_dict() print("... State dict collected") measurements.append(data_size / (epoch_end - epoch_start)) if dist.get_rank() == 0: print( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec" ) torch.cuda.synchronize(rank) training_stop = time.monotonic() img_per_sec = data_size / (training_stop - training_start) * num_epochs max_memory = torch.cuda.max_memory_allocated(rank) / 2**20 print( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall" ) print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") # Compute the mean and average img per second mean = sum(measurements) / len(measurements) diff = map(lambda x: pow(x - mean, 2.0), measurements) std = math.sqrt(sum(diff) / (len(measurements) - 1)) print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") if use_oss and check_regression and dist.get_rank() == 0: assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected" print("[Regression Test] VALID")
def train( rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200, backend: str = "gloo", use_oss: bool = True, check_regression: bool = True, reference_speed: float = -1.0, reference_memory: float = -1.0, ): # DDP dist_init(rank, world_size, backend) # Setup model, dataloader, loss_fn = get_problem(rank, data_size, batch_size) # Shard the optimizer optimizer: torch.optim.Optimizer = ( OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)) # Reset the memory use counter torch.cuda.reset_peak_memory_stats(rank) # Dummy training loop torch.cuda.synchronize(rank) training_start = time.monotonic() model.train() measurements = [] final_loss: Optional[float] = -1.0 for epoch in range(num_epochs): epoch_start = time.monotonic() for batch in dataloader: def closure(): model.zero_grad() outputs = model(batch["inputs"]) loss = loss_fn(outputs, batch["label"]) loss /= world_size loss.backward() dist.all_reduce(loss, op=dist.ReduceOp.SUM) return loss final_loss = optimizer.step(closure) epoch_end = time.monotonic() if use_oss: # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there optimizer = cast(OSS, optimizer) optimizer.consolidate_state_dict() if dist.get_rank() == 0: _ = optimizer.state_dict() print("... State dict collected") measurements.append(data_size / (epoch_end - epoch_start)) if dist.get_rank() == 0: print( f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}" ) torch.cuda.synchronize(rank) training_stop = time.monotonic() img_per_sec = data_size / (training_stop - training_start) * num_epochs max_memory = torch.cuda.max_memory_allocated(rank) / 2**20 print( f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall" ) print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") # Compute the mean and average img per second mean = sum(measurements) / len(measurements) diff = map(lambda x: pow(x - mean, 2.0), measurements) std = math.sqrt(sum(diff) / (len(measurements) - 1)) print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") if use_oss and check_regression and dist.get_rank() == 0: assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" assert max_memory < 1.05 * reference_memory, "Memory use regression detected" print("[Regression Test] VALID")
def train(args, *, tbl): cfg, tokenizer, _, _ = nlp.models.bert.get_pretrained_bert(args.model_name, load_backbone=False, load_mlm=False) cfg = nlp.torch.models.bert.BertModel.get_cfg().clone_merge(cfg) model = nlp.torch.models.bert.QTBertForPretrain(cfg) model.to(args.device) if args.start_step: logging.info('Restart training from {}'.format(args.start_step)) parameters_option(args.start_step, model, args, 'Loading') else: model.apply(nlp.torch.models.bert.init_weights) writer = None if args.local_rank in (-1, 0): writer = SummaryWriter(log_dir=os.path.join(args.ckpt_dir, 'tensorboard')) # pin_memory=False due to lack of https://github.com/pytorch/pytorch/commit/54ce171f16c8859f829dde09f87c364c8a6b4130 sampler = RandomSampler(tbl) if args.local_rank == -1 else DistributedSampler( tbl, seed=args.seed) # batch_size // 2 for QuickThought train_dataloader = DataLoader(np.arange(len(tbl)), sampler=sampler, collate_fn=functools.partial(collate_fn, args=args, tbl=tbl), batch_size=args.batch_size // 2, num_workers=args.num_dataloader_workers, pin_memory=True) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer_arguments = {"lr": args.lr} if get_world_size(args) > 1 and args.ZeRO: optimizer = OSS(params=model.parameters(), optim=nlp.torch.optimizers.FusedLANS, **optimizer_arguments) model = ShardedDataParallel(model, optimizer) elif get_world_size(args) > 1: optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters, **optimizer_arguments) model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) else: optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters, **optimizer_arguments) save_interval = args.ckpt_interval logging.info(f'#Total Training Steps={args.num_steps}, ' f'Warmup Steps={args.warmup_ratio * args.num_steps}, ' f'Save Interval={save_interval}') scheduler = nlp.torch.optimizers.schedules.get_warmup_linear_const_decay_poly_schedule( optimizer, total_steps=args.num_steps, warmup_ratio=args.warmup_ratio, const_ratio=args.const_ratio) if args.start_step: logging.info(f'Restart training from {args.start_step}') states_option(args.start_step, optimizer, args, 'Loading') ce_loss_fn = th.nn.CrossEntropyLoss() step_num = args.start_step if args.phase2: step_num -= args.phase1_num_steps running_num_tks, running_grad_norm = 0, 0 running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0 train_start_time = time.time() tic = time.time() model.zero_grad() if get_world_size(args) > 1 and args.ZeRO: scaler = ShardedGradScaler() if args.fp16 else None else: scaler = th.cuda.amp.GradScaler() if args.fp16 else None train_iter = repeat(train_dataloader, set_epoch=args.local_rank != -1) while step_num < args.num_steps: step_num += 1 for accum_step in range(args.num_accumulated): (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = next(train_iter) (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = (arr.to(args.device) for arr in next(train_iter)) model.train() accumulation = ((accum_step + 1) % args.num_accumulated != 0) with model.no_sync() if get_world_size(args) > 1 and accumulation else suppress(): with th.cuda.amp.autocast(enabled=args.fp16): _, pooled_out, mlm_scores, qt_similarity = model(input_id, segment_id, valid_length, mlm_positions) mlm_loss = ce_loss_fn(mlm_scores, mlm_labels) qt_label = th.arange(len(input_id) // 2, device=args.device) qt_loss = ce_loss_fn(qt_similarity, qt_label) loss = mlm_loss + qt_loss if args.num_accumulated > 1: loss = loss / args.num_accumulated if args.fp16: scaler.scale(loss).backward() else: loss.backward() with th.no_grad(): qt_acc = (qt_similarity.argmax(dim=1) == qt_label).sum() / (len(input_id) // 2) mlm_acc = (mlm_scores.argmax(dim=1) == mlm_labels).sum() / len(mlm_labels) # Gather information from all workers for accurate statistics reduced_num_tokens = valid_length.sum() if get_world_size(args) > 1: distributed.all_reduce(reduced_num_tokens) reduced_num_mlm_tokens = th.tensor(len(mlm_labels), device=args.device) if get_world_size(args) > 1: distributed.all_reduce(reduced_num_mlm_tokens) reduced_loss_mlm = mlm_loss.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens if get_world_size(args) > 1: distributed.all_reduce(reduced_loss_mlm) reduced_acc_mlm = mlm_acc.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens if get_world_size(args) > 1: distributed.all_reduce(reduced_acc_mlm) reduced_bs = th.tensor(len(input_id), device=args.device) if get_world_size(args) > 1: distributed.all_reduce(reduced_bs) reduced_loss_qt = qt_loss.detach().clone() * len(input_id) / reduced_bs if get_world_size(args) > 1: distributed.all_reduce(reduced_loss_qt) reduced_acc_qt = qt_acc.detach().clone() * len(input_id) / reduced_bs if get_world_size(args) > 1: distributed.all_reduce(reduced_acc_qt) running_num_tks += reduced_num_tokens.item() running_mlm_loss += reduced_loss_mlm.item() running_mlm_acc += reduced_acc_mlm.item() running_qt_loss += reduced_loss_qt.item() running_qt_acc += reduced_acc_qt.item() if not accumulation: if args.fp16: scaler.unscale_(optimizer) # unscale for gradient clipping if get_world_size(args) > 1 and args.ZeRO: total_norm = optimizer.clip_grad_norm(args.max_grad_norm) else: total_norm = th.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) if get_world_size(args) > 1: distributed.all_reduce(total_norm) total_norm /= get_world_size(args) running_grad_norm += total_norm if args.fp16: scaler.step(optimizer) scaler.update() else: optimizer.step() with warnings.catch_warnings(): # Scheduler may warn if optimizer.step() call is skipped # due to invalid gradients detected by scaler. warnings.simplefilter("ignore", UserWarning) scheduler.step() optimizer.zero_grad(set_to_none=True) if step_num % args.log_interval == 0: toc = time.time() wps = running_num_tks / (toc - tic) eta = (args.num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600 interval = args.log_interval * args.num_accumulated logging.info(f'[Step {step_num}], LR={scheduler.get_last_lr()[0]:.6f}, ' f'Loss MLM/QT={running_mlm_loss / interval:.4f}/' f'{running_qt_loss / interval:.4f}, ' f'Acc MLM/QT={running_mlm_acc / interval:.4f}/' f'{running_qt_acc / interval:.4f}, ' f'Grad_norm={running_grad_norm / interval:.4f}, ' f'Time cost={toc - tic:.2f}, ' f'Throughput={wps:.2f} tokens/s, ETA={eta:.2f}h') if args.local_rank in (-1, 0): writer.add_scalar('Throughput_wps', wps, step_num) writer.add_scalar('Loss/MLM', running_mlm_loss / interval, step_num) writer.add_scalar('Loss/QT', running_qt_loss / interval, step_num) writer.add_scalar('Acc/MLM', running_mlm_acc / interval, step_num) writer.add_scalar('Acc/QT', running_qt_acc / interval, step_num) writer.add_scalar('LR', scheduler.get_last_lr()[0], step_num) writer.add_scalar('Grad_norm', running_grad_norm / interval, step_num) running_num_tks, running_grad_norm = 0, 0 running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0 tic = time.time() # Saving if step_num % save_interval == 0 or step_num >= args.num_steps: states_option(step_num, optimizer, args, 'Saving') if args.local_rank in (0, -1): parameters_option(step_num, model, args, 'Saving') logging.info('Finish training step: %d', step_num) train_end_time = time.time() logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time)) if args.local_rank in (0, -1): save_dir = os.path.join(args.ckpt_dir, args.model_name) final_save(model, save_dir, tokenizer.vocab, cfg)