def get_objf(batch: Dict,
             model: AcousticModel,
             ali_model: Optional[AcousticModel],
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             den_scale: float = 1.0,
             att_rate: float = 0.0,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None,
             scaler: GradScaler = None):
    feature = batch['inputs']
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler,
                        P=P,
                        den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with autocast(enabled=scaler.is_enabled()), grad_context():
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        if att_rate != 0.0:
            att_loss = model.module.decoder_forward(encoder_memory,
                                                    memory_mask, supervisions,
                                                    graph_compiler)

        if (ali_model is not None and global_batch_idx_train is not None
                and global_batch_idx_train // accum_grad < 4000):
            with torch.no_grad():
                ali_model_output = ali_model(feature)
            # subsampling is done slightly differently, may be small length
            # differences.
            min_len = min(ali_model_output.shape[2], nnet_output.shape[2])
            # scale less than one so it will be encouraged
            # to mimic ali_model's output
            ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad +
                                       500)
            nnet_output = nnet_output.clone(
            )  # or log-softmax backprop will fail.
            nnet_output[:, :, :
                        min_len] += ali_model_scale * ali_model_output[:, :, :
                                                                       min_len]

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0:
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        if att_rate != 0.0:
            loss = (-(1.0 - att_rate) * mmi_loss +
                    att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-mmi_loss) / (len(texts) * accum_grad)
        scaler.scale(loss).backward()
        if is_update:
            maybe_log_gradients('train/grad_norms')
            scaler.unscale_(optimizer)
            clip_grad_value_(model.parameters(), 5.0)
            maybe_log_gradients('train/clipped_grad_norms')
            if tb_writer is not None and (global_batch_idx_train //
                                          accum_grad) % 200 == 0:
                # Once in a time we will perform a more costly diagnostic
                # to check the relative parameter change per minibatch.
                deltas = optim_step_and_measure_param_change(
                    model, optimizer, scaler)
                tb_writer.add_scalars(
                    'train/relative_param_change_per_minibatch',
                    deltas,
                    global_step=global_batch_idx_train)
            else:
                scaler.step(optimizer)
            optimizer.zero_grad()
            scaler.update()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
예제 #2
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    scaler: amp.GradScaler,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        # import ipdb; ipdb.set_trace()
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # outputs = model(samples)
        with amp.autocast(enabled=scaler.is_enabled()):
            outputs = model(samples)
        outputs = to_fp32(outputs) if scaler.is_enabled() else outputs
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        # losses.backward()
        scaler.scale(losses).backward()
        if max_norm > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        # optimizer.step()
        scaler.step(optimizer)
        scaler.update()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
예제 #3
0
파일: train.py 프로젝트: dzubke/speech-lite
def run_epoch(model,
              optimizer,
              train_ldr,
              logger,
              debug_mode: bool,
              tbX_writer,
              iter_count: int,
              avg_loss: float,
              local_rank: int,
              loss_name: str,
              save_path: str,
              gcs_ckpt_handler,
              scaler: GradScaler = None) -> tuple:
    """
    Performs a forwards and backward pass through the model
    Args:
        iter_count (int): count of iterations
        save_path (str): path to directory where model is saved
        gcs_ckpt_handler: facilities saving files to google cloud storage
        scaler (GradScaler): gradient scaler to prevent gradient underflow when autocast
            uses float16 precision for forward pass
    Returns:
        Tuple[int, float]: train state of # batch iterations and average loss
    """
    # booleans and constants for logging
    is_rank_0 = (torch.distributed.get_rank() == 0)
    use_log = (logger is not None and is_rank_0)
    log_modulus = 100  # limits certain logging function to report less frequently
    exp_w = 0.985  # exponential weight for exponential moving average loss
    avg_grad_norm = 0
    model_t, data_t = 0.0, 0.0
    end_t = time.time()

    # progress bar for rank_0 process
    tq = tqdm.tqdm(train_ldr) if is_rank_0 else train_ldr

    # counter for model checkpointing
    batch_counter = 0
    device = torch.device("cuda:" + str(local_rank))

    # if scaler is enabled, amp is being used
    use_amp = scaler.is_enabled()
    print(f"Amp is being used: {use_amp}")

    # training loop
    for batch in tq:
        if use_log:
            logger.info(
                f"train: ====== Iteration: {iter_count} in run_epoch =======")

        ##############  Mid-epoch checkpoint ###############
        if is_rank_0 \
        and batch_counter % (len(train_ldr) // gcs_ckpt_handler.chkpt_per_epoch) == 0 \
        and batch_counter != 0:
            preproc = train_ldr.dataset.preproc
            save(model.module, preproc, save_path, tag='ckpt')
            gcs_ckpt_handler.upload_to_gcs("ckpt_model_state_dict.pth")
            gcs_ckpt_handler.upload_to_gcs("ckpt_preproc.pyc")
            # save the run_sate
            ckpt_state_path = os.path.join(save_path, "ckpt_run_state.pickle")
            write_pickle(ckpt_state_path,
                         {'run_state': (iter_count, avg_loss)})
            gcs_ckpt_handler.upload_to_gcs("ckpt_run_state.pickle")
            # checkpoint tensorboard
            gcs_ckpt_handler.upload_tensorboard_ckpt()

        batch_counter += 1
        ####################################################

        # convert the temprorary generator batch to a permanent list
        batch = list(batch)

        # save the batch information
        if use_log:
            if debug_mode:
                save_batch_log_stats(batch, logger)
                log_batchnorm_mean_std(model.module.state_dict(), logger)

        start_t = time.time()
        optimizer.zero_grad(
            set_to_none=True)  # set grads to None for modest perf improvement

        #  will autocast to lower precision if amp is used. otherwise, it's no-operation
        with autocast(enabled=use_amp):
            # unpack the batch
            inputs, labels, input_lens, label_lens = model.module.collate(
                *batch)
            inputs = inputs.cuda()  #.to(device) #.cuda(local_rank)
            out, rnn_args = model(inputs, softmax=False)

            # use the loss function defined in `loss_name`
            if loss_name == "native":
                loss = native_loss(out, labels, input_lens, label_lens,
                                   model.module.blank)
            elif loss_name == "awni":
                loss = awni_loss(out, labels, input_lens, label_lens,
                                 model.module.blank)
            elif loss_name == "naren":
                loss = naren_loss(out, labels, input_lens, label_lens,
                                  model.module.blank)

        # backward pass
        loss = loss.cuda()  # amp needs the loss to be on cuda
        scaler.scale(loss).backward()

        if use_log:
            if debug_mode:
                plot_grad_flow_bar(model.module.named_parameters(),
                                   get_logger_filename(logger))
                log_param_grad_norms(model.module.named_parameters(), logger)

        # gradient clipping and optimizer step, scaling disabled if amp is not used
        scaler.unscale_(optimizer)
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 200).item()
        scaler.step(optimizer)
        scaler.update()

        # logging in rank_0 process
        if is_rank_0:
            # calculate timers
            prev_end_t = end_t
            end_t = time.time()
            model_t += end_t - start_t
            data_t += start_t - prev_end_t

            # creating scalers from grad_norm and loss for weighted
            # TODO, needed with pytorch 0.4, may not be necessary anymore
            if isinstance(grad_norm, torch.Tensor):
                grad_norm = grad_norm.item()
            if isinstance(loss, torch.Tensor):
                loss = loss.item()

            # calculating the weighted average of loss and grad_norm
            if iter_count == 0:
                avg_loss = loss
                avg_grad_norm = grad_norm
            else:
                avg_loss = exp_w * avg_loss + (1 - exp_w) * loss
                avg_grad_norm = exp_w * avg_grad_norm + (1 - exp_w) * grad_norm

            # writing to the tensorboard log files
            tbX_writer.add_scalars('train/loss', {"loss": loss}, iter_count)
            tbX_writer.add_scalars('train/loss', {"avg_loss": avg_loss},
                                   iter_count)

            # adding this to suppress a tbX WARNING about inf values
            # TODO, this may or may not be a good idea as it masks inf in tensorboard
            if grad_norm == float('inf') or math.isnan(grad_norm):
                tbX_grad_norm = 1
            else:
                tbX_grad_norm = grad_norm
            tbX_writer.add_scalars('train/grad', {"grad_norm": tbX_grad_norm},
                                   iter_count)

            # progress bar update
            tq.set_postfix(it=iter_count,
                           grd_nrm=grad_norm,
                           lss=loss,
                           lss_av=avg_loss,
                           t_mdl=model_t,
                           t_data=data_t,
                           scl=scaler.get_scale())
            if use_log:
                logger.info(f'train: loss is inf: {loss == float("inf")}')
                logger.info(
                    f"train: iter={iter_count}, loss={round(loss,3)}, grad_norm={round(grad_norm,3)}"
                )

            if iter_count % log_modulus == 0:
                if use_log: log_cpu_mem_disk_usage(logger)

        # checks for nan gradients
        if check_nan_params_grads(model.module.parameters()):
            print("\n~~~ NaN value detected in gradients or parameters ~~~\n")
            if use_log:
                logger.error(
                    f"train: labels: {[labels]}, label_lens: {label_lens} state_dict: {model.module.state_dict()}"
                )
                log_model_grads(model.module.named_parameters(), logger)
                save_batch_log_stats(batch, logger)
                log_param_grad_norms(model.module.named_parameters(), logger)
                plot_grad_flow_bar(model.module.named_parameters(),
                                   get_logger_filename(logger))

            #debug_mode = True
            #torch.autograd.set_detect_anomaly(True)

        iter_count += 1

    return iter_count, avg_loss