def main(): # parse arguments args = parse_agrs() # fix random seeds torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(args.seed) # create tokenizer tokenizer = Tokenizer(args) # create data loader train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) # build model architecture model = R2GenModel(args, tokenizer) # get function handles of loss and metrics criterion = compute_loss metrics = compute_scores # build optimizer, learning rate scheduler optimizer = build_optimizer(args, model) lr_scheduler = build_lr_scheduler(args, optimizer) # build trainer and start to train trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) trainer.train()
def train(hparams, distributed_run=False, rank=0, n_gpus=None): """Training and validation logging results to tensorboard and stdout """ if distributed_run: assert n_gpus is not None torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) model = load_model(hparams, distributed_run) optimizer = build_optimizer(model, hparams) lr_scheduler = build_scheduler(optimizer, hparams) criterion = OverallLoss(hparams) if hparams.fp16_run: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level="O2") if distributed_run: model = apply_gradient_allreduce(model) logger = prepare_directories_and_logger(hparams.output_dir, hparams.log_dir, rank) copyfile(hparams.path, os.path.join(hparams.output_dir, 'hparams.yaml')) train_loader, valset, collate_fn = prepare_dataloaders( hparams, distributed_run) # Load checkpoint if one exists iteration = 0 epoch_offset = 0 if hparams.checkpoint is not None: if hparams.warm_start: model = warm_start_model(hparams.checkpoint, model, hparams.ignore_layers) else: model, optimizer, lr_scheduler, mmi_criterion, iteration = load_checkpoint( hparams.checkpoint, model, optimizer, lr_scheduler, criterion, hparams.restore_scheduler_state) iteration += 1 # next iteration is iteration + 1 epoch_offset = max(0, int(iteration / len(train_loader))) model.train() is_overflow = False # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): torch.cuda.empty_cache() start = time.perf_counter() model.zero_grad() inputs, alignments, inputs_ctc = model.parse_batch(batch) outputs, decoder_outputs = model(inputs) losses = criterion(outputs, inputs, alignments=alignments, inputs_ctc=inputs_ctc, decoder_outputs=decoder_outputs) if hparams.use_mmi and hparams.use_gaf and i % gradient_adaptive_factor.UPDATE_GAF_EVERY_N_STEP == 0: mi_loss = losses["mi/loss"] overall_loss = losses["overall/loss"] gaf = calc_gaf(model, optimizer, overall_loss, mi_loss, hparams.max_gaf) losses["mi/loss"] = gaf * mi_loss losses["overall/loss"] = overall_loss - mi_loss * (1 - gaf) reduced_losses = { key: reduce_loss(value, distributed_run, n_gpus) for key, value in losses.items() } loss = losses["overall/loss"] if hparams.fp16_run: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), hparams.grad_clip_thresh) is_overflow = math.isnan(grad_norm) else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hparams.grad_clip_thresh) optimizer.step() if not is_overflow and rank == 0: learning_rate = lr_scheduler.get_last_lr()[0] duration = time.perf_counter() - start print( "Iteration {}: overall loss {:.6f} Grad Norm {:.6f} {:.2f}s/it LR {:.3E}" .format(iteration, reduced_losses["overall/loss"], grad_norm, duration, learning_rate)) logger.log_training(reduced_losses, grad_norm, learning_rate, duration, iteration) if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): val_loss = validate(model, criterion, valset, iteration, hparams.batch_size, collate_fn, logger, distributed_run, rank, n_gpus) if rank == 0: checkpoint = os.path.join( hparams.output_dir, "checkpoint_{}".format(iteration)) save_checkpoint(model, optimizer, lr_scheduler, criterion, iteration, hparams, checkpoint) iteration += 1 if hparams.lr_scheduler == SchedulerTypes.cyclic: lr_scheduler.step() if not hparams.lr_scheduler == SchedulerTypes.cyclic: # TODO: для plateau ошибка валидации должна рассчитываться в конце каждой эпохи, по-хорошему scheduler_args = ( ) if hparams.lr_scheduler != SchedulerTypes.plateau else ( val_loss, ) lr_scheduler.step(*scheduler_args)
def train(hparams, distributed_run=False, rank=0, n_gpus=None): """Training and validation logging results to tensorboard and stdout """ if distributed_run: assert n_gpus is not None torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) model = load_model(hparams, distributed_run) criterion = OverallLoss(hparams) if criterion.mmi_criterion is not None: parameters = chain(model.parameters(), criterion.mmi_criterion.parameters()) else: parameters = model.parameters() optimizer = build_optimizer(parameters, hparams) lr_scheduler = build_scheduler(optimizer, hparams) if distributed_run: model = apply_gradient_allreduce(model) scaler = amp.GradScaler(enabled=hparams.fp16_run) logger = prepare_directories_and_logger(hparams.output_dir, hparams.log_dir, rank) copyfile(hparams.path, os.path.join(hparams.output_dir, 'hparams.yaml')) train_loader, valset, collate_fn = prepare_dataloaders( hparams, distributed_run) # Load checkpoint if one exists iteration = 0 epoch_offset = 0 if hparams.checkpoint is not None: if hparams.warm_start: model = warm_start_model(hparams.checkpoint, model, hparams.ignore_layers, hparams.ignore_mismatched_layers) else: model, optimizer, lr_scheduler, mmi_criterion, iteration = load_checkpoint( hparams.checkpoint, model, optimizer, lr_scheduler, criterion, hparams.restore_scheduler_state) iteration += 1 # next iteration is iteration + 1 epoch_offset = max(0, int(iteration / len(train_loader))) model.train() # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): start = time.perf_counter() model.zero_grad() inputs, alignments, inputs_ctc = model.parse_batch(batch) with amp.autocast(enabled=hparams.fp16_run): outputs, decoder_outputs = model(inputs) losses = criterion(outputs, inputs, alignments=alignments, inputs_ctc=inputs_ctc, decoder_outputs=decoder_outputs) if hparams.use_mmi and hparams.use_gaf and i % gradient_adaptive_factor.UPDATE_GAF_EVERY_N_STEP == 0: mi_loss = losses["mi/loss"] overall_loss = losses["overall/loss"] gaf = calc_gaf(model, optimizer, overall_loss, mi_loss, hparams.max_gaf) losses["mi/loss"] = gaf * mi_loss losses["overall/loss"] = overall_loss - mi_loss * (1 - gaf) reduced_losses = { key: reduce_loss(value, distributed_run, n_gpus) for key, value in losses.items() } loss = losses["overall/loss"] scaler.scale(loss).backward() scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hparams.grad_clip_thresh) scaler.step(optimizer) scaler.update() if rank == 0: learning_rate = lr_scheduler.get_last_lr()[0] duration = time.perf_counter() - start print( "Iteration {} ({} epoch): overall loss {:.6f} Grad Norm {:.6f} {:.2f}s/it LR {:.3E}" .format(iteration, epoch, reduced_losses["overall/loss"], grad_norm, duration, learning_rate)) grad_norm = None if torch.isnan(grad_norm) or torch.isinf( grad_norm) else grad_norm logger.log_training(reduced_losses, grad_norm, learning_rate, duration, iteration) if iteration % hparams.iters_per_checkpoint == 0: validate(model, criterion, valset, iteration, hparams.batch_size, collate_fn, logger, distributed_run, rank, n_gpus) if rank == 0: checkpoint = os.path.join( hparams.output_dir, "checkpoint_{}".format(iteration)) save_checkpoint(model, optimizer, lr_scheduler, criterion, iteration, hparams, checkpoint) iteration += 1 if hparams.lr_scheduler == SchedulerTypes.cyclic: lr_scheduler.step() if not hparams.lr_scheduler == SchedulerTypes.cyclic: if hparams.lr_scheduler == SchedulerTypes.plateau: lr_scheduler.step( validate(model, criterion, valset, iteration, hparams.batch_size, collate_fn, logger, distributed_run, rank, n_gpus)) else: lr_scheduler.step()