def validate_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Dict[str, torch.Tensor]], reporter: SubReporter, options: GANTrainerOptions, distributed_option: DistributedOption, ) -> None: """Validate one epoch.""" assert check_argument_types() ngpu = options.ngpu no_forward_run = options.no_forward_run distributed = distributed_option.distributed generator_first = options.generator_first model.eval() # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") for (_, batch) in iterator: assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: continue if generator_first: turns = ["generator", "discriminator"] else: turns = ["discriminator", "generator"] for turn in turns: retval = model(forward_generator=turn == "generator", **batch) if isinstance(retval, dict): stats = retval["stats"] weight = retval["weight"] else: _, stats, weight = retval if ngpu > 1 or distributed: # Apply weighted averaging for stats. # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) reporter.register(stats, weight) reporter.next() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
def validate_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Dict[str, torch.Tensor]], reporter: SubReporter, options: TrainerOptions, ) -> None: assert check_argument_types() ngpu = options.ngpu no_forward_run = options.no_forward_run distributed = isinstance(model, torch.nn.parallel.DistributedDataParallel) model.eval() # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") for (_, batch) in iterator: assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: continue _, stats, weight = model(**batch) if ngpu > 1 or distributed: # Apply weighted averaging for stats. # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) reporter.register(stats, weight) reporter.next() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], scaler: Optional[GradScaler], reporter: SubReporter, summary_writer, options: GANTrainerOptions, distributed_option: DistributedOption, ) -> bool: """Train one epoch.""" assert check_argument_types() grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip grad_clip_type = options.grad_clip_type log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu use_wandb = options.use_wandb generator_first = options.generator_first distributed = distributed_option.distributed # Check unavailable options # TODO(kan-bayashi): Support the use of these options if accum_grad > 1: raise NotImplementedError( "accum_grad > 1 is not supported in GAN-based training." ) if grad_noise: raise NotImplementedError( "grad_noise is not supported in GAN-based training." ) if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1 ): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False continue turn_start_time = time.perf_counter() if generator_first: turns = ["generator", "discriminator"] else: turns = ["discriminator", "generator"] for turn in turns: with autocast(scaler is not None): with reporter.measure_time(f"{turn}_forward_time"): retval = model(forward_generator=turn == "generator", **batch) # Note(kamo): # Supporting two patterns for the returned value from the model # a. dict type if isinstance(retval, dict): loss = retval["loss"] stats = retval["stats"] weight = retval["weight"] optim_idx = retval.get("optim_idx") if optim_idx is not None and not isinstance(optim_idx, int): if not isinstance(optim_idx, torch.Tensor): raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {type(optim_idx)}" ) if optim_idx.dim() >= 2: raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {optim_idx.dim()}dim tensor" ) if optim_idx.dim() == 1: for v in optim_idx: if v != optim_idx[0]: raise RuntimeError( "optim_idx must be 1dim tensor " "having same values for all entries" ) optim_idx = optim_idx[0].item() else: optim_idx = optim_idx.item() # b. tuple or list type else: raise RuntimeError("model output must be dict.") stats = {k: v for k, v in stats.items() if v is not None} if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size since DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() reporter.register(stats, weight) with reporter.measure_time(f"{turn}_backward_time"): if scaler is not None: # Scales loss. Calls backward() on scaled loss # to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose # for corresponding forward ops. scaler.scale(loss).backward() else: loss.backward() if scaler is not None: # Unscales the gradients of optimizer's assigned params in-place for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.unscale_(optimizer) # TODO(kan-bayashi): Compute grad norm without clipping grad_norm = None if grad_clip > 0.0: # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=grad_clip, norm_type=grad_clip_type, ) # PyTorch<=1.4, clip_grad_norm_ returns float value if not isinstance(grad_norm, torch.Tensor): grad_norm = torch.tensor(grad_norm) if grad_norm is None or torch.isfinite(grad_norm): all_steps_are_invalid = False with reporter.measure_time(f"{turn}_optim_step_time"): for iopt, (optimizer, scheduler) in enumerate( zip(optimizers, schedulers) ): if optim_idx is not None and iopt != optim_idx: continue if scaler is not None: # scaler.step() first unscales the gradients of # the optimizer's assigned params. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() else: optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() else: logging.warning( f"The grad norm is {grad_norm}. " "Skipping updating the model." ) # Must invoke scaler.update() if unscale_() is used in the # iteration to avoid the following error: # RuntimeError: unscale_() has already been called # on this optimizer since the last update(). # Note that if the gradient has inf/nan values, # scaler.step skips optimizer.step(). if scaler is not None: for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.step(optimizer) scaler.update() for iopt, optimizer in enumerate(optimizers): # NOTE(kan-bayashi): In the case of GAN, we need to clear # the gradient of both optimizers after every update. optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( { f"optim{optim_idx}_lr{i}": pg["lr"] for i, pg in enumerate(optimizers[optim_idx].param_groups) if "lr" in pg }, ) reporter.register( {f"{turn}_train_time": time.perf_counter() - turn_start_time} ) turn_start_time = time.perf_counter() reporter.register({"train_time": time.perf_counter() - start_time}) start_time = time.perf_counter() # NOTE(kamo): Call log_message() after next() reporter.next() if iiter % log_interval == 0: logging.info(reporter.log_message(-log_interval)) if summary_writer is not None: reporter.tensorboard_add_scalar(summary_writer, -log_interval) if use_wandb: reporter.wandb_log() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid
def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], scaler: Optional[GradScaler], reporter: SubReporter, summary_writer: Optional[SummaryWriter], options: TrainerOptions, ) -> bool: assert check_argument_types() # Note(kamo): assumes one optimizer assert cls.num_optimizers == 1, cls.num_optimizers assert len(optimizers) == 1, len(optimizers) optimizer = optimizers[0] scheduler = schedulers[0] grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu distributed = isinstance(model, torch.nn.parallel.DistributedDataParallel) if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False continue with autocast(scaler is not None): with reporter.measure_time("forward_time"): loss, stats, weight = model(**batch) stats = {k: v for k, v in stats.items() if v is not None} if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size because DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() loss /= accum_grad reporter.register(stats, weight) with reporter.measure_time("backward_time"): if scaler is not None: # Scales loss. Calls backward() on scaled loss # to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose # for corresponding forward ops. scaler.scale(loss).backward() else: loss.backward() if iiter % accum_grad == 0: if scaler is not None: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(optimizer) # gradient noise injection if grad_noise: add_gradient_noise( model, reporter.get_total_count(), duration=100, eta=1.0, scale_factor=0.55, ) # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), grad_clip) # PyTorch<=1.4, clip_grad_norm_ returns float value if not isinstance(grad_norm, torch.Tensor): grad_norm = torch.tensor(grad_norm) if not torch.isfinite(grad_norm): logging.warning( f"The grad norm is {grad_norm}. Skipping updating the model." ) # Must invoke scaler.update() if unscale_() is used in the iteration # to avoid the following error: # RuntimeError: unscale_() has already been called # on this optimizer since the last update(). # Note that if the gradient has inf/nan values, # scaler.step skips optimizer.step(). if scaler is not None: scaler.step(optimizer) scaler.update() else: all_steps_are_invalid = False with reporter.measure_time("optim_step_time"): if scaler is not None: # scaler.step() first unscales the gradients of # the optimizer's assigned params. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() else: optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( dict( { f"lr_{i}": pg["lr"] for i, pg in enumerate(optimizer.param_groups) if "lr" in pg }, train_time=time.perf_counter() - start_time, ), ) start_time = time.perf_counter() # NOTE(kamo): Call log_message() after next() reporter.next() if iiter % log_interval == 0: logging.info(reporter.log_message(-log_interval)) if summary_writer is not None: reporter.tensorboard_add_scalar(summary_writer, -log_interval) else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid
def plot_attention( cls, model: torch.nn.Module, output_dir: Optional[Path], summary_writer: Optional[SummaryWriter], iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], reporter: SubReporter, options: TrainerOptions, ) -> None: assert check_argument_types() import matplotlib ngpu = options.ngpu no_forward_run = options.no_forward_run matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator model.eval() for ids, batch in iterator: assert isinstance(batch, dict), type(batch) assert len(next(iter(batch.values()))) == len(ids), ( len(next(iter(batch.values()))), len(ids), ) batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: continue # 1. Forwarding model and gathering all attentions # calculate_all_attentions() uses single gpu only. att_dict = calculate_all_attentions(model, batch) # 2. Plot attentions: This part is slow due to matplotlib for k, att_list in att_dict.items(): assert len(att_list) == len(ids), (len(att_list), len(ids)) for id_, att_w in zip(ids, att_list): if isinstance(att_w, torch.Tensor): att_w = att_w.detach().cpu().numpy() if att_w.ndim == 2: att_w = att_w[None] elif att_w.ndim > 3 or att_w.ndim == 1: raise RuntimeError( f"Must be 2 or 3 dimension: {att_w.ndim}") w, h = plt.figaspect(1.0 / len(att_w)) fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) axes = fig.subplots(1, len(att_w)) if len(att_w) == 1: axes = [axes] for ax, aw in zip(axes, att_w): ax.imshow(aw.astype(np.float32), aspect="auto") ax.set_title(f"{k}_{id_}") ax.set_xlabel("Input") ax.set_ylabel("Output") ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) if output_dir is not None: p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png" p.parent.mkdir(parents=True, exist_ok=True) fig.savefig(p) if summary_writer is not None: summary_writer.add_figure(f"{k}_{id_}", fig, reporter.get_epoch()) # Dummy register() stimulates to increment the counter reporter.register({})
def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], reporter: SubReporter, options: TrainerOptions, ) -> bool: assert check_argument_types() # Note(kamo): assumes one optimizer assert cls.num_optimizers == 1, cls.num_optimizers assert len(optimizers) == 1, len(optimizers) optimizer = optimizers[0] scheduler = schedulers[0] grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu distributed = isinstance(model, torch.nn.parallel.DistributedDataParallel) use_apex = options.train_dtype in ("O0", "O1", "O2", "O3") if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False reporter.register({}) continue with reporter.measure_time("forward_time"): loss, stats, weight = model(**batch) if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size because DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() reporter.register(stats, weight) loss /= accum_grad with reporter.measure_time("backward_time"): if use_apex: try: from apex import amp except ImportError: logging.error( "You need to install apex. " "See https://github.com/NVIDIA/apex#linux") with amp.scale_loss(loss, optimizers) as scaled_loss: scaled_loss.backward() else: loss.backward() if iiter % accum_grad == 0: # gradient noise injection if grad_noise: add_gradient_noise( model, reporter.get_total_count(), duration=100, eta=1.0, scale_factor=0.55, ) # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), grad_clip) if not np.isfinite(grad_norm): logging.warning( f"The grad norm is {grad_norm}. Skipping updating the model." ) else: all_steps_are_invalid = False with reporter.measure_time("optim_step_time"): optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( dict( { f"lr_{i}": pg["lr"] for i, pg in enumerate(optimizer.param_groups) if "lr" in pg }, train_time=time.perf_counter() - start_time, ), # Suppress to increment the internal counter. not_increment_count=True, ) start_time = time.perf_counter() if iiter % log_interval == 0: logging.info(reporter.log_message()) else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid
def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], scaler: Optional[GradScaler], reporter: SubReporter, summary_writer, options: TrainerOptions, distributed_option: DistributedOption, ) -> bool: assert check_argument_types() grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip grad_clip_type = options.grad_clip_type log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu use_wandb = options.use_wandb create_graph_in_tensorboard = options.create_graph_in_tensorboard distributed = distributed_option.distributed if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (utt_id, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch["utt_id"] = utt_id batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False continue if (create_graph_in_tensorboard and iiter == 1 and summary_writer is not None): if distributed: _model = getattr(model, "module") else: _model = model if _model is not None: try: _args = kwargs2args(_model.forward, batch) except (ValueError, TypeError): logging.warning( "inpect.signature() is failed for the model. " "The graph can't be added for tensorboard.") else: try: summary_writer.add_graph( _model, _args, use_strict_trace=False) except Exception: logging.warning( "summary_writer.add_graph() " "is failed for the model. " "The graph can't be added for tensorboard." ) del _args else: logging.warning( "model.module is not found (This should be a bug.)" ) del _model with autocast(scaler is not None): with reporter.measure_time("forward_time"): retval = model(**batch) # Note(kamo): # Supporting two patterns for the returned value from the model # a. dict type if isinstance(retval, dict): loss = retval["loss"] stats = retval["stats"] weight = retval["weight"] optim_idx = retval.get("optim_idx") if optim_idx is not None and not isinstance( optim_idx, int): if not isinstance(optim_idx, torch.Tensor): raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {type(optim_idx)}") if optim_idx.dim() >= 2: raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {optim_idx.dim()}dim tensor") if optim_idx.dim() == 1: for v in optim_idx: if v != optim_idx[0]: raise RuntimeError( "optim_idx must be 1dim tensor " "having same values for all entries" ) optim_idx = optim_idx[0].item() else: optim_idx = optim_idx.item() # b. tuple or list type else: loss, stats, weight = retval optim_idx = None stats = {k: v for k, v in stats.items() if v is not None} if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size because DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() loss /= accum_grad reporter.register(stats, weight) with reporter.measure_time("backward_time"): if scaler is not None: # Scales loss. Calls backward() on scaled loss # to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose # for corresponding forward ops. scaler.scale(loss).backward() else: loss.backward() if iiter % accum_grad == 0: if scaler is not None: # Unscales the gradients of optimizer's assigned params in-place for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.unscale_(optimizer) # gradient noise injection if grad_noise: add_gradient_noise( model, reporter.get_total_count(), duration=100, eta=1.0, scale_factor=0.55, ) # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=grad_clip, norm_type=grad_clip_type, ) # PyTorch<=1.4, clip_grad_norm_ returns float value if not isinstance(grad_norm, torch.Tensor): grad_norm = torch.tensor(grad_norm) if not torch.isfinite(grad_norm): logging.warning( f"The grad norm is {grad_norm}. Skipping updating the model." ) # Must invoke scaler.update() if unscale_() is used in the iteration # to avoid the following error: # RuntimeError: unscale_() has already been called # on this optimizer since the last update(). # Note that if the gradient has inf/nan values, # scaler.step skips optimizer.step(). if scaler is not None: for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.step(optimizer) scaler.update() else: all_steps_are_invalid = False with reporter.measure_time("optim_step_time"): for iopt, (optimizer, scheduler) in enumerate( zip(optimizers, schedulers)): if optim_idx is not None and iopt != optim_idx: continue if scaler is not None: # scaler.step() first unscales the gradients of # the optimizer's assigned params. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() else: optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( dict( { f"optim{i}_lr{j}": pg["lr"] for i, optimizer in enumerate(optimizers) for j, pg in enumerate(optimizer.param_groups) if "lr" in pg }, train_time=time.perf_counter() - start_time, ), ) start_time = time.perf_counter() # NOTE(kamo): Call log_message() after next() reporter.next() if iiter % log_interval == 0: logging.info(reporter.log_message(-log_interval)) if summary_writer is not None: reporter.tensorboard_add_scalar(summary_writer, -log_interval) if use_wandb: reporter.wandb_log() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid