def train(args, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_dir = os.path.join("tensorboard", args.model_name) os.makedirs(tb_dir, exist_ok=True) tb_writer = SummaryWriter(tb_dir) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm', 'gamma', 'beta'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = WarmupLinearSchedule(optimizer=optimizer, warmup_steps=args.warmup_steps, t_total=args.num_steps) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(models=model, optimizers=optimizer, opt_level=args.fp16_opt_level, cast_model_outputs=torch.float16) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = DDP( model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) train_dataset = LMDataset(corpus_path=args.corpus_path, tokenizer=tokenizer, local_rank=args.local_rank, seq_len=args.max_seq_length, vocab_size=args.vocab_size, mask_prob=args.mask_prob) # Train! logger.info("***** Running training *****") logger.info(" Total optimization steps = %d", args.num_steps) logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 iters = 0 model.zero_grad() model.train() set_seed( args) # Added here for reproducibility (even between python 2 and 3) while True: train_dataset.gen_segment() train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=4, pin_memory=True) epoch_iterator = tqdm( train_dataloader, desc="Training (X iter) (XX / XX Steps) (Total Loss=X.X)\ (Generator Loss=X.X) (Discriminator Loss=X.X)", disable=args.local_rank not in [-1, 0]) tr_loss = 0.0 for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(args.device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids = batch gen_loss, disc_loss = model(input_ids, segment_ids, input_mask, lm_label_ids) loss = gen_loss + disc_loss if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() mean_loss = tr_loss * args.gradient_accumulation_steps / (step + 1) if (step + 1) % args.gradient_accumulation_steps == 0: scheduler.step() # learning rate warmup optimizer.step() for param in model.parameters(): param.grad = None global_step += 1 epoch_iterator.set_description( "Training (%d iter) (%d / %d Steps) (Mean Loss=%2.5f) (Generator Loss=%2.5f) (Discriminator Loss=%2.5f)" % (iters, global_step, args.num_steps, mean_loss, gen_loss, disc_loss / 50.0)) if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('Mean_Loss', mean_loss, global_step) tb_writer.add_scalar('Gen_Loss', gen_loss, global_step) tb_writer.add_scalar('Disc_Loss', disc_loss / 50.0, global_step) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: model_to_save = model.module if hasattr( model, 'module') else model model_checkpoint = os.path.join( args.output_dir, args.model_name + '_' + str(global_step) + '.bin') model_layer_checkpoint = os.path.join( args.output_dir, args.model_name + '_' + str(global_step) + '_disc.bin') torch.save(model_to_save.state_dict(), model_checkpoint) torch.save(model_to_save.discriminator.model.state_dict(), model_layer_checkpoint) logger.info("Saving model checkpoint to %s", args.output_dir) if args.num_steps > 0 and global_step == args.num_steps: epoch_iterator.close() break if args.num_steps > 0 and global_step == args.num_steps: epoch_iterator.close() break iters += 1 if args.local_rank in [-1, 0]: model_to_save = model.module if hasattr(model, 'module') else model model_checkpoint = os.path.join( args.output_dir, args.model_name + '_' + str(global_step) + '.bin') model_layer_checkpoint = os.path.join( args.output_dir, args.model_name + '_' + str(global_step) + '_disc.bin') torch.save(model_to_save.state_dict(), model_checkpoint) torch.save(model_to_save.discriminator.model.state_dict(), model_layer_checkpoint) logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("End Training!") tb_writer.close()
loss, _ = metrics.compute(yhat, y, w=weight, m=loss_mask, plot_m=qm, x=x) loss /= config.grad_acc if config.attention_band is not None: with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() else: loss.backward() if (n_batch + 1) % config.grad_acc == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 5) opt.step() opt.zero_grad() ready_for_lr = True metrics_puppi.compute(p, y, w=weight, m=loss_mask, plot_m=qm) avg_loss_tensor += loss if config.beta: p, q = yhat[:, :, 0], yhat[:, :, 1] # logger.info(' '.join([str(x) for x in [p.max(), p.min(), q.max(), q.min()]])) yhat = p / (p + q + 1e-5) score = t2n(torch.clamp(yhat.squeeze(-1), 0, 1)) charged_mask = ~batch['neutral_mask'] score[charged_mask] = batch['y'][charged_mask]
def main(): logging.configure_logger('RNNT') logging.log_start(logging.constants.INIT_START) args = parse_args() assert(torch.cuda.is_available()) assert args.prediction_frequency is None or args.prediction_frequency % args.log_frequency == 0 torch.backends.cudnn.benchmark = args.cudnn_benchmark # set up distributed training multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1 if multi_gpu: torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') world_size = dist.get_world_size() print_once(f'Distributed training with {world_size} GPUs\n') else: world_size = 1 if args.seed is not None: logging.log_event(logging.constants.SEED, value=args.seed) torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) random.seed(args.seed + args.local_rank) # np_rng is used for buckets generation, and needs the same seed on every worker np_rng = np.random.default_rng(seed=args.seed) init_log(args) cfg = config.load(args.model_config) config.apply_duration_flags(cfg, args.max_duration) assert args.grad_accumulation_steps >= 1 assert args.batch_size % args.grad_accumulation_steps == 0, f'{args.batch_size} % {args.grad_accumulation_steps} != 0' logging.log_event(logging.constants.GRADIENT_ACCUMULATION_STEPS, value=args.grad_accumulation_steps) batch_size = args.batch_size // args.grad_accumulation_steps logging.log_event(logging.constants.SUBMISSION_BENCHMARK, value=logging.constants.RNNT) logging.log_event(logging.constants.SUBMISSION_ORG, value='my-organization') logging.log_event(logging.constants.SUBMISSION_DIVISION, value=logging.constants.CLOSED) # closed or open logging.log_event(logging.constants.SUBMISSION_STATUS, value=logging.constants.ONPREM) # on-prem/cloud/research logging.log_event(logging.constants.SUBMISSION_PLATFORM, value='my platform') logging.log_end(logging.constants.INIT_STOP) if multi_gpu: torch.distributed.barrier() logging.log_start(logging.constants.RUN_START) if multi_gpu: torch.distributed.barrier() print_once('Setting up datasets...') ( train_dataset_kw, train_features_kw, train_splicing_kw, train_specaugm_kw, ) = config.input(cfg, 'train') ( val_dataset_kw, val_features_kw, val_splicing_kw, val_specaugm_kw, ) = config.input(cfg, 'val') logging.log_event(logging.constants.DATA_TRAIN_MAX_DURATION, value=train_dataset_kw['max_duration']) logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MAX, value=train_dataset_kw['speed_perturbation']['max_rate']) logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MIN, value=train_dataset_kw['speed_perturbation']['min_rate']) logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_N, value=train_specaugm_kw['freq_masks']) logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MIN, value=train_specaugm_kw['min_freq']) logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MAX, value=train_specaugm_kw['max_freq']) logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_N, value=train_specaugm_kw['time_masks']) logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MIN, value=train_specaugm_kw['min_time']) logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MAX, value=train_specaugm_kw['max_time']) logging.log_event(logging.constants.GLOBAL_BATCH_SIZE, value=batch_size * world_size * args.grad_accumulation_steps) tokenizer_kw = config.tokenizer(cfg) tokenizer = Tokenizer(**tokenizer_kw) class PermuteAudio(torch.nn.Module): def forward(self, x): return (x[0].permute(2, 0, 1), *x[1:]) train_augmentations = torch.nn.Sequential( train_specaugm_kw and features.SpecAugment(optim_level=args.amp, **train_specaugm_kw) or torch.nn.Identity(), features.FrameSplicing(optim_level=args.amp, **train_splicing_kw), PermuteAudio(), ) val_augmentations = torch.nn.Sequential( val_specaugm_kw and features.SpecAugment(optim_level=args.amp, **val_specaugm_kw) or torch.nn.Identity(), features.FrameSplicing(optim_level=args.amp, **val_splicing_kw), PermuteAudio(), ) logging.log_event(logging.constants.DATA_TRAIN_NUM_BUCKETS, value=args.num_buckets) if args.num_buckets is not None: sampler = dali_sampler.BucketingSampler( args.num_buckets, batch_size, world_size, args.epochs, np_rng ) else: sampler = dali_sampler.SimpleSampler() train_loader = DaliDataLoader(gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=train_dataset_kw, config_features=train_features_kw, json_names=args.train_manifests, batch_size=batch_size, sampler=sampler, grad_accumulation_steps=args.grad_accumulation_steps, pipeline_type="train", device_type=args.dali_device, tokenizer=tokenizer) val_loader = DaliDataLoader(gpu_id=args.local_rank, dataset_path=args.dataset_dir, config_data=val_dataset_kw, config_features=val_features_kw, json_names=args.val_manifests, batch_size=args.val_batch_size, sampler=dali_sampler.SimpleSampler(), pipeline_type="val", device_type=args.dali_device, tokenizer=tokenizer) train_feat_proc = train_augmentations val_feat_proc = val_augmentations train_feat_proc.cuda() val_feat_proc.cuda() steps_per_epoch = len(train_loader) // args.grad_accumulation_steps logging.log_event(logging.constants.TRAIN_SAMPLES, value=train_loader.dataset_size) logging.log_event(logging.constants.EVAL_SAMPLES, value=val_loader.dataset_size) # set up the model rnnt_config = config.rnnt(cfg) logging.log_event(logging.constants.MODEL_WEIGHTS_INITIALIZATION_SCALE, value=args.weights_init_scale) if args.weights_init_scale is not None: rnnt_config['weights_init_scale'] = args.weights_init_scale if args.hidden_hidden_bias_scale is not None: rnnt_config['hidden_hidden_bias_scale'] = args.hidden_hidden_bias_scale model = RNNT(n_classes=tokenizer.num_labels + 1, **rnnt_config) model.cuda() blank_idx = tokenizer.num_labels loss_fn = RNNTLoss(blank_idx=blank_idx) logging.log_event(logging.constants.EVAL_MAX_PREDICTION_SYMBOLS, value=args.max_symbol_per_sample) greedy_decoder = RNNTGreedyDecoder( blank_idx=blank_idx, max_symbol_per_sample=args.max_symbol_per_sample) print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n') opt_eps=1e-9 logging.log_event(logging.constants.OPT_NAME, value='lamb') logging.log_event(logging.constants.OPT_BASE_LR, value=args.lr) logging.log_event(logging.constants.OPT_LAMB_EPSILON, value=opt_eps) logging.log_event(logging.constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=args.lr_exp_gamma) logging.log_event(logging.constants.OPT_LR_WARMUP_EPOCHS, value=args.warmup_epochs) logging.log_event(logging.constants.OPT_LAMB_LR_HOLD_EPOCHS, value=args.hold_epochs) logging.log_event(logging.constants.OPT_LAMB_BETA_1, value=args.beta1) logging.log_event(logging.constants.OPT_LAMB_BETA_2, value=args.beta2) logging.log_event(logging.constants.OPT_GRADIENT_CLIP_NORM, value=args.clip_norm) logging.log_event(logging.constants.OPT_LR_ALT_DECAY_FUNC, value=True) logging.log_event(logging.constants.OPT_LR_ALT_WARMUP_FUNC, value=True) logging.log_event(logging.constants.OPT_LAMB_LR_MIN, value=args.min_lr) logging.log_event(logging.constants.OPT_WEIGHT_DECAY, value=args.weight_decay) # optimization kw = {'params': model.param_groups(args.lr), 'lr': args.lr, 'weight_decay': args.weight_decay} initial_lrs = [group['lr'] for group in kw['params']] print_once(f'Starting with LRs: {initial_lrs}') optimizer = FusedLAMB(betas=(args.beta1, args.beta2), eps=opt_eps, **kw) adjust_lr = lambda step, epoch: lr_policy( step, epoch, initial_lrs, optimizer, steps_per_epoch=steps_per_epoch, warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs, min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma) if args.amp: model, optimizer = amp.initialize( models=model, optimizers=optimizer, opt_level='O1', max_loss_scale=512.0) if args.ema > 0: ema_model = copy.deepcopy(model).cuda() else: ema_model = None logging.log_event(logging.constants.MODEL_EVAL_EMA_FACTOR, value=args.ema) if multi_gpu: model = DistributedDataParallel(model) # load checkpoint meta = {'best_wer': 10**6, 'start_epoch': 0} checkpointer = Checkpointer(args.output_dir, 'RNN-T', args.keep_milestones, args.amp) if args.resume: args.ckpt = checkpointer.last_checkpoint() or args.ckpt if args.ckpt is not None: checkpointer.load(args.ckpt, model, ema_model, optimizer, meta) start_epoch = meta['start_epoch'] best_wer = meta['best_wer'] last_wer = meta['best_wer'] epoch = 1 step = start_epoch * steps_per_epoch + 1 # training loop model.train() for epoch in range(start_epoch + 1, args.epochs + 1): logging.log_start(logging.constants.BLOCK_START, metadata=dict(first_epoch_num=epoch, epoch_count=1)) logging.log_start(logging.constants.EPOCH_START, metadata=dict(epoch_num=epoch)) epoch_utts = 0 accumulated_batches = 0 epoch_start_time = time.time() for batch in train_loader: if accumulated_batches == 0: adjust_lr(step, epoch) optimizer.zero_grad() step_utts = 0 step_start_time = time.time() all_feat_lens = [] audio, audio_lens, txt, txt_lens = batch feats, feat_lens = train_feat_proc([audio, audio_lens]) all_feat_lens += feat_lens log_probs, log_prob_lens = model(feats, feat_lens, txt, txt_lens) loss = loss_fn(log_probs[:, :log_prob_lens.max().item()], log_prob_lens, txt, txt_lens) loss /= args.grad_accumulation_steps del log_probs, log_prob_lens if torch.isnan(loss).any(): print_once(f'WARNING: loss is NaN; skipping update') else: if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() loss_item = loss.item() del loss step_utts += batch[0].size(0) * world_size epoch_utts += batch[0].size(0) * world_size accumulated_batches += 1 if accumulated_batches % args.grad_accumulation_steps == 0: if args.clip_norm is not None: torch.nn.utils.clip_grad_norm_( getattr(model, 'module', model).parameters(), max_norm=args.clip_norm, norm_type=2) total_norm = 0.0 try: if args.log_norm: for p in getattr(model, 'module', model).parameters(): param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) except AttributeError as e: print_once(f'Exception happened: {e}') total_norm = 0.0 optimizer.step() apply_ema(model, ema_model, args.ema) if step % args.log_frequency == 0: if args.prediction_frequency is None or step % args.prediction_frequency == 0: preds = greedy_decoder.decode(model, feats, feat_lens) wer, pred_utt, ref = greedy_wer( preds, txt, txt_lens, tokenizer.detokenize) print_once(f' Decoded: {pred_utt[:90]}') print_once(f' Reference: {ref[:90]}') wer = {'wer': 100 * wer} else: wer = {} step_time = time.time() - step_start_time log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch), step, 'train', {'loss': loss_item, **wer, # optional entry 'throughput': step_utts / step_time, 'took': step_time, 'grad-norm': total_norm, 'seq-len-min': min(all_feat_lens).item(), 'seq-len-max': max(all_feat_lens).item(), 'lrate': optimizer.param_groups[0]['lr']}) step_start_time = time.time() step += 1 accumulated_batches = 0 # end of step logging.log_end(logging.constants.EPOCH_STOP, metadata=dict(epoch_num=epoch)) epoch_time = time.time() - epoch_start_time log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time, 'took': epoch_time}) if epoch % args.val_frequency == 0: wer = evaluate(epoch, step, val_loader, val_feat_proc, tokenizer.detokenize, ema_model, loss_fn, greedy_decoder, args.amp) last_wer = wer if wer < best_wer and epoch >= args.save_best_from: checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer, is_best=True) best_wer = wer save_this_epoch = (args.save_frequency is not None and epoch % args.save_frequency == 0) \ or (epoch in args.keep_milestones) if save_this_epoch: checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer) logging.log_end(logging.constants.BLOCK_STOP, metadata=dict(first_epoch_num=epoch)) if last_wer <= args.target: logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'success'}) print_once(f'Finished after {args.epochs_this_job} epochs.') break if 0 < args.epochs_this_job <= epoch - start_epoch: print_once(f'Finished after {args.epochs_this_job} epochs.') break # end of epoch log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time}) if last_wer > args.target: logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'aborted'}) if epoch == args.epochs: evaluate(epoch, step, val_loader, val_feat_proc, tokenizer.detokenize, ema_model, loss_fn, greedy_decoder, args.amp) flush_log() if args.save_at_the_end: checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)