コード例 #1
0
def get_validation_objf(dataloader: torch.utils.data.DataLoader,
                        model: AcousticModel,
                        P: k2.Fsa,
                        device: torch.device,
                        graph_compiler: MmiTrainingGraphCompiler,
                        den_scale: float = 1):
    total_objf = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        objf, frames, all_frames = get_objf(batch=batch,
                                            model=model,
                                            P=P,
                                            device=device,
                                            graph_compiler=graph_compiler,
                                            is_training=False,
                                            is_update=False,
                                            den_scale=den_scale)
        total_objf += objf
        total_frames += frames
        total_all_frames += all_frames

    return total_objf, total_frames, total_all_frames
コード例 #2
0
ファイル: mmi_mbr_train.py プロジェクト: danpovey/snowfall
def get_validation_loss(dataloader: torch.utils.data.DataLoader,
                        model: AcousticModel, P: k2.Fsa, device: torch.device,
                        graph_compiler: MmiMbrTrainingGraphCompiler):
    total_loss = 0.
    total_mmi_loss = 0.
    total_mbr_loss = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        mmi_loss, mbr_loss, frames, all_frames = get_loss(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=False)
        cur_loss = mmi_loss + mbr_loss
        total_loss += cur_loss
        total_mmi_loss += mmi_loss
        total_mbr_loss += mbr_loss
        total_frames += frames
        total_all_frames += all_frames

    return total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames
コード例 #3
0
def get_validation_objf(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    ali_model: Optional[AcousticModel],
    P: k2.Fsa,
    device: torch.device,
    graph_compiler: MmiTrainingGraphCompiler,
    scaler: GradScaler,
    den_scale: float = 1,
):
    total_objf = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    from torchaudio.datasets.utils import bg_iterator
    for batch_idx, batch in enumerate(bg_iterator(dataloader, 2)):
        objf, frames, all_frames = get_objf(batch=batch,
                                            model=model,
                                            ali_model=ali_model,
                                            P=P,
                                            device=device,
                                            graph_compiler=graph_compiler,
                                            is_training=False,
                                            is_update=False,
                                            den_scale=den_scale,
                                            scaler=scaler)
        total_objf += objf
        total_frames += frames
        total_all_frames += all_frames

    return total_objf, total_frames, total_all_frames
コード例 #4
0
def load_checkpoint(filename: Pathlike,
                    model: AcousticModel) -> Dict[str, Any]:
    logging.info('load checkpoint from {}'.format(filename))

    checkpoint = torch.load(filename, map_location='cpu')

    keys = [
        'state_dict', 'epoch', 'learning_rate', 'objf', 'valid_objf',
        'num_features', 'num_classes', 'subsampling_factor',
        'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict)
    else:
        model.load_state_dict(checkpoint['state_dict'])

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    return checkpoint
コード例 #5
0
ファイル: train.py プロジェクト: juxiangyu/snowfall
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, device: torch.device,
                    graph_compiler: TrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    num_epochs: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()
        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \
            get_objf(batch, model, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler)
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(total_valid_objf / total_valid_frames,
                        total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))
        prev_timestamp = datetime.now()
    return total_objf
コード例 #6
0
def load_checkpoint(
    filename: Pathlike,
    model: AcousticModel,
    optimizer: Optional[object] = None,
    scheduler: Optional[object] = None,
    scaler: Optional[GradScaler] = None,
) -> Dict[str, Any]:
    logging.info('load checkpoint from {}'.format(filename))

    checkpoint = torch.load(filename, map_location='cpu')

    keys = [
        'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate',
        'objf', 'valid_objf', 'num_features', 'num_classes',
        'subsampling_factor', 'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if isinstance(model, DistributedDataParallel):
        model = model.module

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    # Note we used strict=False above so that the current code
    # can load models trained with P_scores.

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])

    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler'])

    if scaler is not None:
        scaler.load_state_dict(checkpoint['grad_scaler'])

    return checkpoint
コード例 #7
0
def average_checkpoint(filenames: List[Pathlike],
                       model: AcousticModel) -> Dict[str, Any]:
    logging.info('average over checkpoints {}'.format(filenames))

    avg_model = None

    # sum
    for filename in filenames:
        checkpoint = torch.load(filename, map_location='cpu')
        checkpoint_model = checkpoint['state_dict']
        if avg_model is None:
            avg_model = checkpoint_model
        else:
            for k in avg_model.keys():
                avg_model[k] += checkpoint_model[k]
    # average
    for k in avg_model.keys():
        if avg_model[k] is not None:
            if avg_model[k].is_floating_point():
                avg_model[k] /= len(filenames)
            else:
                avg_model[k] //= len(filenames)

    checkpoint['state_dict'] = avg_model

    keys = [
        'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate',
        'objf', 'valid_objf', 'num_features', 'num_classes',
        'subsampling_factor', 'global_batch_idx_train'
    ]
    missing_keys = set(keys) - set(checkpoint.keys())
    if missing_keys:
        raise ValueError(f"Missing keys in checkpoint: {missing_keys}")

    if not list(model.state_dict().keys())[0].startswith('module.') \
            and list(checkpoint['state_dict'])[0].startswith('module.'):
        # the checkpoint was saved by DDP
        logging.info('load checkpoint from DDP')
        dst_state_dict = model.state_dict()
        src_state_dict = checkpoint['state_dict']
        for key in dst_state_dict.keys():
            src_key = '{}.{}'.format('module', key)
            dst_state_dict[key] = src_state_dict.pop(src_key)
        assert len(src_state_dict) == 0
        model.load_state_dict(dst_state_dict, strict=False)
    else:
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    model.num_features = checkpoint['num_features']
    model.num_classes = checkpoint['num_classes']
    model.subsampling_factor = checkpoint['subsampling_factor']

    return checkpoint
コード例 #8
0
def get_validation_objf(
    dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: torch.device,
):
    total_objf = 0.0
    total_frames = 0.0  # for display only

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        objf, frames = get_objf(batch, model, device, False)
        total_objf += objf
        total_frames += frames

    return total_objf, total_frames
コード例 #9
0
def save_checkpoint(filename: Pathlike,
                    model: AcousticModel,
                    epoch: int,
                    learning_rate: float,
                    objf: float,
                    valid_objf: float,
                    global_batch_idx_train: int,
                    local_rank: int = 0) -> None:
    if local_rank is not None and local_rank != 0:
        return
    logging.info(
        f'Save checkpoint to {filename}: epoch={epoch}, '
        f'learning_rate={learning_rate}, objf={objf}, valid_objf={valid_objf}')
    checkpoint = {
        'state_dict': model.state_dict(),
        'num_features': model.num_features,
        'num_classes': model.num_classes,
        'subsampling_factor': model.subsampling_factor,
        'epoch': epoch,
        'learning_rate': learning_rate,
        'objf': objf,
        'valid_objf': valid_objf,
        'global_batch_idx_train': global_batch_idx_train,
    }
    torch.save(checkpoint, filename)
コード例 #10
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['inputs']
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions,
                                                      model.subsampling_factor)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler,
                        P=P,
                        den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with grad_context():
        nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]
        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-mmi_loss).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #11
0
def get_validation_objf(dataloader: torch.utils.data.DataLoader,
                        model: AcousticModel, device: torch.device,
                        graph_compiler: CtcTrainingGraphCompiler):
    total_objf = 0.
    total_frames = 0.  # for display only
    total_all_frames = 0.  # all frames including those seqs that failed.

    model.eval()

    for batch_idx, batch in enumerate(dataloader):
        objf, frames, all_frames = get_objf(batch, model, device,
                                            graph_compiler, False)
        total_objf += objf
        total_frames += frames
        total_all_frames += all_frames

    return total_objf, total_frames, total_all_frames
コード例 #12
0
def get_objf(
    batch: Dict,
    model: AcousticModel,
    device: torch.device,
    training: bool,
    optimizer: Optional[torch.optim.Optimizer] = None,
    class_weights: Optional[torch.Tensor] = None,
):
    feature = batch["inputs"]  # (N, T, C)
    supervisions = batch["supervisions"]["is_voice"].unsqueeze(
        -1).long()  # (N, T, 1)

    feature = feature.to(device)
    supervisions = supervisions.to(device)
    if class_weights is not None:
        class_weights = class_weights.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    # Compute cross-entropy loss
    xent_loss = torch.nn.CrossEntropyLoss(reduction="sum",
                                          weight=class_weights)
    tot_score = xent_loss(nnet_output.contiguous().view(-1, 2),
                          supervisions.contiguous().view(-1))

    if training:
        optimizer.zero_grad()
        tot_score.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step(),

    ans = (
        tot_score.detach().cpu().item(),  # total objective function value
        supervisions.numel(),  # number of frames
    )
    return ans
コード例 #13
0
def save_checkpoint(filename: Pathlike,
                    model: AcousticModel,
                    epoch: int,
                    learning_rate: float,
                    objf: float,
                    local_rank: int = 0) -> None:
    if local_rank is not None and local_rank != 0:
        return
    logging.info('Save checkpoint to {filename}: epoch={epoch}, '
                 'learning_rate={learning_rate}, objf={objf}'.format(
                     filename=filename,
                     epoch=epoch,
                     learning_rate=learning_rate,
                     objf=objf))
    checkpoint = {
        'state_dict': model.state_dict(),
        'num_features': model.num_features,
        'num_classes': model.num_classes,
        'subsampling_factor': model.subsampling_factor,
        'epoch': epoch,
        'learning_rate': learning_rate,
        'objf': objf
    }
    torch.save(checkpoint, filename)
コード例 #14
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #15
0
def get_objf(batch: Dict,
             model: AcousticModel,
             ali_model: Optional[AcousticModel],
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             den_scale: float = 1.0,
             att_rate: float = 0.0,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None,
             scaler: GradScaler = None):
    feature = batch['inputs']
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler,
                        P=P,
                        den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with autocast(enabled=scaler.is_enabled()), grad_context():
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        if att_rate != 0.0:
            att_loss = model.module.decoder_forward(encoder_memory,
                                                    memory_mask, supervisions,
                                                    graph_compiler)

        if (ali_model is not None and global_batch_idx_train is not None
                and global_batch_idx_train // accum_grad < 4000):
            with torch.no_grad():
                ali_model_output = ali_model(feature)
            # subsampling is done slightly differently, may be small length
            # differences.
            min_len = min(ali_model_output.shape[2], nnet_output.shape[2])
            # scale less than one so it will be encouraged
            # to mimic ali_model's output
            ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad +
                                       500)
            nnet_output = nnet_output.clone(
            )  # or log-softmax backprop will fail.
            nnet_output[:, :, :
                        min_len] += ali_model_scale * ali_model_output[:, :, :
                                                                       min_len]

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]

        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0:
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        if att_rate != 0.0:
            loss = (-(1.0 - att_rate) * mmi_loss +
                    att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-mmi_loss) / (len(texts) * accum_grad)
        scaler.scale(loss).backward()
        if is_update:
            maybe_log_gradients('train/grad_norms')
            scaler.unscale_(optimizer)
            clip_grad_value_(model.parameters(), 5.0)
            maybe_log_gradients('train/clipped_grad_norms')
            if tb_writer is not None and (global_batch_idx_train //
                                          accum_grad) % 200 == 0:
                # Once in a time we will perform a more costly diagnostic
                # to check the relative parameter change per minibatch.
                deltas = optim_step_and_measure_param_change(
                    model, optimizer, scaler)
                tb_writer.add_scalars(
                    'train/relative_param_change_per_minibatch',
                    deltas,
                    global_step=global_batch_idx_train)
            else:
                scaler.step(optimizer)
            optimizer.zero_grad()
            scaler.update()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #16
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, device: torch.device,
                    graph_compiler: CtcTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, accum_grad: int,
                    att_rate: float, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    """One epoch training and validation.

    Args:
        dataloader: Training dataloader
        valid_dataloader: Validation dataloader
        model: Acoustic model to be trained
        device: Training device, torch.device("cpu") or torch.device("cuda", device_id)
        graph_compiler: MMI training graph compiler
        optimizer: Training optimizer
        accum_grad: Number of gradient accumulation
        att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss
        current_epoch: current training epoch, for logging only
        tb_writer: tensorboard SummaryWriter
        num_epochs: total number of training epochs, for logging only
        global_batch_idx_train: global training batch index before this epoch, for logging only

    Returns:
        A tuple of 3 scalar:  (total_objf / total_frames, valid_average_objf, global_batch_idx_train)
        - `total_objf / total_frames` is the average training loss
        - `valid_average_objf` is the average validation loss
        - `global_batch_idx_train` is the global training batch index after this epoch
    """
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    forward_count = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        forward_count += 1
        if forward_count == accum_grad:
            is_update = True
            forward_count = 0
        else:
            is_update = False

        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()
        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            is_update=is_update,
            accum_grad=accum_grad,
            att_rate=att_rate,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_average_objf',
                                     total_objf / total_frames,
                                     global_batch_idx_train)

                tb_writer.add_scalar(
                    'train/current_batch_average_objf',
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                model.write_tensorboard_diagnostics(
                    tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
コード例 #17
0
ファイル: mmi_mbr_train.py プロジェクト: danpovey/snowfall
def get_loss(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiMbrTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    assert P.device == device
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num_graph, den_graph, decoding_graph = graph_compiler.compile(
                texts, P)

    assert num_graph.requires_grad == is_training
    assert den_graph.requires_grad is False
    assert decoding_graph.requires_grad is False
    assert len(
        decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None)

    num_graph = num_graph.to(device)
    den_graph = den_graph.to(device)

    decoding_graph = decoding_graph.to(device)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num_lats = k2.intersect_dense(num_graph,
                                  dense_fsa_vec,
                                  10.0,
                                  seqframe_idx_name='seqframe_idx')

    mbr_lats = k2.intersect_dense_pruned(decoding_graph,
                                         dense_fsa_vec,
                                         20.0,
                                         7.0,
                                         30,
                                         10000,
                                         seqframe_idx_name='seqframe_idx')

    if True:
        # WARNING: the else branch is not working at present (the total loss is not stable)
        den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0)
    else:
        # in this case, we can remove den_graph
        den_lats = mbr_lats

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    if id(den_lats) == id(mbr_lats):
        # Some entries in den_tot_scores may be -inf.
        # The corresponding sequences are discarded/ignored.
        finite_indexes = torch.isfinite(den_tot_scores)
        den_tot_scores = den_tot_scores[finite_indexes]
        num_tot_scores = num_tot_scores[finite_indexes]
    else:
        finite_indexes = None

    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2],
                                               finite_indexes)

    num_rows = dense_fsa_vec.scores.shape[0]
    num_cols = dense_fsa_vec.scores.shape[1] - 1
    mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx,
                                      cols=num_lats.phones,
                                      values=num_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)

    mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx,
                                      cols=mbr_lats.phones,
                                      values=mbr_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)
    # NOTE: Due to limited support of PyTorch's autograd for sparse tensors,
    # we cannot use (mbr_num_sparse - mbr_den_sparse) here
    #
    # The following works only for torch >= 1.7.0
    mbr_loss = torch.sparse.sum(
        k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce()))

    mmi_loss = -tot_score

    total_loss = mmi_loss + mbr_loss

    if is_training:
        optimizer.zero_grad()
        total_loss.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = (
        mmi_loss.detach().cpu().item(),
        mbr_loss.detach().cpu().item(),
        tot_frames.cpu().item(),
        all_frames.cpu().item(),
    )
    return ans
コード例 #18
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int, global_batch_idx_valid: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    ragged_shape = P.arcs.shape().to(device)
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        P.set_scores_stochastic_(model.P_scores)
        assert P.is_cpu
        assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \
            get_objf(batch, model, P, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            global_batch_idx_valid += 1
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(total_valid_objf / total_valid_frames,
                        total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_objf',
                                 total_valid_objf / total_valid_frames,
                                 global_batch_idx_valid)
        prev_timestamp = datetime.now()
    return total_objf / total_frames
コード例 #19
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel,
                    ali_model: Optional[AcousticModel],
                    P: k2.Fsa,
                    device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    use_pruned_intersect: bool,
                    optimizer: torch.optim.Optimizer,
                    accum_grad: int,
                    den_scale: float,
                    att_rate: float,
                    current_epoch: int,
                    tb_writer: SummaryWriter,
                    num_epochs: int,
                    global_batch_idx_train: int,
                    world_size: int,
                    scaler: GradScaler
                    ):
    """One epoch training and validation.

    Args:
        dataloader: Training dataloader
        valid_dataloader: Validation dataloader
        model: Acoustic model to be trained
        P: An FSA representing the bigram phone LM
        device: Training device, torch.device("cpu") or torch.device("cuda", device_id)
        graph_compiler: MMI training graph compiler
        optimizer: Training optimizer
        accum_grad: Number of gradient accumulation
        den_scale: Denominator scale in mmi loss
        att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss
        current_epoch: current training epoch, for logging only
        tb_writer: tensorboard SummaryWriter
        num_epochs: total number of training epochs, for logging only
        global_batch_idx_train: global training batch index before this epoch, for logging only

    Returns:
        A tuple of 3 scalar:  (total_objf / total_frames, valid_average_objf, global_batch_idx_train)
        - `total_objf / total_frames` is the average training loss
        - `valid_average_objf` is the average validation loss
        - `global_batch_idx_train` is the global training batch index after this epoch
    """
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    forward_count = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        forward_count += 1
        if forward_count == accum_grad:
            is_update = True
            forward_count = 0
        else:
            is_update = False

        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        if forward_count == 1 or accum_grad == 1:
            P.set_scores_stochastic_(model.module.P_scores)
            assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            ali_model=ali_model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            use_pruned_intersect=use_pruned_intersect,
            is_training=True,
            is_update=is_update,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer,
            scaler=scaler
        )

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_average_objf',
                                     total_objf / total_frames, global_batch_idx_train)

                tb_writer.add_scalar('train/current_batch_average_objf',
                                     curr_batch_objf / (curr_batch_frames + 0.001),
                                     global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                ali_model=ali_model,
                P=P,
                device=device,
                graph_compiler=graph_compiler,
                use_pruned_intersect=use_pruned_intersect,
                scaler=scaler)
            if world_size > 1:
                s = torch.tensor([
                    total_valid_objf, total_valid_frames,
                    total_valid_all_frames
                ]).to(device)

                dist.all_reduce(s, op=dist.ReduceOp.SUM)
                total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu().tolist()

            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                    .format(valid_average_objf,
                            total_valid_frames,
                            100.0 * total_valid_frames / total_valid_all_frames))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
コード例 #20
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['inputs']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    assert feature.ndim == 3
    feature = feature.to(device)

    try:
        subsampling_factor = model.subsampling_factor
    except:
        subsampling_factor = model.module.subsampling_factor

    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions,
                                                      subsampling_factor)

    loss_fn = LFMMILoss(graph_compiler=graph_compiler, den_scale=den_scale)

    grad_context = nullcontext if is_training else torch.no_grad

    with grad_context():
        nnet_output = model(feature)
        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]
        mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                   supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-mmi_loss).backward()

        for name, param in model.named_parameters():
            if param.grad is None:
                print(name)

        maybe_log_gradients('train/grad_norms')
        #clip_grad_value_(model.parameters(), 5.0)
        clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #21
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            print(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_objf',
                                 valid_average_objf, global_batch_idx_train)
            model.write_tensorboard_diagnostics(
                tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
コード例 #22
0
ファイル: mmi_mbr_train.py プロジェクト: danpovey/snowfall
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiMbrTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames = 0., 0., 0., 0., 0.
    valid_average_loss = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        P.set_scores_stochastic_(model.P_scores)
        assert P.requires_grad is True

        curr_batch_mmi_loss, curr_batch_mbr_loss, curr_batch_frames, curr_batch_all_frames = get_loss(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            optimizer=optimizer)

        total_mmi_loss += curr_batch_mmi_loss
        total_mbr_loss += curr_batch_mbr_loss
        curr_batch_loss = curr_batch_mmi_loss + curr_batch_mbr_loss
        total_loss += curr_batch_loss
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info('batch {}, epoch {}/{} '
                         'global average loss: {:.6f}, '
                         'global average mmi loss: {:.6f}, '
                         'global average mbr loss: {:.6f} over {} '
                         'frames ({:.1f}% kept), '
                         'current batch average loss: {:.6f}, '
                         'current batch average mmi loss: {:.6f}, '
                         'current batch average mbr loss: {:.6f} '
                         'over {} frames ({:.1f}% kept) '
                         'avg time waiting for batch {:.3f}s'.format(
                             batch_idx, current_epoch, num_epochs, total_loss /
                             total_frames, total_mmi_loss / total_frames,
                             total_mbr_loss / total_frames, total_frames,
                             100.0 * total_frames / total_all_frames,
                             curr_batch_loss / (curr_batch_frames + 0.001),
                             curr_batch_mmi_loss / (curr_batch_frames + 0.001),
                             curr_batch_mbr_loss / (curr_batch_frames + 0.001),
                             curr_batch_frames,
                             100.0 * curr_batch_frames / curr_batch_all_frames,
                             time_waiting_for_batch / max(1, batch_idx)))

            tb_writer.add_scalar('train/global_average_loss',
                                 total_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_average_mmi_loss',
                                 total_mmi_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_average_mbr_loss',
                                 total_mbr_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_loss',
                                 curr_batch_loss / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)

            tb_writer.add_scalar(
                'train/current_batch_average_mmi_loss',
                curr_batch_mmi_loss / (curr_batch_frames + 0.001),
                global_batch_idx_train)

            tb_writer.add_scalar(
                'train/current_batch_average_mbr_loss',
                curr_batch_mbr_loss / (curr_batch_frames + 0.001),
                global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 3000 == 0:
            total_valid_loss, total_valid_mmi_loss, total_valid_mbr_loss, \
                    total_valid_frames, total_valid_all_frames = get_validation_loss(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_loss = total_valid_loss / total_valid_frames
            model.train()
            logging.info('Validation average loss: {:.6f}, '
                         'Validation average mmi loss: {:.6f}, '
                         'Validation average mbr loss: {:.6f} '
                         'over {} frames ({:.1f}% kept)'.format(
                             total_valid_loss / total_valid_frames,
                             total_valid_mmi_loss / total_valid_frames,
                             total_valid_mbr_loss / total_valid_frames,
                             total_valid_frames, 100.0 * total_valid_frames /
                             total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_loss',
                                 total_valid_loss / total_valid_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_valid_average_mmi_loss',
                                 total_valid_mmi_loss / total_valid_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_valid_average_mbr_loss',
                                 total_valid_mbr_loss / total_valid_frames,
                                 global_batch_idx_train)

        prev_timestamp = datetime.now()
    return total_loss / total_frames, valid_average_loss, global_batch_idx_train
コード例 #23
0
ファイル: ctc_train.py プロジェクト: underdogliu/snowfall
def train_one_epoch(
    dataloader: torch.utils.data.DataLoader,
    valid_dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: torch.device,
    graph_compiler: CtcTrainingGraphCompiler,
    optimizer: torch.optim.Optimizer,
    current_epoch: int,
    tb_writer: SummaryWriter,
    num_epochs: int,
    global_batch_idx_train: int,
):
    total_objf, total_frames, total_all_frames = 0.0, 0.0, 0.0
    valid_average_objf = float("inf")
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()
        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch, model, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                "batch {}, epoch {}/{} "
                "global average objf: {:.6f} over {} "
                "frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) "
                "avg time waiting for batch {:.3f}s".format(
                    batch_idx,
                    current_epoch,
                    num_epochs,
                    total_objf / total_frames,
                    total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx),
                ))

            tb_writer.add_scalar(
                "train/global_average_objf",
                total_objf / total_frames,
                global_batch_idx_train,
            )

            tb_writer.add_scalar(
                "train/current_batch_average_objf",
                curr_batch_objf / (curr_batch_frames + 0.001),
                global_batch_idx_train,
            )
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            (
                total_valid_objf,
                total_valid_frames,
                total_valid_all_frames,
            ) = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler,
            )
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                "Validation average objf: {:.6f} over {} frames ({:.1f}% kept)"
                .format(
                    valid_average_objf,
                    total_valid_frames,
                    100.0 * total_valid_frames / total_valid_all_frames,
                ))

            tb_writer.add_scalar(
                "train/global_valid_average_objf",
                valid_average_objf,
                global_batch_idx_train,
            )
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
コード例 #24
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             att_rate: float = 0.0,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
         (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32)
    supervision_segments = torch.clamp(supervision_segments, min=0)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments)
        if att_rate != 0.0:
            att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler)
    else:
        with torch.no_grad():
            nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments)
            if att_rate != 0.0:
                att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device

    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = target_graph.get_tot_scores(
        log_semiring=True,
        use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:
        if att_rate != 0.0:
            loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-tot_score) / (len(texts) * accum_grad)
        loss.backward()
        if is_update:
            clip_grad_value_(model.parameters(), 5.0)
            optimizer.step()
            optimizer.zero_grad()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #25
0
ファイル: ctc_train.py プロジェクト: underdogliu/snowfall
def get_objf(
    batch: Dict,
    model: AcousticModel,
    device: torch.device,
    graph_compiler: CtcTrainingGraphCompiler,
    training: bool,
    optimizer: Optional[torch.optim.Optimizer] = None,
):
    feature = batch["inputs"]
    supervisions = batch["supervisions"]
    supervision_segments = torch.stack(
        (
            supervisions["sequence_idx"],
            torch.floor_divide(supervisions["start_frame"],
                               model.subsampling_factor),
            torch.floor_divide(supervisions["num_frames"],
                               model.subsampling_factor),
        ),
        1,
    ).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions["text"]
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device

    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = target_graph.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = (
        -tot_score.detach().cpu().item(),
        tot_frames.cpu().item(),
        all_frames.cpu().item(),
    )
    return ans
コード例 #26
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             att_rate: float = 0.0,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['inputs']
    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    supervisions = batch['supervisions']
    supervision_segments, texts = encode_supervisions(supervisions)

    loss_fn = CTCLoss(graph_compiler)
    grad_context = nullcontext if is_training else torch.no_grad

    with grad_context():
        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
        if att_rate != 0.0:
            att_loss = model.decoder_forward(encoder_memory, memory_mask,
                                             supervisions, graph_compiler)

        # nnet_output is [N, C, T]
        nnet_output = nnet_output.permute(0, 2,
                                          1)  # now nnet_output is [N, T, C]
        tot_score, tot_frames, all_frames = loss_fn(nnet_output, texts,
                                                    supervision_segments)

    if is_training:

        def maybe_log_gradients(tag: str):
            if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0:
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        if att_rate != 0.0:
            loss = (-(1.0 - att_rate) * tot_score +
                    att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-tot_score) / (len(texts) * accum_grad)
        loss.backward()
        if is_update:
            maybe_log_gradients('train/grad_norms')
            clip_grad_value_(model.parameters(), 5.0)
            maybe_log_gradients('train/clipped_grad_norms')
            if tb_writer is not None and (global_batch_idx_train //
                                          accum_grad) % 200 == 0:
                # Once in a time we will perform a more costly diagnostic
                # to check the relative parameter change per minibatch.
                deltas = optim_step_and_measure_param_change(model, optimizer)
                tb_writer.add_scalars(
                    'train/relative_param_change_per_minibatch',
                    deltas,
                    global_step=global_batch_idx_train)
            else:
                optimizer.step()
            optimizer.zero_grad()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #27
0
ファイル: mmi_bigram_train.py プロジェクト: hegc/snowfall
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: Optional[SummaryWriter], num_epochs: int,
                    global_batch_idx_train: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        if isinstance(model, DDP):
            P.set_scores_stochastic_(model.module.P_scores)
        else:
            P.set_scores_stochastic_(model.P_scores)
        assert P.is_cpu
        assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0 and dist.get_rank() == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 1000 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            # Synchronize the loss to the master node so that we display it correctly.
            # dist.reduce performs sum reduction by default.
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            if dist.get_rank() == 0:
                logging.info(
                    'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                    .format(
                        valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))
            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                (model.module if isinstance(model, DDP) else
                 model).write_tensorboard_diagnostics(
                     tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
コード例 #28
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device
    # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
    # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = k2.get_tot_scores(target_graph,
                                   log_semiring=True,
                                   use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
コード例 #29
0
ファイル: mmi_bigram_train.py プロジェクト: hegc/snowfall
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    subsampling_factor = model.module.subsampling_factor if isinstance(
        model, DDP) else model.subsampling_factor
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'], subsampling_factor),
         torch.floor_divide(supervisions['num_frames'], subsampling_factor)),
        1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-tot_score).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans