def plot_ctc(self): """Plot CTC posteriors during training.""" self.dec_fwd._plot_ctc(mkdir_join(self.save_path, 'ctc')) if getattr(self, 'dec_fwd_sub1', None) is not None: self.dec_fwd_sub1._plot_ctc(mkdir_join(self.save_path, 'ctc_sub1')) if getattr(self, 'dec_fwd_sub2', None) is not None: self.dec_fwd_sub2._plot_ctc(mkdir_join(self.save_path, 'ctc_sub2'))
def plot_attention(self): """Plot attention weights during training.""" # encoder self.enc._plot_attention(mkdir_join(self.save_path, 'enc_att_weights')) # decoder self.dec_fwd._plot_attention(mkdir_join(self.save_path, 'dec_att_weights')) if getattr(self, 'dec_fwd_sub1', None) is not None: self.dec_fwd_sub1._plot_attention(mkdir_join(self.save_path, 'dec_att_weights_sub1')) if getattr(self, 'dec_fwd_sub2', None) is not None: self.dec_fwd_sub2._plot_attention(mkdir_join(self.save_path, 'dec_att_weights_sub2'))
def _plot_attention(self, save_path, n_cols=2): """Plot attention for each head in all encoder layers.""" from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator _save_path = mkdir_join(save_path, 'enc_att_weights') # Clean directory if _save_path is not None and os.path.isdir(_save_path): shutil.rmtree(_save_path) os.mkdir(_save_path) elens = self.data_dict['elens'] for k, aw in self.aws_dict.items(): plt.clf() n_heads = aw.shape[1] n_cols_tmp = 1 if n_heads == 1 else n_cols fig, axes = plt.subplots(max(1, n_heads // n_cols_tmp), n_cols_tmp, figsize=(20, 8), squeeze=False) for h in range(n_heads): ax = axes[h // n_cols_tmp, h % n_cols_tmp] ax.imshow(aw[-1, h, :elens[-1], :elens[-1]], aspect="auto") ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.tight_layout() fig.savefig(os.path.join(_save_path, '%s.png' % k), dvi=500) plt.close()
def plot_attention(self, n_cols=4): """Plot attention for each head in all layers.""" from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator save_path = mkdir_join(self.save_path, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) for lth in range(self.n_layers): if not hasattr(self, 'yy_aws_layer%d' % lth): continue yy_aws = getattr(self, 'yy_aws_layer%d' % lth) plt.clf() fig, axes = plt.subplots(self.n_heads // n_cols, n_cols, figsize=(20, 8)) for h in range(self.n_heads): if self.n_heads > n_cols: ax = axes[h // n_cols, h % n_cols] else: ax = axes[h] ax.imshow(yy_aws[-1, h, :, :], aspect="auto") ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.tight_layout() fig.savefig(os.path.join(save_path, 'layer%d.png' % (lth)), dvi=500) plt.close()
def _plot_attention(self, save_path, n_cols=1): """Plot attention.""" from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator _save_path = mkdir_join(save_path, 'dec_att_weights') # Clean directory if _save_path is not None and os.path.isdir(_save_path): shutil.rmtree(_save_path) os.mkdir(_save_path) if hasattr(self, 'aws'): plt.clf() fig, axes = plt.subplots(max(1, self.score.n_heads // n_cols), n_cols, figsize=(20, 8), squeeze=False) for h in range(self.score.n_heads): ax = axes[h // n_cols, h % n_cols] ax.imshow(self.aws[-1, h, :, :], aspect="auto") ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.tight_layout() fig.savefig(os.path.join(_save_path, 'attention.png'), dvi=500) plt.close()
def _plot_attention(self, save_path, n_cols=2): """Plot attention for each head in all layers.""" from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator for attn in ['yy', 'xy']: _save_path = mkdir_join(save_path, 'dec_%s_att_weights' % attn) # Clean directory if _save_path is not None and os.path.isdir(_save_path): shutil.rmtree(_save_path) os.mkdir(_save_path) for l in range(self.n_layers): if hasattr(self, '%s_aws_layer%d' % (attn, l)): aws = getattr(self, '%s_aws_layer%d' % (attn, l)) plt.clf() fig, axes = plt.subplots(max(1, self.n_heads // n_cols), n_cols, figsize=(20, 8), squeeze=False) for h in range(self.n_heads): ax = axes[h // n_cols, h % n_cols] ax.imshow(aws[-1, h, :, :], aspect="auto") ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.tight_layout() fig.savefig(os.path.join(_save_path, 'layer%d.png' % (l)), dvi=500) plt.close()
def _plot_attention(self, save_path, n_cols=2): """Plot attention for each head in all decoder layers.""" if getattr(self, 'att_weight', 0) == 0: return from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator _save_path = mkdir_join(save_path, 'dec_att_weights') # Clean directory if _save_path is not None and os.path.isdir(_save_path): shutil.rmtree(_save_path) os.mkdir(_save_path) elens = self.data_dict['elens'] ylens = self.data_dict['ylens'] # ys = self.data_dict['ys'] for k, aw in self.aws_dict.items(): plt.clf() n_heads = aw.shape[1] n_cols_tmp = 1 if n_heads == 1 else n_cols * max(1, n_heads // 4) fig, axes = plt.subplots(max(1, n_heads // n_cols_tmp), n_cols_tmp, figsize=(20 * max(1, n_heads // 4), 8), squeeze=False) for h in range(n_heads): ax = axes[h // n_cols_tmp, h % n_cols_tmp] if 'yy' in k: ax.imshow(aw[-1, h, :ylens[-1], :ylens[-1]], aspect="auto") else: ax.imshow(aw[-1, h, :ylens[-1], :elens[-1]], aspect="auto") # NOTE: show the last utterance in a mini-batch ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) # ax.set_yticks(np.linspace(0, ylens[-1] - 1, ylens[-1])) # ax.set_yticks(np.linspace(0, ylens[-1] - 1, 1), minor=True) # ax.set_yticklabels(ys + ['']) fig.tight_layout() fig.savefig(os.path.join(_save_path, '%s.png' % k), dvi=500) plt.close()
def _plot_ctc(self, save_path, topk=10): """Plot CTC posteriors.""" if self.ctc_weight == 0: return from matplotlib import pyplot as plt _save_path = mkdir_join(save_path, 'ctc') # Clean directory if _save_path is not None and os.path.isdir(_save_path): shutil.rmtree(_save_path) os.mkdir(_save_path) elen = self.ctc.data_dict['elens'][-1] probs = self.ctc.prob_dict['probs'][-1, :elen] # `[T, vocab]` # NOTE: show the last utterance in a mini-batch topk_ids = np.argsort(probs, axis=1) plt.clf() n_frames = probs.shape[0] times_probs = np.arange(n_frames) # NOTE: index 0 is reserved for blank for idx in set(topk_ids.reshape(-1).tolist()): if idx == 0: plt.plot(times_probs, probs[:, 0], ':', label='<blank>', color='grey') else: plt.plot(times_probs, probs[:, idx]) plt.xlabel(u'Time [frame]', fontsize=12) plt.ylabel('Posteriors', fontsize=12) plt.xticks(list(range(0, int(n_frames) + 1, 10))) plt.yticks(list(range(0, 2, 1))) plt.tight_layout() plt.savefig(os.path.join(_save_path, '%s.png' % 'prob'), dvi=500) plt.close()
def main(): args = parse() # Load a conf file dir_name = os.path.dirname(args.recog_model[0]) conf = load_config(os.path.join(dir_name, 'conf.yml')) # Overwrite conf for k, v in conf.items(): if 'recog' not in k: setattr(args, k, v) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Load dataset dataset = Dataset(corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), wp_model=os.path.join(dir_name, 'wp.model'), unit=args.unit, batch_size=args.recog_batch_size, bptt=args.bptt, backward=args.backward, serialize=args.serialize, is_test=True) if i == 0: # Load the LM model = build_lm(args, dir_name) topk_list = load_checkpoint(model, args.recog_model[0]) epoch = int(args.recog_model[0].split('-')[-1]) # Model averaging for Transformer if conf['lm_type'] == 'transformer': model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average, topk_list=topk_list) logger.info('epoch: %d' % (epoch - 1)) logger.info('batch size: %d' % args.recog_batch_size) # logger.info('recog unit: %s' % args.recog_unit) # logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('BPTT: %d' % (args.bptt)) logger.info('cache size: %d' % (args.recog_n_caches)) logger.info('cache theta: %.3f' % (args.recog_cache_theta)) logger.info('cache lambda: %.3f' % (args.recog_cache_lambda)) model.cache_theta = args.recog_cache_theta model.cache_lambda = args.recog_cache_lambda # GPU setting model.cuda() assert args.recog_n_caches > 0 save_path = mkdir_join(args.recog_dir, 'cache') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) hidden = None fig_count = 0 toknen_count = 0 n_tokens = args.recog_n_caches while True: ys, is_new_epoch = dataset.next() for t in range(ys.shape[1] - 1): loss, hidden = model(ys[:, t:t + 2], hidden, is_eval=True, n_caches=args.recog_n_caches)[:2] if len(model.cache_attn) > 0: if toknen_count == n_tokens: tokens_keys = dataset.idx2token[0]( model.cache_ids[:args.recog_n_caches], return_list=True) tokens_query = dataset.idx2token[0]( model.cache_ids[-n_tokens:], return_list=True) # Slide attention matrix n_keys = len(tokens_keys) n_queries = len(tokens_query) cache_probs = np.zeros( (n_keys, n_queries)) # `[n_keys, n_queries]` mask = np.zeros((n_keys, n_queries)) for i, aw in enumerate(model.cache_attn[-n_tokens:]): cache_probs[:(n_keys - n_queries + i + 1), i] = aw[0, -(n_keys - n_queries + i + 1):] mask[(n_keys - n_queries + i + 1):, i] = 1 plot_cache_weights(cache_probs, keys=tokens_keys, queries=tokens_query, save_path=mkdir_join( save_path, str(fig_count) + '.png'), figsize=(40, 16), mask=mask) toknen_count = 0 fig_count += 1 else: toknen_count += 1 if is_new_epoch: break
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) args_init = copy.deepcopy(args) args_teacher = copy.deepcopy(args) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k not in ['resume', 'local_rank']: setattr(args, k, v) args = compute_subsampling_factor(args) resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0 # Load dataset train_set = build_dataloader(args=args, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, batch_size=args.batch_size, batch_size_type=args.batch_size_type, max_n_frames=args.max_n_frames, resume_epoch=resume_epoch, sort_by=args.sort_by, short2long=args.sort_short2long, sort_stop_epoch=args.sort_stop_epoch, num_workers=args.workers, pin_memory=args.pin_memory, distributed=args.distributed, word_alignment_dir=args.train_word_alignment, ctc_alignment_dir=args.train_ctc_alignment) dev_set = build_dataloader( args=args, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, batch_size=1 if 'transducer' in args.dec_type else args.batch_size, batch_size_type='seq' if 'transducer' in args.dec_type else args.batch_size_type, max_n_frames=1600, word_alignment_dir=args.dev_word_alignment, ctc_alignment_dir=args.dev_ctc_alignment) eval_sets = [ build_dataloader(args=args, tsv_path=s, batch_size=1, is_test=True) for s in args.eval_sets ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Set save path if args.resume: args.save_path = os.path.dirname(args.resume) dir_name = os.path.basename(args.save_path) else: dir_name = set_asr_model_name(args) if args.mbr_training: assert args.asr_init args.save_path = mkdir_join(os.path.dirname(args.asr_init), dir_name) else: args.save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) if args.local_rank > 0: time.sleep(1) args.save_path = set_save_path(args.save_path) # avoid overwriting # Set logger set_logger(os.path.join(args.save_path, 'train.log'), args.stdout, args.local_rank) # Load a LM conf file for LM fusion & LM initialization if not args.resume and args.external_lm: lm_conf = load_config( os.path.join(os.path.dirname(args.external_lm), 'conf.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab # Model setting model = Speech2Text(args, args.save_path, train_set.idx2token[0]) if not args.resume: # Save nlsyms, dictionary, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(args.save_path, 'nlsyms.txt')) for sub in ['', '_sub1', '_sub2']: if args.get('dict' + sub): shutil.copy( args.get('dict' + sub), os.path.join(args.save_path, 'dict' + sub + '.txt')) if args.get('unit' + sub) == 'wp': shutil.copy( args.get('wp_model' + sub), os.path.join(args.save_path, 'wp' + sub + '.model')) for k, v in sorted(args.items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info('torch version: %s' % str(torch.__version__)) logger.info(model) # Initialize with pre-trained model's parameters if args.asr_init: # Load ASR model (full model) conf_init = load_config( os.path.join(os.path.dirname(args.asr_init), 'conf.yml')) for k, v in conf_init.items(): setattr(args_init, k, v) model_init = Speech2Text(args_init) load_checkpoint(args.asr_init, model_init) # Overwrite parameters param_dict = dict(model_init.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if args.asr_init_enc_only and 'enc' not in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer optimizer = set_optimizer( model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler is_transformer = 'former' in args.enc_type or 'former' in args.dec_type or 'former' in args.dec_type_sub1 scheduler = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, lower_better=args.metric not in ['accuracy', 'bleu'], warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, peak_lr=0.05 / (args.get('transformer_enc_d_model', 0)**0.5) if 'conformer' in args.enc_type else 1e6, model_size=args.get('transformer_enc_d_model', args.get('transformer_dec_d_model', 0)), factor=args.lr_factor, noam=args.optimizer == 'noam', save_checkpoints_topk=10 if is_transformer else 1) if args.resume: # Restore the last saved model load_checkpoint(args.resume, model, scheduler) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if resume_epoch == args.convert_to_sgd_epoch: scheduler.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) # Load teacher ASR model teacher = None if args.teacher: assert os.path.isfile(args.teacher), 'There is no checkpoint.' conf_teacher = load_config( os.path.join(os.path.dirname(args.teacher), 'conf.yml')) for k, v in conf_teacher.items(): setattr(args_teacher, k, v) # Setting for knowledge distillation args_teacher.ss_prob = 0 args.lsm_prob = 0 teacher = Speech2Text(args_teacher) load_checkpoint(args.teacher, teacher) # Load teacher LM teacher_lm = None if args.teacher_lm: assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.' conf_lm = load_config( os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) teacher_lm = build_lm(args_lm) load_checkpoint(args.teacher_lm, teacher_lm) # GPU setting args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"] amp, scaler = None, None if args.n_gpus >= 1: model.cudnn_setting( deterministic=((not is_transformer) and (not args.cudnn_benchmark)) or args.cudnn_deterministic, benchmark=(not is_transformer) and args.cudnn_benchmark) # Mixed precision training setting if args.use_apex: if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): scaler = torch.cuda.amp.GradScaler() else: from apex import amp model, scheduler.optimizer = amp.initialize( model, scheduler.optimizer, opt_level=args.train_dtype) from neural_sp.models.seq2seq.decoders.ctc import CTC amp.register_float_function(CTC, "loss_fn") # NOTE: see https://github.com/espnet/espnet/pull/1779 amp.init() if args.resume: load_checkpoint(args.resume, amp=amp) n = torch.cuda.device_count() // args.local_world_size device_ids = list(range(args.local_rank * n, (args.local_rank + 1) * n)) torch.cuda.set_device(device_ids[0]) model.cuda(device_ids[0]) scheduler.cuda(device_ids[0]) if args.distributed: model = DDP(model, device_ids=device_ids) else: model = CustomDataParallel(model, device_ids=list(range(args.n_gpus))) if teacher is not None: teacher.cuda() if teacher_lm is not None: teacher_lm.cuda() else: model = CPUWrapperASR(model) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('#GPU: %d' % torch.cuda.device_count()) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(args, model, args.local_rank) args.wandb_id = reporter.wandb_id if args.resume: n_steps = scheduler.n_steps * max( 1, args.accum_grad_n_steps // args.local_world_size) reporter.resume(n_steps, resume_epoch) # Save conf file as a yaml file if args.local_rank == 0: save_config(args, os.path.join(args.save_path, 'conf.yml')) if args.external_lm: save_config(args.lm_conf, os.path.join(args.save_path, 'conf_lm.yml')) # NOTE: save after reporter for wandb ID # Define tasks if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if args.total_weight - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks if args.mbr_ce_weight > 0: tasks = ['ys.mbr'] + tasks for sub in ['sub1', 'sub2']: if args.get('train_set_' + sub) is not None: if args.get(sub + '_weight', 0) - args.get( 'ctc_weight_' + sub, 0) > 0: tasks = ['ys_' + sub] + tasks if args.get('ctc_weight_' + sub, 0) > 0: tasks = ['ys_' + sub + '.ctc'] + tasks else: tasks = ['all'] if args.get('ss_start_epoch', 0) <= resume_epoch: model.module.trigger_scheduled_sampling() if args.get('mocha_quantity_loss_start_epoch', 0) <= resume_epoch: model.module.trigger_quantity_loss() start_time_train = time.time() for ep in range(resume_epoch, args.n_epochs): train_one_epoch(model, train_set, dev_set, eval_sets, scheduler, reporter, logger, args, amp, scaler, tasks, teacher, teacher_lm) # Save checkpoint and validate model per epoch if reporter.n_epochs + 1 < args.eval_start_epoch: scheduler.epoch() # lr decay reporter.epoch() # plot # Save model if args.local_rank == 0: scheduler.save_checkpoint(model, args.save_path, amp=amp, remove_old=(not is_transformer) and args.remove_old_checkpoints) else: start_time_eval = time.time() # dev metric_dev = validate([model.module], dev_set, args, reporter.n_epochs + 1, logger) scheduler.epoch(metric_dev) # lr decay reporter.epoch(metric_dev, name=args.metric) # plot reporter.add_scalar('dev/' + args.metric, metric_dev) if scheduler.is_topk or is_transformer: # Save model if args.local_rank == 0: scheduler.save_checkpoint(model, args.save_path, amp=amp, remove_old=(not is_transformer) and args.remove_old_checkpoints) # test if scheduler.is_topk: for eval_set in eval_sets: validate([model.module], eval_set, args, reporter.n_epochs, logger) logger.info('Evaluation time: %.2f min' % ((time.time() - start_time_eval) / 60)) # Early stopping if scheduler.is_early_stop: break # Convert to fine-tuning stage if reporter.n_epochs == args.convert_to_sgd_epoch: scheduler.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) if reporter.n_epochs >= args.n_epochs: break if args.get('ss_start_epoch', 0) == (ep + 1): model.module.trigger_scheduled_sampling() if args.get('mocha_quantity_loss_start_epoch', 0) == (ep + 1): model.module.trigger_quantity_loss() logger.info('Total time: %.2f hour' % ((time.time() - start_time_train) / 3600)) reporter.close() return args.save_path
def eval_word(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False): """Evaluate the word-level model by WER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate n_oov_total (int): totol number of OOV """ # Reset data counter dataset.reset(recog_params['recog_batch_size']) if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_oov_total = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, aws = models[0].decode( batch['xs'], recog_params, idx2token=dataset.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataset.idx2token[0](best_hyps_id[b]) n_oov_total += hyp.count('<unk>') # Resolving UNK if recog_params['recog_resolving_unk'] and '<unk>' in hyp: recog_params_char = copy.deepcopy(recog_params) recog_params_char['recog_lm_weight'] = 0 recog_params_char['recog_beam_width'] = 1 best_hyps_id_char, aw_char = models[0].decode( batch['xs'][b:b + 1], recog_params_char, idx2token=dataset.idx2token[1] if progressbar else None, exclude_eos=True, refs_id=batch['ys_sub1'], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], task='ys_sub1') # TODO(hirofumi): support ys_sub2 and ys_sub3 assert not streaming hyp = resolve_unk( hyp, best_hyps_id_char[0], aws[b], aw_char[0], dataset.idx2token[1], subsample_factor_word=np.prod(models[0].subsample), subsample_factor_char=np.prod( models[0].subsample[:models[0].enc_n_layers_sub1 - 1])) logger.debug('Hyp (after OOV resolution): %s' % hyp) hyp = hyp.replace('*', '') # Compute CER ref_char = ref hyp_char = hyp if dataset.corpus == 'csj': ref_char = ref.replace(' ', '') hyp_char = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref_char), hyp=list(hyp_char), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref_char) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word if n_char > 0: cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logger.debug('WER (%s): %.2f %%' % (dataset.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataset.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.debug('OOV (total): %d' % (n_oov_total)) return wer, cer, n_oov_total
def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'align.log')): os.remove(os.path.join(args.recog_dir, 'align.log')) set_logger(os.path.join(args.recog_dir, 'align.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Align all utterances args.min_n_frames = 0 args.max_n_frames = 1e5 # Load dataloader dataloader = build_dataloader(args=args, tsv_path=s, batch_size=args.recog_batch_size) if i == 0: # Load ASR model model = Speech2Text(args, dir_name) epoch = int(args.recog_model[0].split('-')[-1]) if args.recog_n_average > 1: # Model averaging for Transformer model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() save_path = mkdir_join(args.recog_dir, 'ctc_forced_alignments') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) pbar = tqdm(total=len(dataloader)) while True: batch, is_new_epoch = dataloader.next() trigger_points = model.ctc_forced_align(batch['xs'], batch['ys']) # `[B, L]` for b in range(len(batch['xs'])): save_path_spk = mkdir_join(save_path, batch['speakers'][b]) save_path_utt = mkdir_join(save_path_spk, batch['utt_ids'][b] + '.txt') tokens = dataloader.idx2token[0](batch['ys'][b], return_list=True) with codecs.open(save_path_utt, 'w', encoding="utf-8") as f: for i, tok in enumerate(tokens): f.write('%s %d\n' % (tok, trigger_points[b, i])) f.write('%s %d\n' % ('<eos>', trigger_points[b, len(tokens)])) pbar.update(len(batch['xs'])) if is_new_epoch: break pbar.close()
def eval_word(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a word-level model by WER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch recog_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) fine_grained (bool): calculate fine-grained WER distributions based on input lengths oracle (bool): calculate oracle WER teacher_force (bool): conduct decoding in teacher-forcing mode Returns: wer (float): Word error rate cer (float): Character error rate n_oov_total (int): total number of OOV """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(recog_params.get('recog_beam_width')) recog_dir += '_lp' + str(recog_params.get('recog_length_penalty')) recog_dir += '_cp' + str(recog_params.get('recog_coverage_penalty')) recog_dir += '_' + str(recog_params.get('recog_min_len_ratio')) + '_' + \ str(recog_params.get('recog_max_len_ratio')) recog_dir += '_lm' + str(recog_params.get('recog_lm_weight')) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 wer_dist = {} # calculate WER distribution based on input lengths n_oov_total = 0 wer_oracle = 0 n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(recog_params.get('recog_batch_size')) if progressbar: pbar = tqdm(total=len(dataloader)) with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: for batch in dataloader: speakers = batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'] if streaming or recog_params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True, speaker=speakers[0])[0] else: nbest_hyps_id, aws = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=speakers, ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] n_oov_total += nbest_hyps[0].count('<unk>') # Resolving UNK if recog_params.get( 'recog_resolving_unk') and '<unk>' in nbest_hyps[0]: recog_params_char = copy.deepcopy(recog_params) recog_params_char['recog_lm_weight'] = 0 recog_params_char['recog_beam_width'] = 1 best_hyps_id_char, aw_char = models[0].decode( batch['xs'][b:b + 1], recog_params_char, idx2token=dataloader.idx2token[1], exclude_eos=True, refs_id=batch['ys_sub1'], utt_ids=batch['utt_ids'], speakers=speakers, task='ys_sub1') # TODO(hirofumi): support ys_sub2 assert not streaming nbest_hyps[0] = resolve_unk( nbest_hyps[0], best_hyps_id_char[0], aws[b], aw_char[0], dataloader.idx2token[1], subsample_factor_word=np.prod(models[0].subsample), subsample_factor_char=np.prod( models[0].subsample[:models[0].enc_n_layers_sub1 - 1])) logger.debug('Hyp (after OOV resolution): %s' % nbest_hyps[0]) nbest_hyps[0] = nbest_hyps[0].replace('*', '') # Compute CER ref_char = ref hyp_char = nbest_hyps[0] if dataloader.corpus == 'csj': ref_char = ref_char.replace(' ', '') hyp_char = hyp_char.replace(' ', '') err_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref_char), hyp=list(hyp_char)) cer += err_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref_char) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: # Compute WER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) wer += err_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute oracle WER if oracle and len(nbest_hyps) > 1: wers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(wers_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) wer_oracle += wers_b[oracle_idx] # NOTE: OOV resolution is not considered if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in wer_dist.keys(): wer_dist[xlen_bin] += [err_b / 100] else: wer_dist[xlen_bin] = [err_b / 100] n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) if edit_distance and not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word if n_char > 0: cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if recog_params.get('recog_beam_width') > 1: logger.info('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('OOV (total): %d' % (n_oov_total)) if oracle: wer_oracle /= n_word oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle WER (%s): %.2f %%' % (dataloader.set, wer_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]): logger.info(' WER (%s): %.2f %% (%d)' % (dataloader.set, sum(wers) / len(wers), len_bin)) return wer, cer, n_oov_total
def eval_wordpiece(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False): """Evaluate the wordpiece-level model by WER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate """ # Reset data counter dataset.reset() if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, _ = models[0].decode( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] if ref[0] == '<': ref = ref.split('>')[1] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute CER if dataset.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logger.debug('WER (%s): %.2f %%' % (dataset.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataset.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) return wer, cer
def eval_phone(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a phone-level model by PER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar oracle (bool): calculate oracle PER fine_grained (bool): calculate fine-grained PER distributions based on input lengths teacher_force (bool): conduct decoding in teacher-forcing mode Returns: per (float): Phone error rate """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') per = 0 n_sub, n_ins, n_del = 0, 0, 0 n_phone = 0 per_dist = {} # calculate PER distribution based on input lengths per_oracle = 0 n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(recog_params['recog_batch_size']) if progressbar: pbar = tqdm(total=len(dataloader)) with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: while True: batch, is_new_epoch = dataloader.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_block_sync']: nbest_hyps_id = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True)[0] else: nbest_hyps_id = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else [])[0] for b in range(len(batch['xs'])): ref = batch['text'][b] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if not streaming: # Compute PER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) per += err_b n_sub += sub_b n_ins += ins_b n_del += del_b n_phone += len(ref.split(' ')) # Compute oracle PER if oracle and len(nbest_hyps) > 1: pers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(pers_b)) if oracle_idx == 0: n_oracle_hit += 1 per_oracle += pers_b[oracle_idx] n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataloader.reset() if not streaming: per /= n_phone n_sub /= n_phone n_ins /= n_phone n_del /= n_phone if recog_params['recog_beam_width'] > 1: logger.info('PER (%s): %.2f %%' % (dataloader.set, per)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del)) if oracle: per_oracle /= n_phone oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle PER (%s): %.2f %%' % (dataloader.set, per_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, pers in sorted(per_dist.items(), key=lambda x: x[0]): logger.info(' PER (%s): %.2f %% (%d)' % (dataloader.set, sum(pers) / len(pers), len_bin)) return per
def main(): args = parse_args_train(sys.argv[1:]) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Load dataset batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=batch_size, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, shuffle=args.shuffle, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=batch_size, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) for s in args.eval_sets ] args.vocab = train_set.vocab # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_lm_name(args) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout) # Model setting model = build_lm(args, save_path) if not args.resume: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) # Save the nlsyms, dictionary, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer resume_epoch = 0 if args.resume: epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer( model, 'sgd' if epoch > args.convert_to_sgd_epoch else args.optimizer, args.lr, args.weight_decay) else: optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler is_transformer = args.lm_type in ['transformer', 'transformer_xl'] optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=getattr(args, 'transformer_d_model', 0), factor=args.lr_factor, noam=is_transformer, save_checkpoints_topk=1) if args.resume: # Restore the last saved model load_checkpoint(args.resume, model, optimizer) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if resume_epoch == args.convert_to_sgd_epoch: optimizer.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) # GPU setting use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"] amp = None if args.n_gpus >= 1: model.cudnn_setting( deterministic=not (is_transformer or args.cudnn_benchmark), benchmark=args.cudnn_benchmark) model.cuda() # Mix precision training setting if use_apex: from apex import amp model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) amp.init() if args.resume: load_checkpoint(args.resume, amp=amp) model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus))) else: model = CPUWrapperLM(model) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('#GPU: %d' % torch.cuda.device_count()) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() pbar_epoch = tqdm(total=len(train_set)) accum_n_steps = 0 n_steps = optimizer.n_steps * args.accum_grad_n_steps while True: # Compute loss in the training set ys_train, is_new_epoch = train_set.next() accum_n_steps += 1 loss, hidden, observation = model(ys_train, hidden) reporter.add(observation) if use_apex: with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_steps == 1 or accum_n_steps >= args.accum_grad_n_steps: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.module.parameters(), args.clip_grad_norm) reporter.add_tensorboard_scalar('total_norm', total_norm) optimizer.step() optimizer.zero_grad() accum_n_steps = 0 loss_train = loss.item() del loss hidden = model.module.repackage_state(hidden) reporter.add_tensorboard_scalar('learning_rate', optimizer.lr) # NOTE: loss/acc/ppl are already added in the model reporter.step() pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1)) n_steps += 1 # NOTE: n_steps is different from the step counter in Noam Optimizer if n_steps % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next(bptt=args.bptt)[0] loss, _, observation = model(ys_dev, None, is_eval=True) reporter.add(observation, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, optimizer.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() # Save fugures of loss and accuracy if n_steps % (args.print_step * 10) == 0: reporter.snapshot() model.module.plot_attention() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() # lr decay reporter.epoch() # plot # Save the model optimizer.save_checkpoint(model, save_path, remove_old=not is_transformer, amp=amp) else: start_time_eval = time.time() # dev model.module.reset_length(args.bptt) ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) model.module.reset_length(args.bptt) optimizer.epoch(ppl_dev) # lr decay reporter.epoch(ppl_dev, name='perplexity') # plot logger.info('PPL (%s, ep:%d): %.2f' % (dev_set.set, optimizer.n_epochs, ppl_dev)) if optimizer.is_topk or is_transformer: # Save the model optimizer.save_checkpoint(model, save_path, remove_old=not is_transformer, amp=amp) # test ppl_test_avg = 0. for eval_set in eval_sets: model.module.reset_length(args.bptt) ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) model.module.reset_length(args.bptt) logger.info( 'PPL (%s, ep:%d): %.2f' % (eval_set.set, optimizer.n_epochs, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info('PPL (avg., ep:%d): %.2f' % (optimizer.n_epochs, ppl_test_avg / len(eval_sets))) duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: optimizer.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) pbar_epoch = tqdm(total=len(train_set)) if optimizer.n_epochs >= args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) reporter.tf_writer.close() pbar_epoch.close() return save_path
def eval_wordpiece(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, fine_grained=False): """Evaluate the wordpiece-level model by WER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar fine_grained (bool): calculate fine-grained WER distributions based on input lengths Returns: wer (float): Word error rate cer (float): Character error rate """ # Reset data counter dataset.reset(recog_params['recog_batch_size']) if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 if progressbar: pbar = tqdm(total=len(dataset)) # calculate WER distribution based on input lengths wer_dist = {} with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, _ = models[0].decode( batch['xs'], recog_params, idx2token=dataset.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] if ref[0] == '<': ref = ref.split('>')[1] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in wer_dist.keys(): wer_dist[xlen_bin] += [wer_b / 100] else: wer_dist[xlen_bin] = [wer_b / 100] # Compute CER if dataset.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if models[0].streamable(): n_streamable += 1 else: last_success_frame_ratio += models[0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt if fine_grained: for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]): logger.info(' WER (%s): %.2f %% (%d)' % (dataset.set, sum(wers) / len(wers), len_bin)) logger.debug('WER (%s): %.2f %%' % (dataset.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataset.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('Streamablility (%s): %.2f %%' % (dataset.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataset.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataset.set, last_success_frame_ratio)) return wer, cer
def eval_wordpiece_bleu(models, dataloader, params, epoch=-1, rank=0, save_dir=None, streaming=False, progressbar=False, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a wordpiece-level model by corpus-level BLEU. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch rank (int): rank of current process group save_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) fine_grained (bool): calculate fine-grained corpus-level BLEU distributions based on input lengths oracle (bool): calculate oracle corpsu-level BLEU teacher_force (bool): conduct decoding in teacher-forcing mode Returns: c_bleu (float): corpus-level 4-gram BLEU """ if save_dir is None: save_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(params.get('recog_beam_width')) save_dir += '_lp' + str(params.get('recog_length_penalty')) save_dir += '_cp' + str(params.get('recog_coverage_penalty')) save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \ str(params.get('recog_max_len_ratio')) save_dir += '_lm' + str(params.get('recog_lm_weight')) ref_trn_path = mkdir_join(models[0].save_path, save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(models[0].save_path, save_dir, 'hyp.trn', rank=rank) else: ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank) list_of_references_dist = { } # calculate corpus-level BLEU distribution bucketed by input lengths hypotheses_dist = {} hypotheses_oracle = [] n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(params.get('recog_batch_size'), 'seq') if progressbar: pbar = tqdm(total=len(dataloader)) list_of_references = [] hypotheses = [] if rank == 0: f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8') f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8') for batch in dataloader: if streaming or params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming( batch['xs'], params, dataloader.idx2token[0], exclude_eos=True, speaker=batch['speakers'][0])[0] else: nbest_hyps_id = models[0].decode( batch['xs'], params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['speakers'], ensemble_models=models[1:] if len(models) > 1 else [], teacher_force=teacher_force)[0] for b in range(len(batch['xs'])): ref = batch['text'][b] if ref[0] == '<': ref = ref.split('>')[1] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) if rank == 0: f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: list_of_references += [[ref.split(' ')]] hypotheses += [nbest_hyps[0].split(' ')] if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in hypotheses_dist.keys(): list_of_references_dist[xlen_bin] += [[ref.split(' ')]] hypotheses_dist[xlen_bin] += [hypotheses[-1]] else: list_of_references_dist[xlen_bin] = [[ref.split(' ')]] hypotheses_dist[xlen_bin] = [hypotheses[-1]] # Compute oracle corpus-level BLEU (selected by sentence-level BLEU) if oracle and len(nbest_hyps) > 1: s_blues_b = [ sentence_bleu(ref.split(' '), hyp_n.split(' ')) for hyp_n in nbest_hyps ] oracle_idx = np.argmax(np.array(s_blues_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) hypotheses_oracle += [nbest_hyps[oracle_idx].split(' ')] n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if rank == 0: f_hyp.close() f_ref.close() if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) c_bleu = corpus_bleu(list_of_references, hypotheses) * 100 if edit_distance and not streaming: if oracle: c_bleu_oracle = corpus_bleu(list_of_references, hypotheses_oracle) * 100 oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle corpus-level BLEU (%s): %.2f %%' % (dataloader.set, c_bleu_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, hypotheses_bin in sorted(hypotheses_dist.items(), key=lambda x: x[0]): c_bleu_bin = corpus_bleu(list_of_references_dist[len_bin], hypotheses_bin) * 100 logger.info(' corpus-level BLEU (%s): %.2f %% (%d)' % (dataloader.set, c_bleu_bin, len_bin)) logger.info('Corpus-level BLEU (%s): %.2f %%' % (dataloader.set, c_bleu)) return c_bleu
def main(): args = parse() # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) ] args.vocab = train_set.vocab # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = make_model_name(args) save_path = mkdir_join( args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training') # Model setting if 'gated_conv' in args.lm_type: model = GatedConvLM(args) else: model = RNNLM(args) model.save_path = save_path if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) model.set_optimizer( optimizer='sgd' if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'], learning_rate=float(conf['learning_rate']), # on-the-fly weight_decay=float(conf['weight_decay'])) # Restore the last saved model model, checkpoint = load_checkpoint(model, args.resume, resume=True) lr_controller = checkpoint['lr_controller'] epoch = checkpoint['epoch'] step = checkpoint['step'] ppl_dev_best = checkpoint['metric_dev_best'] # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1 if epoch == conf['convert_to_sgd_epoch'] + 1: model.set_optimizer(optimizer='sgd', learning_rate=args.learning_rate, weight_decay=float(conf['weight_decay'])) logger.info('========== Convert to SGD ==========') else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(model.save_path, 'conf.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(model.save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(model.save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): nparams = model.num_params_dict[n] logger.info("%s %d" % (n, nparams)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate=float(args.learning_rate), weight_decay=float(args.weight_decay)) epoch, step = 1, 1 ppl_dev_best = 10000 # Set learning rate controller lr_controller = Controller( learning_rate=float(args.learning_rate), decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_n_epochs=args.decay_patient_n_epochs, lower_better=True, best_value=ppl_dev_best) train_set.epoch = epoch - 1 # start from index:0 # GPU setting if args.n_gpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus, 1)), deterministic=False, benchmark=True) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name if args.job_name: setproctitle(args.job_name) else: setproctitle(dir_name) # Set reporter reporter = Reporter(model.module.save_path, tensorboard=True) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0 pbar_epoch = tqdm(total=len(train_set)) while True: # Compute loss in the training set ys_train, is_new_epoch = train_set.next() model.module.optimizer.zero_grad() loss, hidden, reporter = model(ys_train, hidden, reporter) if len(model.device_ids) > 1: loss.backward(torch.ones(len(model.device_ids))) else: loss.backward() loss.detach() # Trancate the graph if args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.clip_grad_norm) model.module.optimizer.step() loss_train = loss.item() del loss if 'gated_conv' not in args.lm_type: hidden = model.module.repackage_hidden(hidden) reporter.step(is_eval=False) if step % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] loss, _, reporter = model(ys_dev, None, reporter, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (step, train_set.epoch_detail, loss_train, loss_dev, np.exp(loss_train), np.exp(loss_dev), lr_controller.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() step += args.n_gpus pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1)) # Save fugures of loss and accuracy if step % (args.print_step * 10) == 0: reporter.snapshot() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (epoch, duration_epoch / 60)) if epoch < args.eval_start_epoch: # Save the model save_checkpoint(model.module, model.module.save_path, lr_controller, epoch, step - 1, ppl_dev_best, remove_old_checkpoints=True) else: start_time_eval = time.time() # dev ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev)) # Update learning rate model.module.optimizer = lr_controller.decay( model.module.optimizer, epoch=epoch, value=ppl_dev) if ppl_dev < ppl_dev_best: ppl_dev_best = ppl_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Save the model save_checkpoint(model.module, model.module.save_path, lr_controller, epoch, step - 1, ppl_dev_best, remove_old_checkpoints=True) # test ppl_test_avg = 0. for eval_set in eval_sets: ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s): %.2f' % (eval_set.set, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info('PPL (avg.): %.2f' % (ppl_test_avg / len(eval_sets))) else: not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_n_epochs: break # Convert to fine-tuning stage if epoch == args.convert_to_sgd_epoch: model.module.set_optimizer( 'sgd', learning_rate=args.learning_rate, weight_decay=float(args.weight_decay)) lr_controller = Controller( learning_rate=args.learning_rate, decay_type='epoch', decay_start_epoch=epoch, decay_rate=0.5, lower_better=True) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if epoch == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) if reporter.tensorboard: reporter.tf_writer.close() pbar_epoch.close() return model.module.save_path
def main(): args = parse() # Load a conf file dir_name = os.path.dirname(args.recog_model[0]) conf = load_config(os.path.join(dir_name, 'conf.yml')) # Overwrite conf for k, v in conf.items(): if 'recog' not in k: setattr(args, k, v) recog_params = vars(args) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding', stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): subsample_factor = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings: for p in args.conv_poolings.split('_'): p = int(p.split(',')[0].replace('(', '')) if p > 1: subsample_factor *= p subsample_factor *= np.prod(subsample) # Load dataset dataset = Dataset(corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile( os.path.join(dir_name, 'dict_sub1.txt')) else False, nlsyms=args.nlsyms, wp_model=os.path.join(dir_name, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, batch_size=args.recog_batch_size, is_test=True) if i == 0: # Load the ASR model model = Speech2Text(args, dir_name) model = load_checkpoint(model, args.recog_model[0])[0] epoch = int(args.recog_model[0].split('-')[-1]) if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('epoch: %d' % (epoch - 1)) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting model.cuda() save_path = mkdir_join(args.recog_dir, 'ctc_probs') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = dataset.next(recog_params['recog_batch_size']) best_hyps_id, _, _ = model.decode(batch['xs'], recog_params, exclude_eos=False) # Get CTC probs ctc_probs, indices_topk, xlens = model.get_ctc_probs( batch['xs'], temperature=1, topk=min(100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_ctc_probs( ctc_probs[b, :xlens[b]], indices_topk[b], n_frames=xlens[b], subsample_factor=subsample_factor, spectrogram=batch['xs'][b][:, :dataset.input_dim], save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def eval_char(models, dataloader, params, epoch=-1, rank=0, save_dir=None, streaming=False, progressbar=False, task_idx=0, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a character-level model by WER & CER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch rank (int): rank of current process group save_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) task_idx (int): index of target task in interest 0: main task 1: sub task 2: sub sub task fine_grained (bool): calculate fine-grained WER distributions based on input lengths oracle (bool): calculate oracle WER teacher_force (bool): conduct decoding in teacher-forcing mode Returns: wer (float): Word error rate cer (float): Character error rate """ if save_dir is None: save_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(params.get('recog_beam_width')) save_dir += '_lp' + str(params.get('recog_length_penalty')) save_dir += '_cp' + str(params.get('recog_coverage_penalty')) save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \ str(params.get('recog_max_len_ratio')) save_dir += '_lm' + str(params.get('recog_lm_weight')) ref_trn_path = mkdir_join(models[0].save_path, save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(models[0].save_path, save_dir, 'hyp.trn', rank=rank) else: ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank) wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 cer_dist = {} # calculate CER distribution based on input lengths cer_oracle = 0 n_oracle_hit = 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 # Reset data counter dataloader.reset(params.get('recog_batch_size'), 'seq') if progressbar: pbar = tqdm(total=len(dataloader)) if rank == 0: f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8') f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8') if task_idx == 0: task = 'ys' elif task_idx == 1: task = 'ys_sub1' elif task_idx == 2: task = 'ys_sub2' elif task_idx == 3: task = 'ys_sub3' for batch in dataloader: speakers = batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'] if streaming or params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming(batch['xs'], params, dataloader.idx2token[0], exclude_eos=True, speaker=speakers[0])[0] else: nbest_hyps_id = models[0].decode( batch['xs'], params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)], utt_ids=batch['utt_ids'], speakers=speakers, task=task, ensemble_models=models[1:] if len(models) > 1 else [], teacher_force=teacher_force)[0] for b in range(len(batch['xs'])): # assert len(batch['xs']) == 1, 'batch is 1' ref = batch['text'][b] nbest_hyps_tmp = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # print(nbest_hyps_id) # print(nbest_hyps_tmp) # assert False, 'vv' # Truncate the first and last spaces for the char_space unit nbest_hyps = [] for hyp in nbest_hyps_tmp: if len(hyp) > 0 and hyp[0] == ' ': hyp = hyp[1:] if len(hyp) > 0 and hyp[-1] == ' ': hyp = hyp[:-1] nbest_hyps.append(hyp) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) if rank == 0: f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit ) or (task_idx > 0 and dataloader.unit_sub1 == 'char'): # Compute WER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) wer += err_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # NOTE: sentence error rate for Chinese # Compute CER if dataloader.corpus == 'csj': ref = ref.replace(' ', '') nbest_hyps[0] = nbest_hyps[0].replace(' ', '') err_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list( nbest_hyps[0])) cer += err_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) # Compute oracle CER if oracle and len(nbest_hyps) > 1: cers_b = [err_b] + [ compute_wer(ref=list(ref), hyp=list(hyp_n))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(cers_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) cer_oracle += cers_b[oracle_idx] if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in cer_dist.keys(): cer_dist[xlen_bin] += [err_b / 100] else: cer_dist[xlen_bin] = [err_b / 100] if models[0].streamable(): n_streamable += len(batch['utt_ids']) else: last_success_frame_ratio += models[ 0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if rank == 0: f_hyp.close() f_ref.close() if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) if edit_distance and not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or ( task_idx > 0 and dataloader.unit_sub1 == 'char'): wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word else: wer = n_sub_w = n_ins_w = n_del_w = 0 cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt if params.get('recog_beam_width') > 1: logger.info('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) if oracle: cer_oracle /= n_char oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle CER (%s): %.2f %%' % (dataloader.set, cer_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, cers in sorted(cer_dist.items(), key=lambda x: x[0]): logger.info(' CER (%s): %.2f %% (%d)' % (dataloader.set, sum(cers) / len(cers), len_bin)) logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio)) return wer, cer
def eval_phone(models, dataset, recog_params, epoch, recog_dir=None, progressbar=False): """Evaluate a phone-level model by PER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): progressbar (bool): visualize the progressbar Returns: per (float): Phone error rate """ # Reset data counter dataset.reset() if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') per = 0 n_sub, n_ins, n_del = 0, 0, 0 n_phone = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.getitem( recog_params['recog_batch_size']) best_hyps_id, _, _ = models[0].decode( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn utt_id = str(batch['utt_ids'][b]) speaker = str(batch['speakers'][b]).replace('-', '_') f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 150) # Compute PER per_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) per += per_b n_sub += sub_b n_ins += ins_b n_del += del_b n_phone += len(ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() per /= n_phone n_sub /= n_phone n_ins /= n_phone n_del /= n_phone logger.info('PER (%s): %.2f %%' % (dataset.set, per)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del)) return per
def eval_wordpiece_bleu(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, fine_grained=False): """Evaluate the wordpiece-level model by BLEU. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar fine_grained (bool): calculate fine-grained BLEU distributions based on input lengths Returns: bleu (float): 4-gram BLEU """ if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') s_bleu = 0 n_sentence = 0 s_bleu_dist = { } # calculate sentence-level BLEU distribution based on input lengths # Reset data counter dataset.reset(recog_params['recog_batch_size']) if progressbar: pbar = tqdm(total=len(dataset)) list_of_references = [] hypotheses = [] with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, _ = models[0].decode( batch['xs'], recog_params, idx2token=dataset.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] if ref[0] == '<': ref = ref.split('>')[1] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn # speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + '\n') f_hyp.write(hyp + '\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: list_of_references += [[ref.split(' ')]] hypotheses += [hyp.split(' ')] n_sentence += 1 # Compute sentence-level BLEU if fine_grained: s_bleu_b = sentence_bleu([ref.split(' ')], hyp.split(' ')) s_bleu += s_bleu_b * 100 xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in s_bleu_dist.keys(): s_bleu_dist[xlen_bin] += [s_bleu_b / 100] else: s_bleu_dist[xlen_bin] = [s_bleu_b / 100] if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() c_bleu = corpus_bleu(list_of_references, hypotheses) * 100 if not streaming and fine_grained: s_bleu /= n_sentence for len_bin, s_bleus in sorted(s_bleu_dist.items(), key=lambda x: x[0]): logger.info(' sentence-level BLEU (%s): %.2f %% (%d)' % (dataset.set, sum(s_bleus) / len(s_bleus), len_bin)) logger.debug('Corpus-level BLEU (%s): %.2f %%' % (dataset.set, c_bleu)) return c_bleu
def main(): args = parse_args_train(sys.argv[1:]) args_init = copy.deepcopy(args) args_teacher = copy.deepcopy(args) # Load a conf file if args.resume: conf = load_config(os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) recog_params = vars(args) args = compute_susampling_factor(args) # Load dataset batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=batch_size, n_epochs=args.n_epochs, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, shuffle_bucket=args.shuffle_bucket, sort_by='input', short2long=args.sort_short2long, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=args.dynamic_batching, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=args.subsample_factor, subsample_factor_sub1=args.subsample_factor_sub1, subsample_factor_sub2=args.subsample_factor_sub2, discourse_aware=args.discourse_aware) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=batch_size, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=args.subsample_factor, subsample_factor_sub1=args.subsample_factor_sub1, subsample_factor_sub2=args.subsample_factor_sub2) eval_sets = [Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, is_test=True) for s in args.eval_sets] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_asr_model_name(args) if args.mbr_training: assert args.asr_init save_path = mkdir_join(os.path.dirname(args.asr_init), dir_name) else: save_path = mkdir_join(args.model_save_dir, '_'.join( os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout) # Load a LM conf file for LM fusion & LM initialization if not args.resume and args.external_lm: lm_conf = load_config(os.path.join(os.path.dirname(args.external_lm), 'conf.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab # Model setting model = Speech2Text(args, save_path, train_set.idx2token[0]) if not args.resume: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) if args.external_lm: save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml')) # Save the nlsyms, dictionary, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) for sub in ['', '_sub1', '_sub2']: if getattr(args, 'dict' + sub): shutil.copy(getattr(args, 'dict' + sub), os.path.join(save_path, 'dict' + sub + '.txt')) if getattr(args, 'unit' + sub) == 'wp': shutil.copy(getattr(args, 'wp_model' + sub), os.path.join(save_path, 'wp' + sub + '.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Initialize with pre-trained model's parameters if args.asr_init: # Load the ASR model (full model) conf_init = load_config(os.path.join(os.path.dirname(args.asr_init), 'conf.yml')) for k, v in conf_init.items(): setattr(args_init, k, v) model_init = Speech2Text(args_init) load_checkpoint(args.asr_init, model_init) # Overwrite parameters param_dict = dict(model_init.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if args.asr_init_enc_only and 'enc' not in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer resume_epoch = 0 if args.resume: resume_epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer(model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer, args.lr, args.weight_decay) else: optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler is_transformer = 'former' in args.enc_type or 'former' in args.dec_type optimizer = LRScheduler(optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, lower_better=args.metric not in ['accuracy', 'bleu'], warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=getattr(args, 'transformer_d_model', 0), factor=args.lr_factor, noam=args.optimizer == 'noam', save_checkpoints_topk=10 if is_transformer else 1) if args.resume: # Restore the last saved model load_checkpoint(args.resume, model, optimizer) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if resume_epoch == args.convert_to_sgd_epoch: optimizer.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) # Load the teacher ASR model teacher = None if args.teacher: assert os.path.isfile(args.teacher), 'There is no checkpoint.' conf_teacher = load_config(os.path.join(os.path.dirname(args.teacher), 'conf.yml')) for k, v in conf_teacher.items(): setattr(args_teacher, k, v) # Setting for knowledge distillation args_teacher.ss_prob = 0 args.lsm_prob = 0 teacher = Speech2Text(args_teacher) load_checkpoint(args.teacher, teacher) # Load the teacher LM teacher_lm = None if args.teacher_lm: assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.' conf_lm = load_config(os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) teacher_lm = build_lm(args_lm) load_checkpoint(args.teacher_lm, teacher_lm) # GPU setting use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"] amp = None if args.n_gpus >= 1: model.cudnn_setting(deterministic=not (is_transformer or args.cudnn_benchmark), benchmark=not is_transformer and args.cudnn_benchmark) model.cuda() # Mix precision training setting if use_apex: from apex import amp model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer, opt_level=args.train_dtype) from neural_sp.models.seq2seq.decoders.ctc import CTC amp.register_float_function(CTC, "loss_fn") # NOTE: see https://github.com/espnet/espnet/pull/1779 amp.init() if args.resume: load_checkpoint(args.resume, amp=amp) model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus))) if teacher is not None: teacher.cuda() if teacher_lm is not None: teacher_lm.cuda() else: model = CPUWrapperASR(model) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('#GPU: %d' % torch.cuda.device_count()) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path) if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks if args.mbr_ce_weight > 0: tasks = ['ys.mbr'] + tasks for sub in ['sub1', 'sub2']: if getattr(args, 'train_set_' + sub): if getattr(args, sub + '_weight') - getattr(args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub] + tasks if getattr(args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub + '.ctc'] + tasks else: tasks = ['all'] start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() accum_n_steps = 0 n_steps = optimizer.n_steps * args.accum_grad_n_steps epoch_detail_prev = 0 for ep in range(resume_epoch, args.n_epochs): pbar_epoch = tqdm(total=len(train_set)) session_prev = None for batch_train, is_new_epoch in train_set: # Compute loss in the training set if args.discourse_aware and batch_train['sessions'][0] != session_prev: model.module.reset_session() session_prev = batch_train['sessions'][0] accum_n_steps += 1 # Change mini-batch depending on task if accum_n_steps == 1: loss_train = 0 # moving average over gradient accumulation for task in tasks: loss, observation = model(batch_train, task, teacher=teacher, teacher_lm=teacher_lm) reporter.add(observation) if use_apex: with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() loss.detach() # Trancate the graph loss_train = (loss_train * (accum_n_steps - 1) + loss.item()) / accum_n_steps if accum_n_steps >= args.accum_grad_n_steps or is_new_epoch: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.module.parameters(), args.clip_grad_norm) reporter.add_tensorboard_scalar('total_norm', total_norm) optimizer.step() optimizer.zero_grad() accum_n_steps = 0 # NOTE: parameters are forcibly updated at the end of every epoch del loss pbar_epoch.update(len(batch_train['utt_ids'])) reporter.add_tensorboard_scalar('learning_rate', optimizer.lr) # NOTE: loss/acc/ppl are already added in the model reporter.step() n_steps += 1 # NOTE: n_steps is different from the step counter in Noam Optimizer if n_steps % args.print_step == 0: # Compute loss in the dev set batch_dev = iter(dev_set).next(batch_size=1 if 'transducer' in args.dec_type else None)[0] # Change mini-batch depending on task for task in tasks: loss, observation = model(batch_dev, task, is_eval=True) reporter.add(observation, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step if args.input_type == 'speech': xlen = max(len(x) for x in batch_train['xs']) ylen = max(len(y) for y in batch_train['ys']) elif args.input_type == 'text': xlen = max(len(x) for x in batch_train['ys']) ylen = max(len(y) for y in batch_train['ys_sub1']) logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.7f/bs:%d/xlen:%d/ylen:%d (%.2f min)" % (n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, optimizer.lr, len(batch_train['utt_ids']), xlen, ylen, duration_step / 60)) start_time_step = time.time() # Save fugures of loss and accuracy if n_steps % (args.print_step * 10) == 0: reporter.snapshot() model.module.plot_attention() model.module.plot_ctc() # Ealuate model every 0.1 epoch during MBR training if args.mbr_training: if int(train_set.epoch_detail * 10) != int(epoch_detail_prev * 10): # dev evaluate([model.module], dev_set, recog_params, args, int(train_set.epoch_detail * 10) / 10, logger) # Save the model optimizer.save_checkpoint( model, save_path, remove_old=False, amp=amp, epoch_detail=train_set.epoch_detail) epoch_detail_prev = train_set.epoch_detail if is_new_epoch: break # Save checkpoint and evaluate model per epoch duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() # lr decay reporter.epoch() # plot # Save the model optimizer.save_checkpoint( model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp) else: start_time_eval = time.time() # dev metric_dev = evaluate([model.module], dev_set, recog_params, args, optimizer.n_epochs + 1, logger) optimizer.epoch(metric_dev) # lr decay reporter.epoch(metric_dev, name=args.metric) # plot if optimizer.is_topk or is_transformer: # Save the model optimizer.save_checkpoint( model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp) # test if optimizer.is_topk: for eval_set in eval_sets: evaluate([model.module], eval_set, recog_params, args, optimizer.n_epochs, logger) duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: optimizer.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) if optimizer.n_epochs >= args.n_epochs: break # if args.ss_prob > 0: # model.module.scheduled_sampling_trigger() start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) reporter.tf_writer.close() pbar_epoch.close() return save_path
def eval_char(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, task_idx=0): """Evaluate the character-level model by WER & CER. Args: models (list): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar task_idx (int): the index of the target task in interest 0: main task 1: sub task 2: sub sub task Returns: wer (float): Word error rate cer (float): Character error rate """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 # Reset data counter dataloader.reset(recog_params['recog_batch_size']) if progressbar: pbar = tqdm(total=len(dataloader)) if task_idx == 0: task = 'ys' elif task_idx == 1: task = 'ys_sub1' elif task_idx == 2: task = 'ys_sub2' elif task_idx == 3: task = 'ys_sub3' with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: while True: batch, is_new_epoch = dataloader.next(recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True) else: best_hyps_id, _ = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[task_idx] if progressbar else None, exclude_eos=True, refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'], task=task, ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataloader.idx2token[task_idx](best_hyps_id[b]) # Truncate the first and last spaces for the char_space unit if len(hyp) > 0 and hyp[0] == ' ': hyp = hyp[1:] if len(hyp) > 0 and hyp[-1] == ' ': hyp = hyp[:-1] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'): # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # NOTE: sentence error rate for Chinese # Compute CER if dataloader.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if models[0].streamable(): n_streamable += 1 else: last_success_frame_ratio += models[0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataloader.reset() if not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'): wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word else: wer = n_sub_w = n_ins_w = n_del_w = 0 cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt logger.debug('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio)) return wer, cer
def main(): args = parse() args_init = copy.deepcopy(args) args_teacher = copy.deepcopy(args) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) recog_params = vars(args) # Compute subsampling factor subsample_factor = 1 subsample_factor_sub1 = 1 subsample_factor_sub2 = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings and 'conv' in args.enc_type: for p in args.conv_poolings.split('_'): subsample_factor *= int(p.split(',')[0].replace('(', '')) else: subsample_factor = np.prod(subsample) if args.train_set_sub1: if args.conv_poolings and 'conv' in args.enc_type: subsample_factor_sub1 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub1 - 1]) else: subsample_factor_sub1 = subsample_factor if args.train_set_sub2: if args.conv_poolings and 'conv' in args.enc_type: subsample_factor_sub2 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub2 - 1]) else: subsample_factor_sub2 = subsample_factor # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_asr_model_name(args, subsample_factor) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # for multi-GPUs if args.n_gpus > 1: logger.info("Batch size is automatically reduced from %d to %d" % (args.batch_size, args.batch_size // 2)) args.batch_size //= 2 skip_thought = 'skip' in args.enc_type # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, sort_by='input', short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=args.dynamic_batching, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, discourse_aware=args.discourse_aware, skip_thought=skip_thought) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size * args.n_gpus, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, discourse_aware=args.discourse_aware, skip_thought=skip_thought) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, discourse_aware=args.discourse_aware, skip_thought=skip_thought, is_test=True) ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Load a LM conf file for LM fusion & LM initialization if not args.resume and (args.lm_fusion or args.lm_init): if args.lm_fusion: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml')) elif args.lm_init: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_init), 'conf.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab # Model setting model = Speech2Text(args, save_path) if not skip_thought else SkipThought( args, save_path) if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer( model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'], conf['lr'], conf['weight_decay']) # Wrap optimizer by learning rate scheduler noam = 'transformer' in conf['enc_type'] or conf[ 'dec_type'] == 'transformer' optimizer = LRScheduler( optimizer, conf['lr'], decay_type=conf['lr_decay_type'], decay_start_epoch=conf['lr_decay_start_epoch'], decay_rate=conf['lr_decay_rate'], decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'], early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'], warmup_start_lr=conf['warmup_start_lr'], warmup_n_steps=conf['warmup_n_steps'], model_size=conf['d_model'], factor=conf['lr_factor'], noam=noam) # Restore the last saved model model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if epoch == conf['convert_to_sgd_epoch']: optimizer.convert_to_sgd(model, 'sgd', args.lr, conf['weight_decay'], decay_type='always', decay_rate=0.5) else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) if args.lm_fusion: save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) for sub in ['', '_sub1', '_sub2']: if getattr(args, 'dict' + sub): shutil.copy(getattr(args, 'dict' + sub), os.path.join(save_path, 'dict' + sub + '.txt')) if getattr(args, 'unit' + sub) == 'wp': shutil.copy(getattr(args, 'wp_model' + sub), os.path.join(save_path, 'wp' + sub + '.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Initialize with pre-trained model's parameters if args.asr_init and os.path.isfile(args.asr_init): # Load the ASR model conf_init = load_config( os.path.join(os.path.dirname(args.asr_init), 'conf.yml')) for k, v in conf_init.items(): setattr(args_init, k, v) model_init = Speech2Text(args_init) model_init = load_checkpoint(model_init, args.asr_init)[0] # Overwrite parameters only_enc = (args.enc_n_layers != args_init.enc_n_layers) or ( args.unit != args_init.unit) or args_init.ctc_weight == 1 param_dict = dict(model_init.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if only_enc and 'enc' not in n: continue if args.lm_fusion_type == 'cache' and 'output' in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler noam = 'transformer' in args.enc_type or args.dec_type == 'transformer' optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=noam) # Load the teacher ASR model teacher = None if args.teacher and os.path.isfile(args.teacher): conf_teacher = load_config( os.path.join(os.path.dirname(args.teacher), 'conf.yml')) for k, v in conf_teacher.items(): setattr(args_teacher, k, v) # Setting for knowledge distillation args_teacher.ss_prob = 0 args.lsm_prob = 0 teacher = Speech2Text(args_teacher) teacher = load_checkpoint(teacher, args.teacher)[0] # Load the teacher LM teacher_lm = None if args.teacher_lm and os.path.isfile(args.teacher_lm): conf_lm = load_config( os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) teacher_lm = build_lm(args_lm) teacher_lm = load_checkpoint(teacher_lm, args.teacher_lm)[0] # GPU setting if args.n_gpus >= 1: torch.backends.cudnn.benchmark = True model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus))) model.cuda() if teacher is not None: teacher.cuda() if teacher_lm is not None: teacher_lm.cuda() # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path) if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks for sub in ['sub1', 'sub2']: if getattr(args, 'train_set_' + sub): if getattr(args, sub + '_weight') - getattr( args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub] + tasks if getattr(args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub + '.ctc'] + tasks else: tasks = ['all'] start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() pbar_epoch = tqdm(total=len(train_set)) accum_n_tokens = 0 while True: # Compute loss in the training set batch_train, is_new_epoch = train_set.next() accum_n_tokens += sum([len(y) for y in batch_train['ys']]) # Change mini-batch depending on task for task in tasks: if skip_thought: loss, reporter = model(batch_train['ys'], ys_prev=batch_train['ys_prev'], ys_next=batch_train['ys_next'], reporter=reporter) else: loss, reporter = model(batch_train, reporter, task, teacher=teacher, teacher_lm=teacher_lm) loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.module.parameters(), args.clip_grad_norm) reporter.add_tensorboard_scalar('total_norm', total_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss reporter.add_tensorboard_scalar('learning_rate', optimizer.lr) # NOTE: loss/acc/ppl are already added in the model reporter.step() if optimizer.n_steps % args.print_step == 0: # Compute loss in the dev set batch_dev = dev_set.next()[0] # Change mini-batch depending on task for task in tasks: if skip_thought: loss, reporter = model(batch_dev['ys'], ys_prev=batch_dev['ys_prev'], ys_next=batch_dev['ys_next'], reporter=reporter, is_eval=True) else: loss, reporter = model(batch_dev, reporter, task, is_eval=True) loss_dev = loss.item() del loss # NOTE: this makes training slow # Compute WER/CER regardless of the output unit (greedy decoding) # best_hyps_id, _, _ = model.module.decode( # batch_dev['xs'], recog_params, dev_set.idx2token[0], exclude_eos=True) # cer = 0. # ref_n_words, ref_n_chars = 0, 0 # for b in range(len(batch_dev['xs'])): # ref = batch_dev['text'][b] # hyp = dev_set.idx2token[0](best_hyps_id[b]) # cer += editdistance.eval(hyp, ref) # ref_n_words += len(ref.split()) # ref_n_chars += len(ref) # wer = cer / ref_n_words # cer /= ref_n_chars # reporter.add_tensorboard_scalar('dev/WER', wer) # reporter.add_tensorboard_scalar('dev/CER', cer) # logger.info('WER (dev)', wer) # logger.info('CER (dev)', cer) reporter.step(is_eval=True) duration_step = time.time() - start_time_step if args.input_type == 'speech': xlen = max(len(x) for x in batch_train['xs']) ylen = max(len(y) for y in batch_train['ys']) elif args.input_type == 'text': xlen = max(len(x) for x in batch_train['ys']) ylen = max(len(y) for y in batch_train['ys_sub1']) logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.7f/bs:%d/xlen:%d/ylen:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, optimizer.lr, len(batch_train['utt_ids']), xlen, ylen, duration_step / 60)) start_time_step = time.time() pbar_epoch.update(len(batch_train['utt_ids'])) # Save fugures of loss and accuracy if optimizer.n_steps % (args.print_step * 10) == 0: reporter.snapshot() model.module.plot_attention() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() # lr decay reporter.epoch() # plot # Save the model save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=not noam) else: start_time_eval = time.time() # dev metric_dev = eval_epoch([model.module], dev_set, recog_params, args, optimizer.n_epochs + 1, logger) optimizer.epoch(metric_dev) # lr decay reporter.epoch(metric_dev) # plot if optimizer.is_best: # Save the model save_checkpoint(model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=not noam) # test for eval_set in eval_sets: eval_epoch([model.module], eval_set, recog_params, args, optimizer.n_epochs, logger) # start scheduled sampling if args.ss_prob > 0: model.module.scheduled_sampling_trigger() duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: optimizer.convert_to_sgd(model, 'sgd', args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) pbar_epoch = tqdm(total=len(train_set)) if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) reporter.tf_writer.close() pbar_epoch.close() return save_path
def main(): # Load configuration args, recog_params, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Load dataset dataset = Dataset( corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False, nlsyms=args.nlsyms, wp_model=os.path.join(dir_name, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, batch_size=args.recog_batch_size, is_test=True) if i == 0: # Load the ASR model model = Speech2Text(args, dir_name) epoch = int(args.recog_model[0].split('-')[-1]) if args.recog_n_average > 1: # Model averaging for Transformer model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() save_path = mkdir_join(args.recog_dir, 'ctc_probs') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) best_hyps_id, _ = model.decode(batch['xs'], recog_params) # Get CTC probs ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'], temperature=1, topk=min( 100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_ctc_probs( ctc_probs[b, :xlens[b]], topk_ids[b], subsample_factor=args.subsample_factor, spectrogram=batch['xs'][b][:, :dataset.input_dim], save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def main(): args = parse() # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) # Set save path if args.resume: save_path = os.path.dirname(args.resume) dir_name = os.path.basename(save_path) else: dir_name = set_lm_name(args) save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) save_path = set_save_path(save_path) # avoid overwriting # Set logger logger = set_logger(os.path.join(save_path, 'train.log'), key='training', stdout=args.stdout) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_tokens=args.min_n_tokens, bptt=args.bptt, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=args.batch_size * args.n_gpus, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, nlsyms=args.nlsyms, unit=args.unit, wp_model=args.wp_model, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) ] args.vocab = train_set.vocab # Model setting model = build_lm(args, save_path) if args.resume: # Set optimizer epoch = int(args.resume.split('-')[-1]) optimizer = set_optimizer( model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else conf['optimizer'], conf['lr'], conf['weight_decay']) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler( optimizer, conf['lr'], decay_type=conf['lr_decay_type'], decay_start_epoch=conf['lr_decay_start_epoch'], decay_rate=conf['lr_decay_rate'], decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'], early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'], warmup_start_lr=conf['warmup_start_lr'], warmup_n_steps=conf['warmup_n_steps'], model_size=conf['d_model'], factor=conf['lr_factor'], noam=conf['lm_type'] == 'transformer') # Restore the last saved model model, optimizer = load_checkpoint(model, args.resume, optimizer, resume=True) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if epoch == conf['convert_to_sgd_epoch']: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, conf['weight_decay']) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps logger.info('========== Convert to SGD ==========') else: # Save the conf file as a yaml file save_config(vars(args), os.path.join(save_path, 'conf.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Set optimizer optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler optimizer = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.d_model, factor=args.lr_factor, noam=args.lm_type == 'transformer') # GPU setting if args.n_gpus >= 1: torch.backends.cudnn.benchmark = True model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus))) model.cuda() # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(save_path) hidden = None start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() pbar_epoch = tqdm(total=len(train_set)) accum_n_tokens = 0 while True: # Compute loss in the training set ys_train, is_new_epoch = train_set.next() accum_n_tokens += sum([len(y) for y in ys_train]) optimizer.zero_grad() loss, hidden, reporter = model(ys_train, hidden, reporter) loss.backward() loss.detach() # Trancate the graph if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens: if args.clip_grad_norm > 0: total_norm = torch.nn.utils.clip_grad_norm_( model.module.parameters(), args.clip_grad_norm) reporter.add_tensorboard_scalar('total_norm', total_norm) optimizer.step() optimizer.zero_grad() accum_n_tokens = 0 loss_train = loss.item() del loss hidden = model.module.repackage_state(hidden) reporter.add_tensorboard_scalar('learning_rate', optimizer.lr) # NOTE: loss/acc/ppl are already added in the model reporter.step() if optimizer.n_steps % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] loss, _, reporter = model(ys_dev, None, reporter, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" % (optimizer.n_steps, optimizer.n_epochs + train_set.epoch_detail, loss_train, loss_dev, np.exp(loss_train), np.exp(loss_dev), optimizer.lr, ys_train.shape[0], duration_step / 60)) start_time_step = time.time() pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1)) # Save fugures of loss and accuracy if optimizer.n_steps % (args.print_step * 10) == 0: reporter.snapshot() if args.lm_type == 'transformer': model.module.plot_attention() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (optimizer.n_epochs + 1, duration_epoch / 60)) if optimizer.n_epochs + 1 < args.eval_start_epoch: optimizer.epoch() # lr decay reporter.epoch() # plot # Save the model save_checkpoint( model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') else: start_time_eval = time.time() # dev ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) logger.info('PPL (%s, epoch:%d): %.2f' % (dev_set.set, optimizer.n_epochs, ppl_dev)) optimizer.epoch(ppl_dev) # lr decay reporter.epoch(ppl_dev, name='perplexity') # plot if optimizer.is_best: # Save the model save_checkpoint( model, save_path, optimizer, optimizer.n_epochs, remove_old_checkpoints=args.lm_type != 'transformer') # test ppl_test_avg = 0. for eval_set in eval_sets: ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) logger.info( 'PPL (%s, epoch:%d): %.2f' % (eval_set.set, optimizer.n_epochs, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info('PPL (avg., epoch:%d): %.2f' % (optimizer.n_epochs, ppl_test_avg / len(eval_sets))) duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if optimizer.is_early_stop: break # Convert to fine-tuning stage if optimizer.n_epochs == args.convert_to_sgd_epoch: n_epochs = optimizer.n_epochs n_steps = optimizer.n_steps optimizer = set_optimizer(model, 'sgd', args.lr, args.weight_decay) optimizer = LRScheduler(optimizer, args.lr, decay_type='always', decay_start_epoch=0, decay_rate=0.5) optimizer._epoch = n_epochs optimizer._step = n_steps logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if optimizer.n_epochs == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) reporter.tf_writer.close() pbar_epoch.close() return save_path
def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Load dataloader dataloader = build_dataloader( args=args, tsv_path=s, batch_size=1, is_test=True, first_n_utterances=args.recog_first_n_utt, longform_max_n_frames=args.recog_longform_max_n_frames) if i == 0: # Load ASR model model = Speech2Text(args, dir_name) epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10 if args.recog_n_average > 1: # Model averaging for Transformer model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() save_path = mkdir_join(args.recog_dir, 'ctc_probs') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) for batch in dataloader: nbest_hyps_id, _ = model.decode(batch['xs'], args, dataloader.idx2token[0]) best_hyps_id = [h[0] for h in nbest_hyps_id] # Get CTC probs ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'], temperature=1, topk=min( 100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): tokens = dataloader.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_ctc_probs( ctc_probs[b, :xlens[b]], topk_ids[b], factor=args.subsample_factor, spectrogram=batch['xs'][b][:, :dataloader.input_dim], save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50)
def main(): # Load configuration args, dir_name = parse_args_eval(sys.argv[1:]) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout) for i, s in enumerate(args.recog_sets): # Load dataloader dataloader = build_dataloader(args=args, tsv_path=s, batch_size=1, is_test=True, first_n_utterances=args.recog_first_n_utt, longform_max_n_frames=args.recog_longform_max_n_frames) if i == 0: # Load ASR model model = Speech2Text(args, dir_name) epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10 if args.recog_n_average > 1: # Model averaging for Transformer model = average_checkpoints(model, args.recog_model[0], n_average=args.recog_n_average) else: load_checkpoint(args.recog_model[0], model) # Ensemble (different models) ensemble_models = [model] if len(args.recog_model) > 1: for recog_model_e in args.recog_model[1:]: conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml')) args_e = copy.deepcopy(args) for k, v in conf_e.items(): if 'recog' not in k: setattr(args_e, k, v) model_e = Speech2Text(args_e) load_checkpoint(recog_model_e, model_e) if args.recog_n_gpus >= 1: model_e.cuda() ensemble_models += [model_e] # Load LM for shallow fusion if not args.lm_fusion: # first path if args.recog_lm is not None and args.recog_lm_weight > 0: conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml')) args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) lm = build_lm(args_lm) load_checkpoint(args.recog_lm, lm) if args_lm.backward: model.lm_bwd = lm else: model.lm_fwd = lm # NOTE: only support for first path if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('recog oracle: %s' % args.recog_oracle) logger.info('epoch: %d' % epoch) logger.info('batch size: %d' % args.recog_batch_size) logger.info('beam width: %d' % args.recog_beam_width) logger.info('min length ratio: %.3f' % args.recog_min_len_ratio) logger.info('max length ratio: %.3f' % args.recog_max_len_ratio) logger.info('length penalty: %.3f' % args.recog_length_penalty) logger.info('length norm: %s' % args.recog_length_norm) logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty) logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold) logger.info('CTC weight: %.3f' % args.recog_ctc_weight) logger.info('fist LM path: %s' % args.recog_lm) logger.info('LM weight: %.3f' % args.recog_lm_weight) logger.info('GNMT: %s' % args.recog_gnmt_decoding) logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention) logger.info('resolving UNK: %s' % args.recog_resolving_unk) logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over)) logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over)) logger.info('model average (Transformer): %d' % (args.recog_n_average)) # GPU setting if args.recog_n_gpus >= 1: model.cudnn_setting(deterministic=True, benchmark=False) model.cuda() save_path = mkdir_join(args.recog_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) for batch in dataloader: nbest_hyps_id, aws = model.decode( batch['xs'], args, dataloader.idx2token[0], exclude_eos=False, refs_id=batch['ys'], ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [], speakers=batch['sessions'] if dataloader.corpus == 'swbd' else batch['speakers']) best_hyps_id = [h[0] for h in nbest_hyps_id] # Get CTC probs ctc_probs, topk_ids = None, None if args.ctc_weight > 0: ctc_probs, topk_ids, xlens = model.get_ctc_probs( batch['xs'], task='ys', temperature=1, topk=min(100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' ctc_probs_sub1, topk_ids_sub1 = None, None if args.ctc_weight_sub1 > 0: ctc_probs_sub1, topk_ids_sub1, xlens_sub1 = model.get_ctc_probs( batch['xs'], task='ys_sub1', temperature=1, topk=min(100, model.vocab_sub1)) if model.bwd_weight > 0.5: # Reverse the order best_hyps_id = [hyp[::-1] for hyp in best_hyps_id] aws = [[aw[0][:, ::-1]] for aw in aws] for b in range(len(batch['xs'])): tokens = dataloader.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_attention_weights( aws[b][0][:, :len(tokens)], tokens, spectrogram=batch['xs'][b][:, :dataloader.input_dim] if args.input_type == 'speech' else None, factor=args.subsample_factor, ref=batch['text'][b].lower(), save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8), ctc_probs=ctc_probs[b, :xlens[b]] if ctc_probs is not None else None, ctc_topk_ids=topk_ids[b] if topk_ids is not None else None, ctc_probs_sub1=ctc_probs_sub1[b, :xlens_sub1[b]] if ctc_probs_sub1 is not None else None, ctc_topk_ids_sub1=topk_ids_sub1[b] if topk_ids_sub1 is not None else None) if model.bwd_weight > 0.5: hyp = ' '.join(tokens[::-1]) else: hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50)