def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Load dataloader dataloader = build_dataloader( args=args, tsv_path=s, batch_size=1, is_test=True, first_n_utterances=args.recog_first_n_utt, longform_max_n_frames=args.recog_longform_max_n_frames) if i == 0: # Load ASR model model = Speech2Text(args, dir_name) epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10 if args.recog_n_average > 1: # Model averaging for Transformer model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() save_path = mkdir_join(args.recog_dir, 'ctc_probs') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) for batch in dataloader: nbest_hyps_id, _ = model.decode(batch['xs'], args, dataloader.idx2token[0]) best_hyps_id = [h[0] for h in nbest_hyps_id] # Get CTC probs ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'], temperature=1, topk=min( 100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): tokens = dataloader.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_ctc_probs( ctc_probs[b, :xlens[b]], topk_ids[b], factor=args.subsample_factor, spectrogram=batch['xs'][b][:, :dataloader.input_dim], save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50)
def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')): os.remove(os.path.join(args.recog_dir, 'decode.log')) set_logger(os.path.join(args.recog_dir, 'decode.log'), stdout=args.recog_stdout) wer_avg, cer_avg, per_avg = 0, 0, 0 ppl_avg, loss_avg = 0, 0 acc_avg = 0 bleu_avg = 0 for i, s in enumerate(args.recog_sets): # Load dataloader dataloader = build_dataloader( args=args, tsv_path=s, batch_size=1, is_test=True, first_n_utterances=args.recog_first_n_utt, longform_max_n_frames=args.recog_longform_max_n_frames) if i == 0: # Load ASR model model = Speech2Text(args, dir_name) epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10 if args.recog_n_average > 1: # Model averaging for Transformer # topk_list = load_checkpoint(args.recog_model[0], model) model = average_checkpoints( model, args.recog_model[0], # topk_list=topk_list, n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) # Ensemble (different models) ensemble_models = [model] if len(args.recog_model) > 1: for recog_model_e in args.recog_model[1:]: conf_e = load_config( os.path.join(os.path.dirname(recog_model_e), 'conf.yml')) args_e = copy.deepcopy(args) for k, v in conf_e.items(): if 'recog' not in k: setattr(args_e, k, v) model_e = Speech2Text(args_e) load_checkpoint(recog_model_e, model_e) if args.recog_n_gpus >= 1: model_e.cuda() ensemble_models += [model_e] # Load LM for shallow fusion if not args.lm_fusion: # first path if args.recog_lm is not None and args.recog_lm_weight > 0: conf_lm = load_config( os.path.join(os.path.dirname(args.recog_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) args_lm.recog_mem_len = args.recog_mem_len lm = build_lm(args_lm, wordlm=args.recog_wordlm, lm_dict_path=os.path.join( os.path.dirname(args.recog_lm), 'dict.txt'), asr_dict_path=os.path.join( dir_name, 'dict.txt')) load_checkpoint(args.recog_lm, lm) if args_lm.backward: model.lm_bwd = lm else: model.lm_fwd = lm # second path (forward) if args.recog_lm_second is not None and args.recog_lm_second_weight > 0: conf_lm_second = load_config( os.path.join(os.path.dirname(args.recog_lm_second), 'conf.yml')) args_lm_second = argparse.Namespace() for k, v in conf_lm_second.items(): setattr(args_lm_second, k, v) args_lm_second.recog_mem_len = args.recog_mem_len lm_second = build_lm(args_lm_second) load_checkpoint(args.recog_lm_second, lm_second) model.lm_second = lm_second # second path (backward) if args.recog_lm_bwd is not None and args.recog_lm_bwd_weight > 0: conf_lm = load_config( os.path.join(os.path.dirname(args.recog_lm_bwd), 'conf.yml')) args_lm_bwd = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm_bwd, k, v) args_lm_bwd.recog_mem_len = args.recog_mem_len lm_bwd = build_lm(args_lm_bwd) load_checkpoint(args.recog_lm_bwd, lm_bwd) model.lm_bwd = lm_bwd if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('recog metric: %s' % args.recog_metric) logger.info('recog oracle: %s' % args.recog_oracle) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) logger.info('beam width: %d' % args.recog_beam_width) logger.info('min length ratio: %.3f' % args.recog_min_len_ratio) logger.info('max length ratio: %.3f' % args.recog_max_len_ratio) logger.info('length penalty: %.3f' % args.recog_length_penalty) logger.info('length norm: %s' % args.recog_length_norm) logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty) logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold) logger.info('CTC weight: %.3f' % args.recog_ctc_weight) logger.info('fist LM path: %s' % args.recog_lm) logger.info('second LM path: %s' % args.recog_lm_second) logger.info('backward LM path: %s' % args.recog_lm_bwd) logger.info('LM weight (first-pass): %.3f' % args.recog_lm_weight) logger.info('LM weight (second-pass): %.3f' % args.recog_lm_second_weight) logger.info('LM weight (backward): %.3f' % args.recog_lm_bwd_weight) logger.info('GNMT: %s' % args.recog_gnmt_decoding) logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention) logger.info('resolving UNK: %s' % args.recog_resolving_unk) logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over)) logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over)) logger.info('model average (Transformer): %d' % (args.recog_n_average)) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() start_time = time.time() if args.recog_metric == 'edit_distance': if args.recog_unit in ['word', 'word_char']: wer, cer, _ = eval_word(ensemble_models, dataloader, args, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True, fine_grained=True, oracle=True) wer_avg += wer cer_avg += cer elif args.recog_unit == 'wp': wer, cer = eval_wordpiece(ensemble_models, dataloader, args, epoch=epoch - 1, recog_dir=args.recog_dir, streaming=args.recog_streaming, progressbar=True, fine_grained=True, oracle=True) wer_avg += wer cer_avg += cer elif 'char' in args.recog_unit: wer, cer = eval_char(ensemble_models, dataloader, args, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True, task_idx=0, fine_grained=True, oracle=True) # task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0) wer_avg += wer cer_avg += cer elif 'phone' in args.recog_unit: per = eval_phone(ensemble_models, dataloader, args, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True, fine_grained=True, oracle=True) per_avg += per else: raise ValueError(args.recog_unit) elif args.recog_metric in ['ppl', 'loss']: ppl, loss = eval_ppl(ensemble_models, dataloader, progressbar=True) ppl_avg += ppl loss_avg += loss elif args.recog_metric == 'accuracy': acc_avg += eval_accuracy(ensemble_models, dataloader, progressbar=True) elif args.recog_metric == 'bleu': bleu = eval_wordpiece_bleu(ensemble_models, dataloader, args, epoch=epoch - 1, recog_dir=args.recog_dir, streaming=args.recog_streaming, progressbar=True, fine_grained=True, oracle=True) bleu_avg += bleu else: raise NotImplementedError(args.recog_metric) elapsed_time = time.time() - start_time logger.info('Elapsed time: %.3f [sec]' % elapsed_time) logger.info('RTF: %.3f' % (elapsed_time / (dataloader.n_frames * 0.01))) if args.recog_metric == 'edit_distance': if 'phone' in args.recog_unit: logger.info('PER (avg.): %.2f %%\n' % (per_avg / len(args.recog_sets))) else: logger.info('WER / CER (avg.): %.2f / %.2f %%\n' % (wer_avg / len(args.recog_sets), cer_avg / len(args.recog_sets))) elif args.recog_metric in ['ppl', 'loss']: logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets))) print('PPL (avg.): %.3f' % (ppl_avg / len(args.recog_sets))) logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets))) print('Loss (avg.): %.3f' % (loss_avg / len(args.recog_sets))) elif args.recog_metric == 'accuracy': logger.info('Accuracy (avg.): %.2f\n' % (acc_avg / len(args.recog_sets))) print('Accuracy (avg.): %.3f' % (acc_avg / len(args.recog_sets))) elif args.recog_metric == 'bleu': logger.info('BLEU (avg.): %.2f\n' % (bleu / len(args.recog_sets))) print('BLEU (avg.): %.3f' % (bleu / len(args.recog_sets)))
def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Load dataloader dataloader = build_dataloader(args=args, tsv_path=s, batch_size=1, is_test=True, first_n_utterances=args.recog_first_n_utt, longform_max_n_frames=args.recog_longform_max_n_frames) if i == 0: # Load ASR model model = Speech2Text(args, dir_name) epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10 if args.recog_n_average > 1: # Model averaging for Transformer model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) # Ensemble (different models) ensemble_models = [model] if len(args.recog_model) > 1: for recog_model_e in args.recog_model[1:]: conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml')) args_e = copy.deepcopy(args) for k, v in conf_e.items(): if 'recog' not in k: setattr(args_e, k, v) model_e = Speech2Text(args_e) load_checkpoint(recog_model_e, model_e) if args.recog_n_gpus >= 1: model_e.cuda() ensemble_models += [model_e] # Load LM for shallow fusion if not args.lm_fusion: # first path if args.recog_lm is not None and args.recog_lm_weight > 0: conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) lm = build_lm(args_lm) load_checkpoint(args.recog_lm, lm) if args_lm.backward: model.lm_bwd = lm else: model.lm_fwd = lm # NOTE: only support for first path if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('recog oracle: %s' % args.recog_oracle) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) logger.info('beam width: %d' % args.recog_beam_width) logger.info('min length ratio: %.3f' % args.recog_min_len_ratio) logger.info('max length ratio: %.3f' % args.recog_max_len_ratio) logger.info('length penalty: %.3f' % args.recog_length_penalty) logger.info('length norm: %s' % args.recog_length_norm) logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty) logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold) logger.info('CTC weight: %.3f' % args.recog_ctc_weight) logger.info('fist LM path: %s' % args.recog_lm) logger.info('LM weight: %.3f' % args.recog_lm_weight) logger.info('GNMT: %s' % args.recog_gnmt_decoding) logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention) logger.info('resolving UNK: %s' % args.recog_resolving_unk) logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over)) logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over)) logger.info('model average (Transformer): %d' % (args.recog_n_average)) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() save_path = mkdir_join(args.recog_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) for batch in dataloader: nbest_hyps_id, aws = model.decode( batch['xs'], args, dataloader.idx2token[0], exclude_eos=False, refs_id=batch['ys'], ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [], speakers=batch['sessions'] if dataloader.corpus == 'swbd' else batch['speakers']) best_hyps_id = [h[0] for h in nbest_hyps_id] # Get CTC probs ctc_probs, topk_ids = None, None if args.ctc_weight > 0: ctc_probs, topk_ids, xlens = model.get_ctc_probs( batch['xs'], task='ys', temperature=1, topk=min(100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' ctc_probs_sub1, topk_ids_sub1 = None, None if args.ctc_weight_sub1 > 0: ctc_probs_sub1, topk_ids_sub1, xlens_sub1 = model.get_ctc_probs( batch['xs'], task='ys_sub1', temperature=1, topk=min(100, model.vocab_sub1)) if model.bwd_weight > 0.5: # Reverse the order best_hyps_id = [hyp[::-1] for hyp in best_hyps_id] aws = [[aw[0][:, ::-1]] for aw in aws] for b in range(len(batch['xs'])): tokens = dataloader.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_attention_weights( aws[b][0][:, :len(tokens)], tokens, spectrogram=batch['xs'][b][:, :dataloader.input_dim] if args.input_type == 'speech' else None, factor=args.subsample_factor, ref=batch['text'][b].lower(), save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8), ctc_probs=ctc_probs[b, :xlens[b]] if ctc_probs is not None else None, ctc_topk_ids=topk_ids[b] if topk_ids is not None else None, ctc_probs_sub1=ctc_probs_sub1[b, :xlens_sub1[b]] if ctc_probs_sub1 is not None else None, ctc_topk_ids_sub1=topk_ids_sub1[b] if topk_ids_sub1 is not None else None) if model.bwd_weight > 0.5: hyp = ' '.join(tokens[::-1]) else: hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50)
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) args_init = copy.deepcopy(args) args_teacher = copy.deepcopy(args) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k not in ['resume', 'local_rank']: setattr(args, k, v) args = compute_subsampling_factor(args) resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0 # Load dataset train_set = build_dataloader(args=args, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, batch_size=args.batch_size, batch_size_type=args.batch_size_type, max_n_frames=args.max_n_frames, resume_epoch=resume_epoch, sort_by=args.sort_by, short2long=args.sort_short2long, sort_stop_epoch=args.sort_stop_epoch, num_workers=args.workers, pin_memory=args.pin_memory, distributed=args.distributed, word_alignment_dir=args.train_word_alignment, ctc_alignment_dir=args.train_ctc_alignment) dev_set = build_dataloader( args=args, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, batch_size=1 if 'transducer' in args.dec_type else args.batch_size, batch_size_type='seq' if 'transducer' in args.dec_type else args.batch_size_type, max_n_frames=1600, word_alignment_dir=args.dev_word_alignment, ctc_alignment_dir=args.dev_ctc_alignment) eval_sets = [ build_dataloader(args=args, tsv_path=s, batch_size=1, is_test=True) for s in args.eval_sets ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Set save path if args.resume: args.save_path = os.path.dirname(args.resume) dir_name = os.path.basename(args.save_path) else: dir_name = set_asr_model_name(args) if args.mbr_training: assert args.asr_init args.save_path = mkdir_join(os.path.dirname(args.asr_init), dir_name) else: args.save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) if args.local_rank > 0: time.sleep(1) args.save_path = set_save_path(args.save_path) # avoid overwriting # Set logger set_logger(os.path.join(args.save_path, 'train.log'), args.stdout, args.local_rank) # Load a LM conf file for LM fusion & LM initialization if not args.resume and args.external_lm: lm_conf = load_config( os.path.join(os.path.dirname(args.external_lm), 'conf.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab # Model setting model = Speech2Text(args, args.save_path, train_set.idx2token[0]) if not args.resume: # Save nlsyms, dictionary, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(args.save_path, 'nlsyms.txt')) for sub in ['', '_sub1', '_sub2']: if args.get('dict' + sub): shutil.copy( args.get('dict' + sub), os.path.join(args.save_path, 'dict' + sub + '.txt')) if args.get('unit' + sub) == 'wp': shutil.copy( args.get('wp_model' + sub), os.path.join(args.save_path, 'wp' + sub + '.model')) for k, v in sorted(args.items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info('torch version: %s' % str(torch.__version__)) logger.info(model) # Initialize with pre-trained model's parameters if args.asr_init: # Load ASR model (full model) conf_init = load_config( os.path.join(os.path.dirname(args.asr_init), 'conf.yml')) for k, v in conf_init.items(): setattr(args_init, k, v) model_init = Speech2Text(args_init) load_checkpoint(args.asr_init, model_init) # Overwrite parameters param_dict = dict(model_init.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if args.asr_init_enc_only and 'enc' not in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer optimizer = set_optimizer( model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler is_transformer = 'former' in args.enc_type or 'former' in args.dec_type or 'former' in args.dec_type_sub1 scheduler = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, lower_better=args.metric not in ['accuracy', 'bleu'], warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, peak_lr=0.05 / (args.get('transformer_enc_d_model', 0)**0.5) if 'conformer' in args.enc_type else 1e6, model_size=args.get('transformer_enc_d_model', args.get('transformer_dec_d_model', 0)), factor=args.lr_factor, noam=args.optimizer == 'noam', save_checkpoints_topk=10 if is_transformer else 1) if args.resume: # Restore the last saved model load_checkpoint(args.resume, model, scheduler) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if resume_epoch == args.convert_to_sgd_epoch: scheduler.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) # Load teacher ASR model teacher = None if args.teacher: assert os.path.isfile(args.teacher), 'There is no checkpoint.' conf_teacher = load_config( os.path.join(os.path.dirname(args.teacher), 'conf.yml')) for k, v in conf_teacher.items(): setattr(args_teacher, k, v) # Setting for knowledge distillation args_teacher.ss_prob = 0 args.lsm_prob = 0 teacher = Speech2Text(args_teacher) load_checkpoint(args.teacher, teacher) # Load teacher LM teacher_lm = None if args.teacher_lm: assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.' conf_lm = load_config( os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) teacher_lm = build_lm(args_lm) load_checkpoint(args.teacher_lm, teacher_lm) # GPU setting args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"] amp, scaler = None, None if args.n_gpus >= 1: model.cudnn_setting( deterministic=((not is_transformer) and (not args.cudnn_benchmark)) or args.cudnn_deterministic, benchmark=(not is_transformer) and args.cudnn_benchmark) # Mixed precision training setting if args.use_apex: if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): scaler = torch.cuda.amp.GradScaler() else: from apex import amp model, scheduler.optimizer = amp.initialize( model, scheduler.optimizer, opt_level=args.train_dtype) from neural_sp.models.seq2seq.decoders.ctc import CTC amp.register_float_function(CTC, "loss_fn") # NOTE: see https://github.com/espnet/espnet/pull/1779 amp.init() if args.resume: load_checkpoint(args.resume, amp=amp) n = torch.cuda.device_count() // args.local_world_size device_ids = list(range(args.local_rank * n, (args.local_rank + 1) * n)) torch.cuda.set_device(device_ids[0]) model.cuda(device_ids[0]) scheduler.cuda(device_ids[0]) if args.distributed: model = DDP(model, device_ids=device_ids) else: model = CustomDataParallel(model, device_ids=list(range(args.n_gpus))) if teacher is not None: teacher.cuda() if teacher_lm is not None: teacher_lm.cuda() else: model = CPUWrapperASR(model) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('#GPU: %d' % torch.cuda.device_count()) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(args, model, args.local_rank) args.wandb_id = reporter.wandb_id if args.resume: n_steps = scheduler.n_steps * max( 1, args.accum_grad_n_steps // args.local_world_size) reporter.resume(n_steps, resume_epoch) # Save conf file as a yaml file if args.local_rank == 0: save_config(args, os.path.join(args.save_path, 'conf.yml')) if args.external_lm: save_config(args.lm_conf, os.path.join(args.save_path, 'conf_lm.yml')) # NOTE: save after reporter for wandb ID # Define tasks if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if args.total_weight - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks if args.mbr_ce_weight > 0: tasks = ['ys.mbr'] + tasks for sub in ['sub1', 'sub2']: if args.get('train_set_' + sub) is not None: if args.get(sub + '_weight', 0) - args.get( 'ctc_weight_' + sub, 0) > 0: tasks = ['ys_' + sub] + tasks if args.get('ctc_weight_' + sub, 0) > 0: tasks = ['ys_' + sub + '.ctc'] + tasks else: tasks = ['all'] if args.get('ss_start_epoch', 0) <= resume_epoch: model.module.trigger_scheduled_sampling() if args.get('mocha_quantity_loss_start_epoch', 0) <= resume_epoch: model.module.trigger_quantity_loss() start_time_train = time.time() for ep in range(resume_epoch, args.n_epochs): train_one_epoch(model, train_set, dev_set, eval_sets, scheduler, reporter, logger, args, amp, scaler, tasks, teacher, teacher_lm) # Save checkpoint and validate model per epoch if reporter.n_epochs + 1 < args.eval_start_epoch: scheduler.epoch() # lr decay reporter.epoch() # plot # Save model if args.local_rank == 0: scheduler.save_checkpoint(model, args.save_path, amp=amp, remove_old=(not is_transformer) and args.remove_old_checkpoints) else: start_time_eval = time.time() # dev metric_dev = validate([model.module], dev_set, args, reporter.n_epochs + 1, logger) scheduler.epoch(metric_dev) # lr decay reporter.epoch(metric_dev, name=args.metric) # plot reporter.add_scalar('dev/' + args.metric, metric_dev) if scheduler.is_topk or is_transformer: # Save model if args.local_rank == 0: scheduler.save_checkpoint(model, args.save_path, amp=amp, remove_old=(not is_transformer) and args.remove_old_checkpoints) # test if scheduler.is_topk: for eval_set in eval_sets: validate([model.module], eval_set, args, reporter.n_epochs, logger) logger.info('Evaluation time: %.2f min' % ((time.time() - start_time_eval) / 60)) # Early stopping if scheduler.is_early_stop: break # Convert to fine-tuning stage if reporter.n_epochs == args.convert_to_sgd_epoch: scheduler.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) if reporter.n_epochs >= args.n_epochs: break if args.get('ss_start_epoch', 0) == (ep + 1): model.module.trigger_scheduled_sampling() if args.get('mocha_quantity_loss_start_epoch', 0) == (ep + 1): model.module.trigger_quantity_loss() logger.info('Total time: %.2f hour' % ((time.time() - start_time_train) / 3600)) reporter.close() return args.save_path
def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'align.log')): os.remove(os.path.join(args.recog_dir, 'align.log')) set_logger(os.path.join(args.recog_dir, 'align.log'), stdout=args.recog_stdout) # Load ASR model model = Speech2Text(args, dir_name) average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() for s in args.recog_sets: # Align all utterances args.min_n_frames = 0 args.max_n_frames = 1e5 # Load dataloader dataloader = build_dataloader(args=args, tsv_path=s, batch_size=args.recog_batch_size) save_path = mkdir_join(args.recog_dir, 'ctc_forced_alignments') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) pbar = tqdm(total=len(dataloader)) for batch in dataloader: trigger_points = model.ctc_forced_align(batch['xs'], batch['ys']) # `[B, L]` for b in range(len(batch['xs'])): save_path_spk = mkdir_join(save_path, batch['speakers'][b]) save_path_utt = mkdir_join(save_path_spk, batch['utt_ids'][b] + '.txt') tokens = dataloader.idx2token[0](batch['ys'][b], return_list=True) with codecs.open(save_path_utt, 'w', encoding="utf-8") as f: for i_tok, tok in enumerate(tokens): f.write('%s %d\n' % (tok, trigger_points[b, i_tok])) f.write('%s %d\n' % ('<eos>', trigger_points[b, len(tokens)])) pbar.update(len(batch['xs'])) pbar.close()