def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: TrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, num_epochs: int): total_objf, total_frames, total_all_frames = 0., 0., 0. time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ get_objf(batch, model, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler) model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(total_valid_objf / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) prev_timestamp = datetime.now() return total_objf
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, optimizer: torch.optim.Optimizer, accum_grad: int, att_rate: float, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): """One epoch training and validation. Args: dataloader: Training dataloader valid_dataloader: Validation dataloader model: Acoustic model to be trained device: Training device, torch.device("cpu") or torch.device("cuda", device_id) graph_compiler: MMI training graph compiler optimizer: Training optimizer accum_grad: Number of gradient accumulation att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss current_epoch: current training epoch, for logging only tb_writer: tensorboard SummaryWriter num_epochs: total number of training epochs, for logging only global_batch_idx_train: global training batch index before this epoch, for logging only Returns: A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) - `total_objf / total_frames` is the average training loss - `valid_average_objf` is the average validation loss - `global_batch_idx_train` is the global training batch index after this epoch """ total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 forward_count = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): forward_count += 1 if forward_count == accum_grad: is_update = True forward_count = 0 else: is_update = False global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, device=device, graph_compiler=graph_compiler, is_training=True, is_update=is_update, accum_grad=accum_grad, att_rate=att_rate, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) if tb_writer is not None: tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler) valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames = 0., 0., 0., 0., 0. valid_average_loss = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() P.set_scores_stochastic_(model.P_scores) assert P.requires_grad is True curr_batch_mmi_loss, curr_batch_mbr_loss, curr_batch_frames, curr_batch_all_frames = get_loss( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, optimizer=optimizer) total_mmi_loss += curr_batch_mmi_loss total_mbr_loss += curr_batch_mbr_loss curr_batch_loss = curr_batch_mmi_loss + curr_batch_mbr_loss total_loss += curr_batch_loss total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info('batch {}, epoch {}/{} ' 'global average loss: {:.6f}, ' 'global average mmi loss: {:.6f}, ' 'global average mbr loss: {:.6f} over {} ' 'frames ({:.1f}% kept), ' 'current batch average loss: {:.6f}, ' 'current batch average mmi loss: {:.6f}, ' 'current batch average mbr loss: {:.6f} ' 'over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_loss / total_frames, total_mmi_loss / total_frames, total_mbr_loss / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_loss / (curr_batch_frames + 0.001), curr_batch_mmi_loss / (curr_batch_frames + 0.001), curr_batch_mbr_loss / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_loss', total_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_average_mmi_loss', total_mmi_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_average_mbr_loss', total_mbr_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_loss', curr_batch_loss / (curr_batch_frames + 0.001), global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_mmi_loss', curr_batch_mmi_loss / (curr_batch_frames + 0.001), global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_mbr_loss', curr_batch_mbr_loss / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 3000 == 0: total_valid_loss, total_valid_mmi_loss, total_valid_mbr_loss, \ total_valid_frames, total_valid_all_frames = get_validation_loss( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) valid_average_loss = total_valid_loss / total_valid_frames model.train() logging.info('Validation average loss: {:.6f}, ' 'Validation average mmi loss: {:.6f}, ' 'Validation average mbr loss: {:.6f} ' 'over {} frames ({:.1f}% kept)'.format( total_valid_loss / total_valid_frames, total_valid_mmi_loss / total_valid_frames, total_valid_mbr_loss / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_loss', total_valid_loss / total_valid_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_valid_average_mmi_loss', total_valid_mmi_loss / total_valid_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_valid_average_mbr_loss', total_valid_mbr_loss / total_valid_frames, global_batch_idx_train) prev_timestamp = datetime.now() return total_loss / total_frames, valid_average_loss, global_batch_idx_train
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: Optional[SummaryWriter], num_epochs: int, global_batch_idx_train: int): total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() if isinstance(model, DDP): P.set_scores_stochastic_(model.module.P_scores) else: P.set_scores_stochastic_(model.P_scores) assert P.is_cpu assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0 and dist.get_rank() == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 1000 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) # Synchronize the loss to the master node so that we display it correctly. # dist.reduce performs sum reduction by default. valid_average_objf = total_valid_objf / total_valid_frames model.train() if dist.get_rank() == 0: logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format( valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) (model.module if isinstance(model, DDP) else model).write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def train_one_epoch( dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, ): total_objf, total_frames, total_all_frames = 0.0, 0.0, 0.0 valid_average_objf = float("inf") time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch, model, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( "batch {}, epoch {}/{} " "global average objf: {:.6f} over {} " "frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) " "avg time waiting for batch {:.3f}s".format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx), )) tb_writer.add_scalar( "train/global_average_objf", total_objf / total_frames, global_batch_idx_train, ) tb_writer.add_scalar( "train/current_batch_average_objf", curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train, ) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: ( total_valid_objf, total_valid_frames, total_valid_all_frames, ) = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler, ) valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( "Validation average objf: {:.6f} over {} frames ({:.1f}% kept)" .format( valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames, )) tb_writer.add_scalar( "train/global_valid_average_objf", valid_average_objf, global_batch_idx_train, ) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, global_batch_idx_valid: int): total_objf, total_frames, total_all_frames = 0., 0., 0. time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() ragged_shape = P.arcs.shape().to(device) for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() P.set_scores_stochastic_(model.P_scores) assert P.is_cpu assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ get_objf(batch, model, P, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) global_batch_idx_valid += 1 model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(total_valid_objf / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_objf', total_valid_objf / total_valid_frames, global_batch_idx_valid) prev_timestamp = datetime.now() return total_objf / total_frames
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, ali_model: Optional[AcousticModel], P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool, optimizer: torch.optim.Optimizer, accum_grad: int, den_scale: float, att_rate: float, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, world_size: int, scaler: GradScaler ): """One epoch training and validation. Args: dataloader: Training dataloader valid_dataloader: Validation dataloader model: Acoustic model to be trained P: An FSA representing the bigram phone LM device: Training device, torch.device("cpu") or torch.device("cuda", device_id) graph_compiler: MMI training graph compiler optimizer: Training optimizer accum_grad: Number of gradient accumulation den_scale: Denominator scale in mmi loss att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss current_epoch: current training epoch, for logging only tb_writer: tensorboard SummaryWriter num_epochs: total number of training epochs, for logging only global_batch_idx_train: global training batch index before this epoch, for logging only Returns: A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) - `total_objf / total_frames` is the average training loss - `valid_average_objf` is the average validation loss - `global_batch_idx_train` is the global training batch index after this epoch """ total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 forward_count = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): forward_count += 1 if forward_count == accum_grad: is_update = True forward_count = 0 else: is_update = False global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() if forward_count == 1 or accum_grad == 1: P.set_scores_stochastic_(model.module.P_scores) assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, is_training=True, is_update=is_update, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer, scaler=scaler ) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) if tb_writer is not None: tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, scaler=scaler) if world_size > 1: s = torch.tensor([ total_valid_objf, total_valid_frames, total_valid_all_frames ]).to(device) dist.all_reduce(s, op=dist.ReduceOp.SUM) total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu().tolist() valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, device=device, graph_compiler=graph_compiler, is_training=True, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) print( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler) valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train