Example #1
0
def get_validation_loss(dataloader: torch.utils.data.DataLoader,
                        model: AcousticModel, P: k2.Fsa, device: torch.device,
                        graph_compiler: MmiMbrTrainingGraphCompiler):
    total_loss = 0.
    total_mmi_loss = 0.
    total_mbr_loss = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        mmi_loss, mbr_loss, frames, all_frames = get_loss(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=False)
        cur_loss = mmi_loss + mbr_loss
        total_loss += cur_loss
        total_mmi_loss += mmi_loss
        total_mbr_loss += mbr_loss
        total_frames += frames
        total_all_frames += all_frames

    return total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames
def get_validation_objf(dataloader: torch.utils.data.DataLoader,
                        model: AcousticModel,
                        P: k2.Fsa,
                        device: torch.device,
                        graph_compiler: MmiTrainingGraphCompiler,
                        den_scale: float = 1):
    total_objf = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        objf, frames, all_frames = get_objf(batch=batch,
                                            model=model,
                                            P=P,
                                            device=device,
                                            graph_compiler=graph_compiler,
                                            is_training=False,
                                            is_update=False,
                                            den_scale=den_scale)
        total_objf += objf
        total_frames += frames
        total_all_frames += all_frames

    return total_objf, total_frames, total_all_frames
def get_validation_objf(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    ali_model: Optional[AcousticModel],
    P: k2.Fsa,
    device: torch.device,
    graph_compiler: MmiTrainingGraphCompiler,
    scaler: GradScaler,
    den_scale: float = 1,
):
    total_objf = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    from torchaudio.datasets.utils import bg_iterator
    for batch_idx, batch in enumerate(bg_iterator(dataloader, 2)):
        objf, frames, all_frames = get_objf(batch=batch,
                                            model=model,
                                            ali_model=ali_model,
                                            P=P,
                                            device=device,
                                            graph_compiler=graph_compiler,
                                            is_training=False,
                                            is_update=False,
                                            den_scale=den_scale,
                                            scaler=scaler)
        total_objf += objf
        total_frames += frames
        total_all_frames += all_frames

    return total_objf, total_frames, total_all_frames
Example #4
0
def get_validation_objf(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: torch.device,
):
    total_objf = 0.0
    total_frames = 0.0  # for display only

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        objf, frames = get_objf(batch, model, device, False)
        total_objf += objf
        total_frames += frames

    return total_objf, total_frames
Example #5
0
def get_validation_objf(dataloader: torch.utils.data.DataLoader,
                        model: AcousticModel, device: torch.device,
                        graph_compiler: CtcTrainingGraphCompiler):
    total_objf = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        objf, frames, all_frames = get_objf(batch, model, device,
                                            graph_compiler, False)
        total_objf += objf
        total_frames += frames
        total_all_frames += all_frames

    return total_objf, total_frames, total_all_frames