示例#1
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, device: torch.device,
                    graph_compiler: TrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    num_epochs: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()
        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \
            get_objf(batch, model, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler)
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(total_valid_objf / total_valid_frames,
                        total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))
        prev_timestamp = datetime.now()
    return total_objf
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, device: torch.device,
                    graph_compiler: CtcTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, accum_grad: int,
                    att_rate: float, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    """One epoch training and validation.

    Args:
        dataloader: Training dataloader
        valid_dataloader: Validation dataloader
        model: Acoustic model to be trained
        device: Training device, torch.device("cpu") or torch.device("cuda", device_id)
        graph_compiler: MMI training graph compiler
        optimizer: Training optimizer
        accum_grad: Number of gradient accumulation
        att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss
        current_epoch: current training epoch, for logging only
        tb_writer: tensorboard SummaryWriter
        num_epochs: total number of training epochs, for logging only
        global_batch_idx_train: global training batch index before this epoch, for logging only

    Returns:
        A tuple of 3 scalar:  (total_objf / total_frames, valid_average_objf, global_batch_idx_train)
        - `total_objf / total_frames` is the average training loss
        - `valid_average_objf` is the average validation loss
        - `global_batch_idx_train` is the global training batch index after this epoch
    """
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    forward_count = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        forward_count += 1
        if forward_count == accum_grad:
            is_update = True
            forward_count = 0
        else:
            is_update = False

        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()
        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            is_update=is_update,
            accum_grad=accum_grad,
            att_rate=att_rate,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_average_objf',
                                     total_objf / total_frames,
                                     global_batch_idx_train)

                tb_writer.add_scalar(
                    'train/current_batch_average_objf',
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                model.write_tensorboard_diagnostics(
                    tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
示例#3
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiMbrTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames = 0., 0., 0., 0., 0.
    valid_average_loss = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        P.set_scores_stochastic_(model.P_scores)
        assert P.requires_grad is True

        curr_batch_mmi_loss, curr_batch_mbr_loss, curr_batch_frames, curr_batch_all_frames = get_loss(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            optimizer=optimizer)

        total_mmi_loss += curr_batch_mmi_loss
        total_mbr_loss += curr_batch_mbr_loss
        curr_batch_loss = curr_batch_mmi_loss + curr_batch_mbr_loss
        total_loss += curr_batch_loss
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info('batch {}, epoch {}/{} '
                         'global average loss: {:.6f}, '
                         'global average mmi loss: {:.6f}, '
                         'global average mbr loss: {:.6f} over {} '
                         'frames ({:.1f}% kept), '
                         'current batch average loss: {:.6f}, '
                         'current batch average mmi loss: {:.6f}, '
                         'current batch average mbr loss: {:.6f} '
                         'over {} frames ({:.1f}% kept) '
                         'avg time waiting for batch {:.3f}s'.format(
                             batch_idx, current_epoch, num_epochs, total_loss /
                             total_frames, total_mmi_loss / total_frames,
                             total_mbr_loss / total_frames, total_frames,
                             100.0 * total_frames / total_all_frames,
                             curr_batch_loss / (curr_batch_frames + 0.001),
                             curr_batch_mmi_loss / (curr_batch_frames + 0.001),
                             curr_batch_mbr_loss / (curr_batch_frames + 0.001),
                             curr_batch_frames,
                             100.0 * curr_batch_frames / curr_batch_all_frames,
                             time_waiting_for_batch / max(1, batch_idx)))

            tb_writer.add_scalar('train/global_average_loss',
                                 total_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_average_mmi_loss',
                                 total_mmi_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_average_mbr_loss',
                                 total_mbr_loss / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_loss',
                                 curr_batch_loss / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)

            tb_writer.add_scalar(
                'train/current_batch_average_mmi_loss',
                curr_batch_mmi_loss / (curr_batch_frames + 0.001),
                global_batch_idx_train)

            tb_writer.add_scalar(
                'train/current_batch_average_mbr_loss',
                curr_batch_mbr_loss / (curr_batch_frames + 0.001),
                global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 3000 == 0:
            total_valid_loss, total_valid_mmi_loss, total_valid_mbr_loss, \
                    total_valid_frames, total_valid_all_frames = get_validation_loss(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_loss = total_valid_loss / total_valid_frames
            model.train()
            logging.info('Validation average loss: {:.6f}, '
                         'Validation average mmi loss: {:.6f}, '
                         'Validation average mbr loss: {:.6f} '
                         'over {} frames ({:.1f}% kept)'.format(
                             total_valid_loss / total_valid_frames,
                             total_valid_mmi_loss / total_valid_frames,
                             total_valid_mbr_loss / total_valid_frames,
                             total_valid_frames, 100.0 * total_valid_frames /
                             total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_loss',
                                 total_valid_loss / total_valid_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_valid_average_mmi_loss',
                                 total_valid_mmi_loss / total_valid_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/global_valid_average_mbr_loss',
                                 total_valid_mbr_loss / total_valid_frames,
                                 global_batch_idx_train)

        prev_timestamp = datetime.now()
    return total_loss / total_frames, valid_average_loss, global_batch_idx_train
示例#4
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: Optional[SummaryWriter], num_epochs: int,
                    global_batch_idx_train: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        if isinstance(model, DDP):
            P.set_scores_stochastic_(model.module.P_scores)
        else:
            P.set_scores_stochastic_(model.P_scores)
        assert P.is_cpu
        assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0 and dist.get_rank() == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 1000 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            # Synchronize the loss to the master node so that we display it correctly.
            # dist.reduce performs sum reduction by default.
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            if dist.get_rank() == 0:
                logging.info(
                    'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                    .format(
                        valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))
            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                (model.module if isinstance(model, DDP) else
                 model).write_tensorboard_diagnostics(
                     tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
示例#5
0
def train_one_epoch(
    dataloader: torch.utils.data.DataLoader,
    valid_dataloader: torch.utils.data.DataLoader,
    model: AcousticModel,
    device: torch.device,
    graph_compiler: CtcTrainingGraphCompiler,
    optimizer: torch.optim.Optimizer,
    current_epoch: int,
    tb_writer: SummaryWriter,
    num_epochs: int,
    global_batch_idx_train: int,
):
    total_objf, total_frames, total_all_frames = 0.0, 0.0, 0.0
    valid_average_objf = float("inf")
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()
        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch, model, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                "batch {}, epoch {}/{} "
                "global average objf: {:.6f} over {} "
                "frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) "
                "avg time waiting for batch {:.3f}s".format(
                    batch_idx,
                    current_epoch,
                    num_epochs,
                    total_objf / total_frames,
                    total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx),
                ))

            tb_writer.add_scalar(
                "train/global_average_objf",
                total_objf / total_frames,
                global_batch_idx_train,
            )

            tb_writer.add_scalar(
                "train/current_batch_average_objf",
                curr_batch_objf / (curr_batch_frames + 0.001),
                global_batch_idx_train,
            )
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            (
                total_valid_objf,
                total_valid_frames,
                total_valid_all_frames,
            ) = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler,
            )
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                "Validation average objf: {:.6f} over {} frames ({:.1f}% kept)"
                .format(
                    valid_average_objf,
                    total_valid_frames,
                    100.0 * total_valid_frames / total_valid_all_frames,
                ))

            tb_writer.add_scalar(
                "train/global_valid_average_objf",
                valid_average_objf,
                global_batch_idx_train,
            )
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
示例#6
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, P: k2.Fsa, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int, global_batch_idx_valid: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    ragged_shape = P.arcs.shape().to(device)
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        P.set_scores_stochastic_(model.P_scores)
        assert P.is_cpu
        assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \
            get_objf(batch, model, P, device, graph_compiler, True, optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                P=P,
                device=device,
                graph_compiler=graph_compiler)
            global_batch_idx_valid += 1
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(total_valid_objf / total_valid_frames,
                        total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_objf',
                                 total_valid_objf / total_valid_frames,
                                 global_batch_idx_valid)
        prev_timestamp = datetime.now()
    return total_objf / total_frames
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel,
                    ali_model: Optional[AcousticModel],
                    P: k2.Fsa,
                    device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    use_pruned_intersect: bool,
                    optimizer: torch.optim.Optimizer,
                    accum_grad: int,
                    den_scale: float,
                    att_rate: float,
                    current_epoch: int,
                    tb_writer: SummaryWriter,
                    num_epochs: int,
                    global_batch_idx_train: int,
                    world_size: int,
                    scaler: GradScaler
                    ):
    """One epoch training and validation.

    Args:
        dataloader: Training dataloader
        valid_dataloader: Validation dataloader
        model: Acoustic model to be trained
        P: An FSA representing the bigram phone LM
        device: Training device, torch.device("cpu") or torch.device("cuda", device_id)
        graph_compiler: MMI training graph compiler
        optimizer: Training optimizer
        accum_grad: Number of gradient accumulation
        den_scale: Denominator scale in mmi loss
        att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss
        current_epoch: current training epoch, for logging only
        tb_writer: tensorboard SummaryWriter
        num_epochs: total number of training epochs, for logging only
        global_batch_idx_train: global training batch index before this epoch, for logging only

    Returns:
        A tuple of 3 scalar:  (total_objf / total_frames, valid_average_objf, global_batch_idx_train)
        - `total_objf / total_frames` is the average training loss
        - `valid_average_objf` is the average validation loss
        - `global_batch_idx_train` is the global training batch index after this epoch
    """
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    forward_count = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        forward_count += 1
        if forward_count == accum_grad:
            is_update = True
            forward_count = 0
        else:
            is_update = False

        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        if forward_count == 1 or accum_grad == 1:
            P.set_scores_stochastic_(model.module.P_scores)
            assert P.requires_grad is True

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            ali_model=ali_model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            use_pruned_intersect=use_pruned_intersect,
            is_training=True,
            is_update=is_update,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer,
            scaler=scaler
        )

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_average_objf',
                                     total_objf / total_frames, global_batch_idx_train)

                tb_writer.add_scalar('train/current_batch_average_objf',
                                     curr_batch_objf / (curr_batch_frames + 0.001),
                                     global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                ali_model=ali_model,
                P=P,
                device=device,
                graph_compiler=graph_compiler,
                use_pruned_intersect=use_pruned_intersect,
                scaler=scaler)
            if world_size > 1:
                s = torch.tensor([
                    total_valid_objf, total_valid_frames,
                    total_valid_all_frames
                ]).to(device)

                dist.all_reduce(s, op=dist.ReduceOp.SUM)
                total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu().tolist()

            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                    .format(valid_average_objf,
                            total_valid_frames,
                            100.0 * total_valid_frames / total_valid_all_frames))

            if tb_writer is not None:
                tb_writer.add_scalar('train/global_valid_average_objf',
                                     valid_average_objf,
                                     global_batch_idx_train)
                model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train
示例#8
0
def train_one_epoch(dataloader: torch.utils.data.DataLoader,
                    valid_dataloader: torch.utils.data.DataLoader,
                    model: AcousticModel, device: torch.device,
                    graph_compiler: MmiTrainingGraphCompiler,
                    optimizer: torch.optim.Optimizer, current_epoch: int,
                    tb_writer: SummaryWriter, num_epochs: int,
                    global_batch_idx_train: int):
    total_objf, total_frames, total_all_frames = 0., 0., 0.
    valid_average_objf = float('inf')
    time_waiting_for_batch = 0
    prev_timestamp = datetime.now()

    model.train()
    for batch_idx, batch in enumerate(dataloader):
        global_batch_idx_train += 1
        timestamp = datetime.now()
        time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

        curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
            batch=batch,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            is_training=True,
            tb_writer=tb_writer,
            global_batch_idx_train=global_batch_idx_train,
            optimizer=optimizer)

        total_objf += curr_batch_objf
        total_frames += curr_batch_frames
        total_all_frames += curr_batch_all_frames

        if batch_idx % 10 == 0:
            logging.info(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))

            print(
                'batch {}, epoch {}/{} '
                'global average objf: {:.6f} over {} '
                'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) '
                'avg time waiting for batch {:.3f}s'.format(
                    batch_idx, current_epoch, num_epochs,
                    total_objf / total_frames, total_frames,
                    100.0 * total_frames / total_all_frames,
                    curr_batch_objf / (curr_batch_frames + 0.001),
                    curr_batch_frames,
                    100.0 * curr_batch_frames / curr_batch_all_frames,
                    time_waiting_for_batch / max(1, batch_idx)))
            tb_writer.add_scalar('train/global_average_objf',
                                 total_objf / total_frames,
                                 global_batch_idx_train)

            tb_writer.add_scalar('train/current_batch_average_objf',
                                 curr_batch_objf / (curr_batch_frames + 0.001),
                                 global_batch_idx_train)
            # if batch_idx >= 10:
            #    print("Exiting early to get profile info")
            #    sys.exit(0)

        if batch_idx > 0 and batch_idx % 200 == 0:
            total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf(
                dataloader=valid_dataloader,
                model=model,
                device=device,
                graph_compiler=graph_compiler)
            valid_average_objf = total_valid_objf / total_valid_frames
            model.train()
            logging.info(
                'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)'
                .format(valid_average_objf, total_valid_frames,
                        100.0 * total_valid_frames / total_valid_all_frames))

            tb_writer.add_scalar('train/global_valid_average_objf',
                                 valid_average_objf, global_batch_idx_train)
            model.write_tensorboard_diagnostics(
                tb_writer, global_step=global_batch_idx_train)
        prev_timestamp = datetime.now()
    return total_objf / total_frames, valid_average_objf, global_batch_idx_train