def get_objf(batch: Dict, model: AcousticModel, ali_model: Optional[AcousticModel], P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, den_scale: float = 1.0, att_rate: float = 0.0, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None, scaler: GradScaler = None): feature = batch['inputs'] # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions) loss_fn = LFMMILoss(graph_compiler=graph_compiler, P=P, den_scale=den_scale) grad_context = nullcontext if is_training else torch.no_grad with autocast(enabled=scaler.is_enabled()), grad_context(): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: att_loss = model.module.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) if (ali_model is not None and global_batch_idx_train is not None and global_batch_idx_train // accum_grad < 4000): with torch.no_grad(): ali_model_output = ali_model(feature) # subsampling is done slightly differently, may be small length # differences. min_len = min(ali_model_output.shape[2], nnet_output.shape[2]) # scale less than one so it will be encouraged # to mimic ali_model's output ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad + 500) nnet_output = nnet_output.clone( ) # or log-softmax backprop will fail. nnet_output[:, :, : min_len] += ali_model_scale * ali_model_output[:, :, : min_len] # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) if att_rate != 0.0: loss = (-(1.0 - att_rate) * mmi_loss + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-mmi_loss) / (len(texts) * accum_grad) scaler.scale(loss).backward() if is_update: maybe_log_gradients('train/grad_norms') scaler.unscale_(optimizer) clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change( model, optimizer, scaler) tb_writer.add_scalars( 'train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: scaler.step(optimizer) optimizer.zero_grad() scaler.update() ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, scaler: amp.GradScaler, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): # import ipdb; ipdb.set_trace() samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # outputs = model(samples) with amp.autocast(enabled=scaler.is_enabled()): outputs = model(samples) outputs = to_fp32(outputs) if scaler.is_enabled() else outputs loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() # losses.backward() scaler.scale(losses).backward() if max_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # optimizer.step() scaler.step(optimizer) scaler.update() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def run_epoch(model, optimizer, train_ldr, logger, debug_mode: bool, tbX_writer, iter_count: int, avg_loss: float, local_rank: int, loss_name: str, save_path: str, gcs_ckpt_handler, scaler: GradScaler = None) -> tuple: """ Performs a forwards and backward pass through the model Args: iter_count (int): count of iterations save_path (str): path to directory where model is saved gcs_ckpt_handler: facilities saving files to google cloud storage scaler (GradScaler): gradient scaler to prevent gradient underflow when autocast uses float16 precision for forward pass Returns: Tuple[int, float]: train state of # batch iterations and average loss """ # booleans and constants for logging is_rank_0 = (torch.distributed.get_rank() == 0) use_log = (logger is not None and is_rank_0) log_modulus = 100 # limits certain logging function to report less frequently exp_w = 0.985 # exponential weight for exponential moving average loss avg_grad_norm = 0 model_t, data_t = 0.0, 0.0 end_t = time.time() # progress bar for rank_0 process tq = tqdm.tqdm(train_ldr) if is_rank_0 else train_ldr # counter for model checkpointing batch_counter = 0 device = torch.device("cuda:" + str(local_rank)) # if scaler is enabled, amp is being used use_amp = scaler.is_enabled() print(f"Amp is being used: {use_amp}") # training loop for batch in tq: if use_log: logger.info( f"train: ====== Iteration: {iter_count} in run_epoch =======") ############## Mid-epoch checkpoint ############### if is_rank_0 \ and batch_counter % (len(train_ldr) // gcs_ckpt_handler.chkpt_per_epoch) == 0 \ and batch_counter != 0: preproc = train_ldr.dataset.preproc save(model.module, preproc, save_path, tag='ckpt') gcs_ckpt_handler.upload_to_gcs("ckpt_model_state_dict.pth") gcs_ckpt_handler.upload_to_gcs("ckpt_preproc.pyc") # save the run_sate ckpt_state_path = os.path.join(save_path, "ckpt_run_state.pickle") write_pickle(ckpt_state_path, {'run_state': (iter_count, avg_loss)}) gcs_ckpt_handler.upload_to_gcs("ckpt_run_state.pickle") # checkpoint tensorboard gcs_ckpt_handler.upload_tensorboard_ckpt() batch_counter += 1 #################################################### # convert the temprorary generator batch to a permanent list batch = list(batch) # save the batch information if use_log: if debug_mode: save_batch_log_stats(batch, logger) log_batchnorm_mean_std(model.module.state_dict(), logger) start_t = time.time() optimizer.zero_grad( set_to_none=True) # set grads to None for modest perf improvement # will autocast to lower precision if amp is used. otherwise, it's no-operation with autocast(enabled=use_amp): # unpack the batch inputs, labels, input_lens, label_lens = model.module.collate( *batch) inputs = inputs.cuda() #.to(device) #.cuda(local_rank) out, rnn_args = model(inputs, softmax=False) # use the loss function defined in `loss_name` if loss_name == "native": loss = native_loss(out, labels, input_lens, label_lens, model.module.blank) elif loss_name == "awni": loss = awni_loss(out, labels, input_lens, label_lens, model.module.blank) elif loss_name == "naren": loss = naren_loss(out, labels, input_lens, label_lens, model.module.blank) # backward pass loss = loss.cuda() # amp needs the loss to be on cuda scaler.scale(loss).backward() if use_log: if debug_mode: plot_grad_flow_bar(model.module.named_parameters(), get_logger_filename(logger)) log_param_grad_norms(model.module.named_parameters(), logger) # gradient clipping and optimizer step, scaling disabled if amp is not used scaler.unscale_(optimizer) grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 200).item() scaler.step(optimizer) scaler.update() # logging in rank_0 process if is_rank_0: # calculate timers prev_end_t = end_t end_t = time.time() model_t += end_t - start_t data_t += start_t - prev_end_t # creating scalers from grad_norm and loss for weighted # TODO, needed with pytorch 0.4, may not be necessary anymore if isinstance(grad_norm, torch.Tensor): grad_norm = grad_norm.item() if isinstance(loss, torch.Tensor): loss = loss.item() # calculating the weighted average of loss and grad_norm if iter_count == 0: avg_loss = loss avg_grad_norm = grad_norm else: avg_loss = exp_w * avg_loss + (1 - exp_w) * loss avg_grad_norm = exp_w * avg_grad_norm + (1 - exp_w) * grad_norm # writing to the tensorboard log files tbX_writer.add_scalars('train/loss', {"loss": loss}, iter_count) tbX_writer.add_scalars('train/loss', {"avg_loss": avg_loss}, iter_count) # adding this to suppress a tbX WARNING about inf values # TODO, this may or may not be a good idea as it masks inf in tensorboard if grad_norm == float('inf') or math.isnan(grad_norm): tbX_grad_norm = 1 else: tbX_grad_norm = grad_norm tbX_writer.add_scalars('train/grad', {"grad_norm": tbX_grad_norm}, iter_count) # progress bar update tq.set_postfix(it=iter_count, grd_nrm=grad_norm, lss=loss, lss_av=avg_loss, t_mdl=model_t, t_data=data_t, scl=scaler.get_scale()) if use_log: logger.info(f'train: loss is inf: {loss == float("inf")}') logger.info( f"train: iter={iter_count}, loss={round(loss,3)}, grad_norm={round(grad_norm,3)}" ) if iter_count % log_modulus == 0: if use_log: log_cpu_mem_disk_usage(logger) # checks for nan gradients if check_nan_params_grads(model.module.parameters()): print("\n~~~ NaN value detected in gradients or parameters ~~~\n") if use_log: logger.error( f"train: labels: {[labels]}, label_lens: {label_lens} state_dict: {model.module.state_dict()}" ) log_model_grads(model.module.named_parameters(), logger) save_batch_log_stats(batch, logger) log_param_grad_norms(model.module.named_parameters(), logger) plot_grad_flow_bar(model.module.named_parameters(), get_logger_filename(logger)) #debug_mode = True #torch.autograd.set_detect_anomaly(True) iter_count += 1 return iter_count, avg_loss