Esempio n. 1
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['inputs']
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions,
                                                      model.subsampling_factor)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler,
                        P=P,
                        den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with grad_context():
        nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]
        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-mmi_loss).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             att_rate: float = 0.0,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['inputs']
    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions)

    loss_fn = CTCLoss(graph_compiler)
    grad_context = nullcontext if is_training else torch.no_grad

    with grad_context():
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        if att_rate != 0.0:
            att_loss = model.decoder_forward(encoder_memory, memory_mask,
                                             supervisions, graph_compiler)

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]
        tot_score, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                    supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0:
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        if att_rate != 0.0:
            loss = (-(1.0 - att_rate) * tot_score +
                    att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-tot_score) / (len(texts) * accum_grad)
        loss.backward()
        if is_update:
            maybe_log_gradients('train/grad_norms')
            clip_grad_value_(model.parameters(), 5.0)
            maybe_log_gradients('train/clipped_grad_norms')
            if tb_writer is not None and (global_batch_idx_train //
                                          accum_grad) % 200 == 0:
                # Once in a time we will perform a more costly diagnostic
                # to check the relative parameter change per minibatch.
                deltas = optim_step_and_measure_param_change(model, optimizer)
                tb_writer.add_scalars(
                    'train/relative_param_change_per_minibatch',
                    deltas,
                    global_step=global_batch_idx_train)
            else:
                optimizer.step()
            optimizer.zero_grad()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
def get_objf(batch: Dict,
             model: AcousticModel,
             ali_model: Optional[AcousticModel],
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             den_scale: float = 1.0,
             att_rate: float = 0.0,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None,
             scaler: GradScaler = None):
    feature = batch['inputs']
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler,
                        P=P,
                        den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with autocast(enabled=scaler.is_enabled()), grad_context():
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        if att_rate != 0.0:
            att_loss = model.module.decoder_forward(encoder_memory,
                                                    memory_mask, supervisions,
                                                    graph_compiler)

        if (ali_model is not None and global_batch_idx_train is not None
                and global_batch_idx_train // accum_grad < 4000):
            with torch.no_grad():
                ali_model_output = ali_model(feature)
            # subsampling is done slightly differently, may be small length
            # differences.
            min_len = min(ali_model_output.shape[2], nnet_output.shape[2])
            # scale less than one so it will be encouraged
            # to mimic ali_model's output
            ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad +
                                       500)
            nnet_output = nnet_output.clone(
            )  # or log-softmax backprop will fail.
            nnet_output[:, :, :
                        min_len] += ali_model_scale * ali_model_output[:, :, :
                                                                       min_len]

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0:
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        if att_rate != 0.0:
            loss = (-(1.0 - att_rate) * mmi_loss +
                    att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-mmi_loss) / (len(texts) * accum_grad)
        scaler.scale(loss).backward()
        if is_update:
            maybe_log_gradients('train/grad_norms')
            scaler.unscale_(optimizer)
            clip_grad_value_(model.parameters(), 5.0)
            maybe_log_gradients('train/clipped_grad_norms')
            if tb_writer is not None and (global_batch_idx_train //
                                          accum_grad) % 200 == 0:
                # Once in a time we will perform a more costly diagnostic
                # to check the relative parameter change per minibatch.
                deltas = optim_step_and_measure_param_change(
                    model, optimizer, scaler)
                tb_writer.add_scalars(
                    'train/relative_param_change_per_minibatch',
                    deltas,
                    global_step=global_batch_idx_train)
            else:
                scaler.step(optimizer)
            optimizer.zero_grad()
            scaler.update()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Esempio n. 4
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    subsampling_factor = model.module.subsampling_factor if isinstance(
        model, DDP) else model.subsampling_factor
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'], subsampling_factor),
         torch.floor_divide(supervisions['num_frames'], subsampling_factor)),
        1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-tot_score).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Esempio n. 5
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['inputs']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    try:
        subsampling_factor = model.subsampling_factor
    except:
        subsampling_factor = model.module.subsampling_factor

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions,
                                                      subsampling_factor)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler, den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with grad_context():
        nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]
        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-mmi_loss).backward()

        for name, param in model.named_parameters():
            if param.grad is None:
                print(name)

        maybe_log_gradients('train/grad_norms')
        #clip_grad_value_(model.parameters(), 5.0)
        clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans