def main(): args = parse() args_pt = copy.deepcopy(args) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k != 'resume': setattr(args, k, v) recog_params = vars(args) # Automatically reduce batch size in multi-GPU setting if args.n_gpus > 1: args.batch_size -= 10 args.print_step //= args.n_gpus subsample_factor = 1 subsample_factor_sub1 = 1 subsample_factor_sub2 = 1 subsample_factor_sub3 = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings: for p in args.conv_poolings.split('_'): p = int(p.split(',')[1].replace(')', '')) if p > 1: subsample_factor *= p if args.train_set_sub1: subsample_factor_sub1 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub1 - 1]) if args.train_set_sub2: subsample_factor_sub2 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub2 - 1]) if args.train_set_sub3: subsample_factor_sub3 = subsample_factor * np.prod( subsample[:args.enc_n_layers_sub3 - 1]) subsample_factor *= np.prod(subsample) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, tsv_path_sub1=args.train_set_sub1, tsv_path_sub2=args.train_set_sub2, tsv_path_sub3=args.train_set_sub3, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, dict_path_sub3=args.dict_sub3, nlsyms=args.nlsyms, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, unit_sub3=args.unit_sub3, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, wp_model_sub3=args.wp_model_sub3, batch_size=args.batch_size * args.n_gpus, n_epochs=args.n_epochs, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, sort_by_input_length=True, short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=args.dynamic_batching, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, ctc_sub3=args.ctc_weight_sub3 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, subsample_factor_sub3=subsample_factor_sub3, concat_prev_n_utterances=args.concat_prev_n_utterances, n_caches=args.n_caches) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, tsv_path_sub1=args.dev_set_sub1, tsv_path_sub2=args.dev_set_sub2, tsv_path_sub3=args.dev_set_sub3, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, dict_path_sub3=args.dict_sub3, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, unit_sub3=args.unit_sub3, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, wp_model_sub3=args.wp_model_sub3, batch_size=args.batch_size * args.n_gpus, min_n_frames=args.min_n_frames, max_n_frames=args.max_n_frames, shuffle=True if args.n_caches == 0 else False, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, ctc_sub3=args.ctc_weight_sub3 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, subsample_factor_sub3=subsample_factor_sub3, n_caches=args.n_caches) eval_sets = [] for s in args.eval_sets: eval_sets += [ Dataset(corpus=args.corpus, tsv_path=s, dict_path=args.dict, unit=args.unit, wp_model=args.wp_model, batch_size=1, n_caches=args.n_caches, is_test=True) ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.vocab_sub3 = train_set.vocab_sub3 args.input_dim = train_set.input_dim # Load a LM conf file for cold fusion & LM initialization if args.lm_fusion: if args.model: lm_conf = load_config( os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml')) elif args.resume: lm_conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf_lm.yml')) args.lm_conf = argparse.Namespace() for k, v in lm_conf.items(): setattr(args.lm_conf, k, v) assert args.unit == args.lm_conf.unit assert args.vocab == args.lm_conf.vocab if args.enc_type == 'transformer': args.decay_type = 'warmup' # Model setting model = Seq2seq(args) dir_name = make_model_name(args, subsample_factor) if args.resume: # Set save path model.save_path = os.path.dirname(args.resume) # Setting for logging logger = set_logger(os.path.join(os.path.dirname(args.resume), 'train.log'), key='training') # Set optimizer epoch = int(args.resume.split('-')[-1]) model.set_optimizer( optimizer='sgd' if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'], learning_rate=float(conf['learning_rate']), # on-the-fly weight_decay=float(conf['weight_decay'])) # Restore the last saved model checkpoints = model.load_checkpoint(args.resume, resume=True) lr_controller = checkpoints['lr_controller'] epoch = checkpoints['epoch'] step = checkpoints['step'] metric_dev_best = checkpoints['metric_dev_best'] # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1 if epoch == conf['convert_to_sgd_epoch'] + 1: model.set_optimizer(optimizer='sgd', learning_rate=args.learning_rate, weight_decay=float(conf['weight_decay'])) logger.info('========== Convert to SGD ==========') else: # Set save path save_path = mkdir_join( args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) model.set_save_path(save_path) # avoid overwriting # Save the conf file as a yaml file save_config(vars(args), os.path.join(model.save_path, 'conf.yml')) if args.lm_fusion: save_config(args.lm_conf, os.path.join(model.save_path, 'conf_lm.yml')) # Save the nlsyms, dictionar, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(model.save_path, 'nlsyms.txt')) for sub in ['', '_sub1', '_sub2', '_sub3']: if getattr(args, 'dict' + sub): shutil.copy( getattr(args, 'dict' + sub), os.path.join(model.save_path, 'dict' + sub + '.txt')) if getattr(args, 'unit' + sub) == 'wp': shutil.copy( getattr(args, 'wp_model' + sub), os.path.join(model.save_path, 'wp' + sub + '.model')) # Setting for logging logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): nparams = model.num_params_dict[n] logger.info("%s %d" % (n, nparams)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Initialize with pre-trained model's parameters if args.pretrained_model and os.path.isfile(args.pretrained_model): # Load a conf file conf_pt = load_config( os.path.join(os.path.dirname(args.pretrained_model), 'conf.yml')) # Merge conf with args for k, v in conf_pt.items(): setattr(args_pt, k, v) # Load the ASR model model_pt = Seq2seq(args_pt) model_pt.load_checkpoint(args.pretrained_model) # Overwrite parameters only_enc = (args.enc_n_layers != args_pt.enc_n_layers) or (args.unit != args_pt.unit) param_dict = dict(model_pt.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if only_enc and 'enc' not in n: continue if args.lm_fusion_type == 'cache' and 'output' in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate=float(args.learning_rate), weight_decay=float(args.weight_decay), transformer=True if args.enc_type == 'transformer' or args.dec_type == 'transformer' else False) epoch, step = 1, 1 metric_dev_best = 10000 # Set learning rate controller lr_controller = Controller( learning_rate=float(args.learning_rate), decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_n_epochs=args.decay_patient_n_epochs, lower_better=True, best_value=metric_dev_best, model_size=args.d_model, warmup_start_learning_rate=args.warmup_start_learning_rate, warmup_n_steps=args.warmup_n_steps, factor=1) train_set.epoch = epoch - 1 # start from index:0 # GPU setting if args.n_gpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus, 1)), deterministic=False, benchmark=True) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name if args.job_name: setproctitle(args.job_name) else: setproctitle(dir_name) # Set reporter reporter = Reporter(model.module.save_path, tensorboard=True) if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = [] if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight - args.sub3_weight > 0: tasks += ['ys'] if args.bwd_weight > 0: tasks = ['ys.bwd'] + tasks if args.ctc_weight > 0: tasks = ['ys.ctc'] + tasks if args.lmobj_weight > 0: tasks = ['ys.lmobj'] + tasks if args.lm_fusion is not None and 'mtl' in args.lm_fusion_type: tasks = ['ys.lm'] + tasks for sub in ['sub1', 'sub2', 'sub3']: if getattr(args, 'train_set_' + sub): if getattr(args, sub + '_weight') - getattr( args, 'bwd_weight_' + sub) - getattr( args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub] + tasks if getattr(args, 'bwd_weight_' + sub) > 0: tasks = ['ys_' + sub + '.bwd'] + tasks if getattr(args, 'ctc_weight_' + sub) > 0: tasks = ['ys_' + sub + '.ctc'] + tasks if getattr(args, 'lmobj_weight_' + sub) > 0: tasks = ['ys_' + sub + '.lmobj'] + tasks else: tasks = ['all'] start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_n_epochs = 0 pbar_epoch = tqdm(total=len(train_set)) while True: # Compute loss in the training set batch_train, is_new_epoch = train_set.next() # Change tasks depending on task for task in tasks: model.module.optimizer.zero_grad() loss, reporter = model(batch_train, reporter=reporter, task=task) if len(model.device_ids) > 1: loss.backward(torch.ones(len(model.device_ids))) else: loss.backward() loss.detach() # Trancate the graph if args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.clip_grad_norm) model.module.optimizer.step() loss_train = loss.item() del loss reporter.step(is_eval=False) # Update learning rate if args.decay_type == 'warmup' and step < args.warmup_n_steps: model.module.optimizer = lr_controller.warmup( model.module.optimizer, step=step) if step % args.print_step == 0: # Compute loss in the dev set batch_dev = dev_set.next()[0] # Change tasks depending on task for task in tasks: loss, reporter = model(batch_dev, reporter=reporter, task=task, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step if args.input_type == 'speech': xlen = max(len(x) for x in batch_train['xs']) elif args.input_type == 'text': xlen = max(len(x) for x in batch_train['ys']) logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d (%.2f min)" % (step, train_set.epoch_detail, loss_train, loss_dev, lr_controller.lr, len( batch_train['utt_ids']), xlen, duration_step / 60)) start_time_step = time.time() step += args.n_gpus pbar_epoch.update(len(batch_train['utt_ids'])) # Save fugures of loss and accuracy if step % (args.print_step * 10) == 0: reporter.snapshot() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (epoch, duration_epoch / 60)) if epoch < args.eval_start_epoch: # Save the model model.module.save_checkpoint(model.module.save_path, lr_controller, epoch, step - 1, metric_dev_best) reporter._epoch += 1 # TODO(hirofumi): fix later else: start_time_eval = time.time() # dev if args.metric == 'edit_distance': if args.unit in ['word', 'word_char']: metric_dev = eval_word([model.module], dev_set, recog_params, epoch=epoch)[0] logger.info('WER (%s): %.2f %%' % (dev_set.set, metric_dev)) elif args.unit == 'wp': metric_dev, cer_dev = eval_wordpiece([model.module], dev_set, recog_params, epoch=epoch) logger.info('WER (%s): %.2f %%' % (dev_set.set, metric_dev)) logger.info('CER (%s): %.2f %%' % (dev_set.set, cer_dev)) elif 'char' in args.unit: metric_dev, cer_dev = eval_char([model.module], dev_set, recog_params, epoch=epoch) logger.info('WER (%s): %.2f %%' % (dev_set.set, metric_dev)) logger.info('CER (%s): %.2f %%' % (dev_set.set, cer_dev)) elif 'phone' in args.unit: metric_dev = eval_phone([model.module], dev_set, recog_params, epoch=epoch) logger.info('PER (%s): %.2f %%' % (dev_set.set, metric_dev)) elif args.metric == 'ppl': metric_dev = eval_ppl([model.module], dev_set, recog_params)[0] logger.info('PPL (%s): %.2f %%' % (dev_set.set, metric_dev)) elif args.metric == 'loss': metric_dev = eval_ppl([model.module], dev_set, recog_params)[1] logger.info('Loss (%s): %.2f %%' % (dev_set.set, metric_dev)) else: raise NotImplementedError(args.metric) reporter.epoch(metric_dev) # Update learning rate model.module.optimizer = lr_controller.decay( model.module.optimizer, epoch=epoch, value=metric_dev) if metric_dev < metric_dev_best: metric_dev_best = metric_dev not_improved_n_epochs = 0 logger.info('||||| Best Score |||||') # Save the model model.module.save_checkpoint(model.module.save_path, lr_controller, epoch, step - 1, metric_dev_best) # test for s in eval_sets: if args.metric == 'edit_distance': if args.unit in ['word', 'word_char']: wer_test = eval_word([model.module], s, recog_params, epoch=epoch)[0] logger.info('WER (%s): %.2f %%' % (s.set, wer_test)) elif args.unit == 'wp': wer_test, cer_test = eval_wordpiece( [model.module], s, recog_params, epoch=epoch) logger.info('WER (%s): %.2f %%' % (s.set, wer_test)) logger.info('CER (%s): %.2f %%' % (s.set, cer_test)) elif 'char' in args.unit: wer_test, cer_test = eval_char([model.module], s, recog_params, epoch=epoch) logger.info('WER (%s): %.2f %%' % (s.set, wer_test)) logger.info('CER (%s): %.2f %%' % (s.set, cer_test)) elif 'phone' in args.unit: per_test = eval_phone([model.module], s, recog_params, epoch=epoch) logger.info('PER (%s): %.2f %%' % (s.set, per_test)) elif args.metric == 'ppl': ppl_test = eval_ppl([model.module], s, recog_params)[0] logger.info('PPL (%s): %.2f %%' % (s.set, ppl_test)) elif args.metric == 'loss': loss_test = eval_ppl([model.module], s, recog_params)[1] logger.info('Loss (%s): %.2f %%' % (s.set, loss_test)) else: raise NotImplementedError(args.metric) else: not_improved_n_epochs += 1 # start scheduled sampling if args.ss_prob > 0: model.module.scheduled_sampling_trigger() duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_n_epochs == args.not_improved_patient_n_epochs: break # Convert to fine-tuning stage if epoch == args.convert_to_sgd_epoch: model.module.set_optimizer( 'sgd', learning_rate=args.learning_rate, weight_decay=float(args.weight_decay)) lr_controller = Controller( learning_rate=args.learning_rate, decay_type='epoch', decay_start_epoch=epoch, decay_rate=0.5, lower_better=True) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if epoch == args.n_epochs: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) if reporter.tensorboard: reporter.tf_writer.close() pbar_epoch.close() return model.module.save_path
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding') for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset(csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile( os.path.join(args.model, 'dict_sub.txt')) else None, wp_model=os.path.join(args.model, 'wp.model'), unit=args.unit, batch_size=args.batch_size, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, is_test=True) if i == 0: args.vocab = eval_set.vocab args.vocab_sub = eval_set.vocab_sub args.input_dim = eval_set.input_dim # TODO(hirofumi): For cold fusion args.rnnlm_cold_fusion = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cold_fusion is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.unit == args_rnnlm.unit args_rnnlm.vocab = eval_set.vocab # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.cuda() logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.plot_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = eval_set.next(decode_params['batch_size']) best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] if model.bwd_weight > 0.5: # Reverse the order best_hyps = [hyp[::-1] for hyp in best_hyps] aws = [aw[::-1] for aw in aws] for b in range(len(batch['xs'])): if args.unit == 'word': token_list = eval_set.idx2word(best_hyps[b], return_list=True) if args.unit == 'wp': token_list = eval_set.idx2wp(best_hyps[b], return_list=True) elif args.unit == 'char': token_list = eval_set.idx2char(best_hyps[b], return_list=True) elif args.unit == 'phone': token_list = eval_set.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError(args.unit) token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2]) # error check assert len(batch['xs'][b]) <= 2000 plot_attention_weights(aws[b][:len(token_list)], label_list=token_list, spectrogram=batch['xs'][b][:, :eval_set.input_dim] if args.input_type == 'speech' else None, save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) ref = ys[b] if model.bwd_weight > 0.5: hyp = ' '.join(token_list[::-1]) else: hyp = ' '.join(token_list) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref.lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def main(): args = parse() # Load a conf file dir_name = os.path.dirname(args.recog_model[0]) conf = load_config(os.path.join(dir_name, 'conf.yml')) # Overwrite conf for k, v in conf.items(): if 'recog' not in k: setattr(args, k, v) recog_params = vars(args) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')): os.remove(os.path.join(args.recog_dir, 'decode.log')) logger = set_logger(os.path.join(args.recog_dir, 'decode.log'), key='decoding') skip_thought = 'skip' in args.enc_type wer_avg, cer_avg, per_avg = 0, 0, 0 ppl_avg, loss_avg = 0, 0 for i, s in enumerate(args.recog_sets): # Load dataset dataset = Dataset( corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False, dict_path_sub2=os.path.join(dir_name, 'dict_sub2.txt') if os.path.isfile(os.path.join(dir_name, 'dict_sub2.txt')) else False, nlsyms=os.path.join(dir_name, 'nlsyms.txt'), wp_model=os.path.join(dir_name, 'wp.model'), wp_model_sub1=os.path.join(dir_name, 'wp_sub1.model'), wp_model_sub2=os.path.join(dir_name, 'wp_sub2.model'), unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, batch_size=args.recog_batch_size, skip_thought=skip_thought, is_test=True) if i == 0: # Load the ASR model if skip_thought: model = SkipThought(args) else: model = Seq2seq(args) model, checkpoint = load_checkpoint(model, args.recog_model[0]) epoch = checkpoint['epoch'] model.save_path = dir_name # ensemble (different models) ensemble_models = [model] if len(args.recog_model) > 1: for recog_model_e in args.recog_model[1:]: # Load a conf file conf_e = load_config( os.path.join(os.path.dirname(recog_model_e), 'conf.yml')) # Overwrite conf args_e = copy.deepcopy(args) for k, v in conf_e.items(): if 'recog' not in k: setattr(args_e, k, v) model_e = Seq2seq(args_e) model_e, _ = load_checkpoint(model_e, recog_model_e) model_e.cuda() ensemble_models += [model_e] # For shallow fusion if not args.lm_fusion: if args.recog_lm is not None and args.recog_lm_weight > 0: # Load a LM conf file conf_lm = load_config( os.path.join(os.path.dirname(args.recog_lm), 'conf.yml')) # Merge conf with args args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) # Load the pre-trianed LM if args_lm.lm_type == 'gated_cnn': lm = GatedConvLM(args_lm) else: lm = RNNLM(args_lm) lm, _ = load_checkpoint(lm, args.recog_lm) if args_lm.backward: model.lm_bwd = lm else: model.lm_fwd = lm if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 \ and (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring): # Load a LM conf file conf_lm = load_config( os.path.join(args.recog_lm_bwd, 'conf.yml')) # Merge conf with args args_lm_bwd = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm_bwd, k, v) # Load the pre-trianed LM if args_lm_bwd.lm_type == 'gated_cnn': lm_bwd = GatedConvLM(args_lm_bwd) else: lm_bwd = RNNLM(args_lm_bwd) lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd) model.lm_bwd = lm_bwd if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('recog metric: %s' % args.recog_metric) logger.info('recog oracle: %s' % args.recog_oracle) logger.info('epoch: %d' % (epoch - 1)) logger.info('batch size: %d' % args.recog_batch_size) logger.info('beam width: %d' % args.recog_beam_width) logger.info('min length ratio: %.3f' % args.recog_min_len_ratio) logger.info('max length ratio: %.3f' % args.recog_max_len_ratio) logger.info('length penalty: %.3f' % args.recog_length_penalty) logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty) logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold) logger.info('CTC weight: %.3f' % args.recog_ctc_weight) logger.info('LM path: %s' % args.recog_lm) logger.info('LM path (bwd): %s' % args.recog_lm_bwd) logger.info('LM weight: %.3f' % args.recog_lm_weight) logger.info('GNMT: %s' % args.recog_gnmt_decoding) logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention) logger.info('reverse LM rescoring: %s' % args.recog_reverse_lm_rescoring) logger.info('resolving UNK: %s' % args.recog_resolving_unk) logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over)) logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over)) logger.info('cache size: %d' % (args.recog_n_caches)) logger.info('cache type: %s' % (args.recog_cache_type)) logger.info('cache word frequency threshold: %s' % (args.recog_cache_word_freq)) logger.info('cache theta (speech): %.3f' % (args.recog_cache_theta_speech)) logger.info('cache lambda (speech): %.3f' % (args.recog_cache_lambda_speech)) logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm)) logger.info('cache lambda (lm): %.3f' % (args.recog_cache_lambda_lm)) # GPU setting model.cuda() start_time = time.time() if args.recog_metric == 'edit_distance': if args.recog_unit in ['word', 'word_char']: wer, cer, _ = eval_word(ensemble_models, dataset, recog_params, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True) wer_avg += wer cer_avg += cer elif args.recog_unit == 'wp': wer, cer = eval_wordpiece(ensemble_models, dataset, recog_params, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True) wer_avg += wer cer_avg += cer elif 'char' in args.recog_unit: wer, cer = eval_char(ensemble_models, dataset, recog_params, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True, task_idx=0) # task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0) wer_avg += wer cer_avg += cer elif 'phone' in args.recog_unit: per = eval_phone(ensemble_models, dataset, recog_params, epoch=epoch - 1, recog_dir=args.recog_dir, progressbar=True)[0] per_avg += per else: raise ValueError(args.recog_unit) elif args.recog_metric == 'acc': raise NotImplementedError elif args.recog_metric in ['ppl', 'loss']: ppl, loss = eval_ppl(ensemble_models, dataset, recog_params=recog_params, progressbar=True) ppl_avg += ppl loss_avg += loss elif args.recog_metric == 'bleu': raise NotImplementedError else: raise NotImplementedError logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time)) if args.recog_metric == 'edit_distance': if 'phone' in args.recog_unit: logger.info('PER (avg.): %.2f %%\n' % (per_avg / len(args.recog_sets))) else: logger.info('WER / CER (avg.): %.2f / %.2f %%\n' % (wer_avg / len(args.recog_sets), cer_avg / len(args.recog_sets))) elif args.recog_metric in ['ppl', 'loss']: logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets))) print('PPL (avg.): %.2f' % (ppl_avg / len(args.recog_sets))) logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets))) print('Loss (avg.): %.2f' % (loss_avg / len(args.recog_sets)))
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding') for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset( csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None, label_type=args.label_type, batch_size=args.batch_size, max_epoch=args.num_epochs, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, is_test=False) if i == 0: args.num_classes = eval_set.num_classes args.input_dim = eval_set.input_dim args.num_classes_sub = eval_set.num_classes_sub # TODO(hirofumi): For cold fusion args.rnnlm_cf = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) # Restore the saved parameters epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config( os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.label_type == args_rnnlm.label_type args_rnnlm.num_classes = eval_set.num_classes # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.set_cuda(deterministic=False, benchmark=True) logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.model, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = eval_set.next(decode_params['batch_size']) best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] for b in range(len(batch['xs'])): if args.label_type in ['word', 'wordpiece']: token_list = eval_set.idx2word(best_hyps[b], return_list=True) elif args.label_type == 'char': token_list = eval_set.idx2char(best_hyps[b], return_list=True) elif args.label_type == 'phone': token_list = eval_set.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError() token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) # error check assert len(batch['xs'][b]) <= 2000 plot_attention_weights( aw[b][:len(token_list)], label_list=token_list, spectrogram=batch['xs'][b][:, :eval_set.input_dim] if args.input_type == 'speech' else None, save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) # Reference if eval_set.is_test: text_ref = ys[b] else: if args.label_type in ['word', 'wordpiece']: text_ref = eval_set.idx2word(ys[b]) if args.label_type in ['word', 'wordpiece']: token_list = eval_set.idx2word(ys[b]) elif args.label_type == 'char': token_list = eval_set.idx2char(ys[b]) elif args.label_type == 'phone': token_list = eval_set.idx2phone(ys[b]) # Hypothesis text_hyp = ' '.join(token_list) sys.stdout = open( os.path.join(save_path, speaker, batch['utt_ids'][b] + '.txt'), 'w') ler = wer_align( ref=text_ref.split(' '), hyp=text_hyp.encode('utf-8').split(' '), normalize=True, double_byte=False)[0] # TODO(hirofumi): add corpus to args print('\nLER: %.3f %%\n\n' % ler) if is_new_epoch: break
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding') wer_mean, cer_mean, per_mean = 0, 0, 0 for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset( csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None, label_type=args.label_type, batch_size=args.batch_size, max_epoch=args.num_epochs, is_test=True) if i == 0: args.num_classes = eval_set.num_classes args.input_dim = eval_set.input_dim args.num_classes_sub = eval_set.num_classes_sub # For cold fusion # if args.rnnlm_cf: # # Load a RNNLM config file # config['rnnlm_config'] = load_config(os.path.join(args.model, 'config_rnnlm.yml')) # # assert args.label_type == config['rnnlm_config']['label_type'] # rnnlm_args.num_classes = eval_set.num_classes # logger.info('RNNLM path: %s' % config['rnnlm']) # logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) # else: # pass args.rnnlm_cf = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) # Restore the saved parameters epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config( os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.label_type == args_rnnlm.label_type args_rnnlm.num_classes = eval_set.num_classes # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.set_cuda(deterministic=False, benchmark=True) logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) start_time = time.time() if args.label_type == 'word': wer, _, _, _, decode_dir = eval_word([model], eval_set, decode_params, epoch=epoch - 1, progressbar=True) wer_mean += wer logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer)) elif args.label_type == 'wordpiece': wer, _, _, _, decode_dir = eval_wordpiece([model], eval_set, decode_params, os.path.join( args.model, 'wp.model'), epoch=epoch - 1, progressbar=True) wer_mean += wer logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer)) elif 'char' in args.label_type: (wer, _, _, _), (cer, _, _, _), decode_dir = eval_char([model], eval_set, decode_params, epoch=epoch - 1, progressbar=True) wer_mean += wer cer_mean += cer logger.info(' WER / CER (%s): %.3f / %.3f %%' % (eval_set.set, wer, cer)) elif 'phone' in args.label_type: per, _, _, _, decode_dir = eval_phone([model], eval_set, decode_params, epoch=epoch - 1, progressbar=True) per_mean += per logger.info(' PER (%s): %.3f %%' % (eval_set.set, per)) else: raise ValueError(args.label_type) logger.info('Elasped time: %.2f [sec.]:' % (time.time() - start_time)) if args.label_type == 'word': logger.info(' WER (mean): %.3f %%\n' % (wer_mean / len(args.eval_sets))) if args.label_type == 'wordpiece': logger.info(' WER (mean): %.3f %%\n' % (wer_mean / len(args.eval_sets))) elif 'char' in args.label_type: logger.info( ' WER / CER (mean): %.3f / %.3f %%\n' % (wer_mean / len(args.eval_sets), cer_mean / len(args.eval_sets))) elif 'phone' in args.label_type: logger.info(' PER (mean): %.3f %%\n' % (per_mean / len(args.eval_sets))) print(decode_dir)
def main(): # Load a config file if args.resume_model is None: config = load_config(args.config) else: # Restart from the last checkpoint config = load_config(os.path.join(args.resume_model, 'config.yml')) # Check differences between args and yaml comfiguraiton for k, v in vars(args).items(): if k not in config.keys(): warnings.warn("key %s is automatically set to %s" % (k, str(v))) # Merge config with args for k, v in config.items(): setattr(args, k, v) # Automatically reduce batch size in multi-GPU setting if args.ngpus > 1: args.batch_size -= 10 args.print_step //= args.ngpus subsample_factor = 1 subsample_factor_sub = 1 for p in args.conv_poolings: if len(p) > 0: subsample_factor *= p[0] if args.train_set_sub is not None: subsample_factor_sub = subsample_factor * (2**sum( args.subsample[:args.enc_num_layers_sub - 1])) subsample_factor *= 2**sum(args.subsample) # Load dataset train_set = Dataset(csv_path=args.train_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, max_epoch=args.num_epochs, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, sort_by_input_length=True, short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=True, use_ctc=args.ctc_weight > 0, subsample_factor=subsample_factor, csv_path_sub=args.train_set_sub, dict_path_sub=args.dict_sub, label_type_sub=args.label_type_sub, use_ctc_sub=args.ctc_weight_sub > 0, subsample_factor_sub=subsample_factor_sub, skip_speech=(args.input_type != 'speech')) dev_set = Dataset(csv_path=args.dev_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, max_epoch=args.num_epochs, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, shuffle=True, use_ctc=args.ctc_weight > 0, subsample_factor=subsample_factor, csv_path_sub=args.dev_set_sub, dict_path_sub=args.dict_sub, label_type_sub=args.label_type_sub, use_ctc_sub=args.ctc_weight_sub > 0, subsample_factor_sub=subsample_factor_sub, skip_speech=(args.input_type != 'speech')) eval_sets = [] for set in args.eval_sets: eval_sets += [ Dataset(csv_path=set, dict_path=args.dict, label_type=args.label_type, batch_size=1, max_epoch=args.num_epochs, is_test=True, skip_speech=(args.input_type != 'speech')) ] args.num_classes = train_set.num_classes args.input_dim = train_set.input_dim args.num_classes_sub = train_set.num_classes_sub # Load a RNNLM config file for cold fusion & RNNLM initialization # if config['rnnlm_cf']: # if args.model is not None: # config['rnnlm_config_cold_fusion'] = load_config( # os.path.join(config['rnnlm_cf'], 'config.yml'), is_eval=True) # elif args.resume_model is not None: # config = load_config(os.path.join( # args.resume_model, 'config_rnnlm_cf.yml')) # assert args.label_type == config['rnnlm_config_cold_fusion']['label_type'] # config['rnnlm_config_cold_fusion']['num_classes'] = train_set.num_classes args.rnnlm_cf = None args.rnnlm_init = None # Model setting model = Seq2seq(args) model.name = args.enc_type if len(args.conv_channels) > 0: tmp = model.name model.name = 'conv' + str(len(args.conv_channels)) + 'L' if args.conv_batch_norm: model.name += 'bn' model.name += tmp model.name += str(args.enc_num_units) + 'H' model.name += str(args.enc_num_projs) + 'P' model.name += str(args.enc_num_layers) + 'L' model.name += '_subsample' + str(subsample_factor) model.name += '_' + args.dec_type model.name += str(args.dec_num_units) + 'H' # model.name += str(args.dec_num_projs) + 'P' model.name += str(args.dec_num_layers) + 'L' model.name += '_' + args.att_type if args.att_num_heads > 1: model.name += '_head' + str(args.att_num_heads) model.name += '_' + args.optimizer model.name += '_lr' + str(args.learning_rate) model.name += '_bs' + str(args.batch_size) model.name += '_ss' + str(args.ss_prob) model.name += '_ls' + str(args.lsm_prob) if args.ctc_weight > 0: model.name += '_ctc' + str(args.ctc_weight) if args.bwd_weight > 0: model.name += '_bwd' + str(args.bwd_weight) if args.main_task_weight < 1: model.name += '_main' + str(args.main_task_weight) if args.ctc_weight_sub > 0: model.name += '_ctcsub' + str(args.ctc_weight_sub * (1 - args.main_task_weight)) else: model.name += '_attsub' + str(1 - args.main_task_weight) if args.resume_model is None: # Load pre-trained RNNLM # if config['rnnlm_cf']: # rnnlm = RNNLM(args) # rnnlm.load_checkpoint(save_path=config['rnnlm_cf'], epoch=-1) # rnnlm.flatten_parameters() # # # Fix RNNLM parameters # for param in rnnlm.parameters(): # param.requires_grad = False # # # Set pre-trained parameters # if config['rnnlm_config_cold_fusion']['backward']: # model.dec_0_bwd.rnnlm = rnnlm # else: # model.dec_0_fwd.rnnlm = rnnlm # TODO(hirofumi): 最初にRNNLMのモデルをコピー # Set save path save_path = mkdir_join( args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name) model.set_save_path(save_path) # avoid overwriting # Save the config file as a yaml file save_config(vars(args), model.save_path) # Save the dictionary & wp_model shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.dict_sub is not None: shutil.copy(args.dict_sub, os.path.join(save_path, 'dict_sub.txt')) if args.label_type == 'wordpiece': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) # Setting for logging logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # if os.path.isdir(args.pretrained_model): # # NOTE: Start training from the pre-trained model # # This is defferent from resuming training # model.load_checkpoint(args.pretrained_model, epoch=-1, # load_pretrained_model=True) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] logger.info("%s %d" % (name, num_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate_init=float(args.learning_rate), weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) epoch, step = 1, 0 learning_rate = float(args.learning_rate) metric_dev_best = 10000 # NOTE: Restart from the last checkpoint # elif args.resume_model is not None: # # Set save path # model.save_path = args.resume_model # # # Setting for logging # logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') # # # Set optimizer # model.set_optimizer( # optimizer=config['optimizer'], # learning_rate_init=float(config['learning_rate']), # on-the-fly # weight_decay=float(config['weight_decay']), # clip_grad_norm=config['clip_grad_norm'], # lr_schedule=False, # factor=config['decay_rate'], # patience_epoch=config['decay_patient_epoch']) # # # Restore the last saved model # epoch, step, learning_rate, metric_dev_best = model.load_checkpoint( # save_path=args.resume_model, epoch=-1, restart=True) # # if epoch >= config['convert_to_sgd_epoch']: # model.set_optimizer( # optimizer='sgd', # learning_rate_init=float(config['learning_rate']), # on-the-fly # weight_decay=float(config['weight_decay']), # clip_grad_norm=config['clip_grad_norm'], # lr_schedule=False, # factor=config['decay_rate'], # patience_epoch=config['decay_patient_epoch']) # # if config['rnnlm_cf']: # if config['rnnlm_config_cold_fusion']['backward']: # model.rnnlm_0_bwd.flatten_parameters() # else: # model.rnnlm_0_fwd.flatten_parameters() train_set.epoch = epoch - 1 # start from index:0 # GPU setting if args.ngpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.ngpus, 1)), deterministic=False, benchmark=True) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name # setproctitle(args.job_name) # Set learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_epoch=args.decay_patient_epoch, lower_better=True, best_value=metric_dev_best) # Set reporter reporter = Reporter(model.module.save_path, max_loss=300) # Set the updater updater = Updater(args.clip_grad_norm) # Setting for tensorboard tf_writer = SummaryWriter(model.module.save_path) start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0. loss_train_mean, acc_train_mean = 0., 0. pbar_epoch = tqdm(total=len(train_set)) pbar_all = tqdm(total=len(train_set) * args.num_epochs) while True: # Compute loss in the training set (including parameter update) batch_train, is_new_epoch = train_set.next() model, loss_train, acc_train = updater(model, batch_train) loss_train_mean += loss_train acc_train_mean += acc_train pbar_epoch.update(len(batch_train['utt_ids'])) if (step + 1) % args.print_step == 0: # Compute loss in the dev set batch_dev = dev_set.next()[0] model, loss_dev, acc_dev = updater(model, batch_dev, is_eval=True) loss_train_mean /= args.print_step acc_train_mean /= args.print_step reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev) # Logging by tensorboard tf_writer.add_scalar('train/loss', loss_train_mean, step + 1) tf_writer.add_scalar('dev/loss', loss_dev, step + 1) # for n, p in model.module.named_parameters(): # n = n.replace('.', '/') # if p.grad is not None: # tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1) # tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1) duration_step = time.time() - start_time_step if args.input_type == 'speech': x_len = max(len(x) for x in batch_train['xs']) elif args.input_type == 'text': x_len = max(len(x) for x in batch_train['ys_sub']) logger.info( "...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/lr:%.5f/bs:%d/x_len:%d (%.2f min)" % (step + 1, train_set.epoch_detail, loss_train_mean, loss_dev, acc_train_mean, acc_dev, learning_rate, train_set.current_batch_size, x_len, duration_step / 60)) start_time_step = time.time() loss_train_mean, acc_train_mean = 0, 0 step += args.ngpus # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60)) # Save fugures of loss and accuracy reporter.epoch() if epoch < args.eval_start_epoch: # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) else: start_time_eval = time.time() # dev if args.metric == 'ler': if args.label_type == 'word': metric_dev = eval_word([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif args.label_type == 'wordpiece': metric_dev = eval_wordpiece([model.module], dev_set, decode_params, args.wp_model, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif 'char' in args.label_type: metric_dev = eval_char([model.module], dev_set, decode_params, epoch=epoch)[1][0] logger.info(' CER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif 'phone' in args.label_type: metric_dev = eval_phone([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info(' PER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif args.metric == 'loss': metric_dev = eval_loss([model.module], dev_set, decode_params) logger.info(' Loss (%s): %.3f %%' % (dev_set.set, metric_dev)) else: raise NotImplementedError() if metric_dev < metric_dev_best: metric_dev_best = metric_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=metric_dev) # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) # test for eval_set in eval_sets: if args.metric == 'ler': if args.label_type == 'word': wer_test = eval_word([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif args.label_type == 'wordpiece': wer_test = eval_wordpiece([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif 'char' in args.label_type: cer_test = eval_char([model.module], eval_set, decode_params, epoch=epoch)[1][0] logger.info(' CER (%s): %.3f / %.3f %%' % (eval_set.set, cer_test)) elif 'phone' in args.label_type: per_test = eval_phone([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info(' PER (%s): %.3f %%' % (eval_set.set, per_test)) elif args.metric == 'loss': loss_test = eval_loss([model.module], eval_set, decode_params) logger.info(' Loss (%s): %.3f %%' % (eval_set.set, loss_test)) else: raise NotImplementedError() else: # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=metric_dev) not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_epoch: break if epoch == args.convert_to_sgd_epoch: # Convert to fine-tuning stage model.module.set_optimizer( 'sgd', learning_rate_init=float( args.learning_rate), # TODO: ? weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) pbar_all.update(len(train_set)) if epoch == args.num_epochs: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) tf_writer.close() pbar_epoch.close() pbar_all.close() return model.module.save_path
def main(): # Load a config file if args.resume: config = load_config(os.path.join(args.resume, 'config.yml')) for k, v in config.items(): setattr(args, k, v) # Automatically reduce batch size in multi-GPU setting if args.ngpus > 1: args.batch_size -= 10 args.print_step //= args.ngpus subsample_factor = 1 subsample_factor_sub1 = 1 subsample_factor_sub2 = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings: for p in args.conv_poolings.split('_'): p = int(p.split(',')[0].replace('(', '')) if p > 1: subsample_factor *= p if args.train_set_sub1: subsample_factor_sub1 = subsample_factor * np.prod( subsample[:args.enc_nlayers_sub1 - 1]) if args.train_set_sub2: subsample_factor_sub2 = subsample_factor * np.prod( subsample[:args.enc_nlayers_sub2 - 1]) subsample_factor *= np.prod(subsample) # Load dataset train_set = Dataset(csv_path=args.train_set, csv_path_sub1=args.train_set_sub1, csv_path_sub2=args.train_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size * args.ngpus, nepochs=args.nepochs, min_nframes=args.min_nframes, max_nframes=args.max_nframes, sort_by_input_length=True, short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=args.dynamic_batching, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, skip_speech=(args.input_type != 'speech')) dev_set = Dataset(csv_path=args.dev_set, csv_path_sub1=args.dev_set_sub1, csv_path_sub2=args.dev_set_sub2, dict_path=args.dict, dict_path_sub1=args.dict_sub1, dict_path_sub2=args.dict_sub2, unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, wp_model=args.wp_model, wp_model_sub1=args.wp_model_sub1, wp_model_sub2=args.wp_model_sub2, batch_size=args.batch_size * args.ngpus, min_nframes=args.min_nframes, max_nframes=args.max_nframes, shuffle=True, ctc=args.ctc_weight > 0, ctc_sub1=args.ctc_weight_sub1 > 0, ctc_sub2=args.ctc_weight_sub2 > 0, subsample_factor=subsample_factor, subsample_factor_sub1=subsample_factor_sub1, subsample_factor_sub2=subsample_factor_sub2, skip_speech=(args.input_type != 'speech')) eval_sets = [] for set in args.eval_sets: eval_sets += [ Dataset(csv_path=set, dict_path=args.dict, unit=args.unit, wp_model=args.wp_model, batch_size=1, is_test=True, skip_speech=(args.input_type != 'speech')) ] args.vocab = train_set.vocab args.vocab_sub1 = train_set.vocab_sub1 args.vocab_sub2 = train_set.vocab_sub2 args.input_dim = train_set.input_dim # Load a RNNLM config file for cold fusion & RNNLM initialization # if config['rnnlm_cold_fusion']: # if args.model: # config['rnnlm_config_cold_fusion'] = load_config( # os.path.join(config['rnnlm_cold_fusion'], 'config.yml'), is_eval=True) # elif args.resume: # config = load_config(os.path.join( # args.resume, 'config_rnnlm_cf.yml')) # assert args.unit == config['rnnlm_config_cold_fusion']['unit'] # config['rnnlm_config_cold_fusion']['vocab'] = train_set.vocab args.rnnlm_cold_fusion = False # Model setting if args.transformer: model = Transformer(args) dir_name = 'transformer' if len(args.conv_channels) > 0: tmp = dir_name dir_name = 'conv' + str(len(args.conv_channels.split('_'))) + 'L' if args.conv_batch_norm: dir_name += 'bn' dir_name += tmp dir_name += str(args.d_model) + 'H' dir_name += str(args.enc_nlayers) + 'L' dir_name += str(args.dec_nlayers) + 'L' dir_name += '_head' + str(args.attn_nheads) dir_name += '_' + args.optimizer dir_name += '_lr' + str(args.learning_rate) dir_name += '_bs' + str(args.batch_size) dir_name += '_ls' + str(args.lsm_prob) dir_name += '_' + str(args.pre_process) + 't' + str(args.post_process) if args.nstacks > 1: dir_name += '_stack' + str(args.nstacks) if args.bwd_weight > 0: dir_name += '_bwd' + str(args.bwd_weight) else: model = Seq2seq(args) dir_name = args.enc_type if args.conv_channels and len(args.conv_channels.split('_')) > 0: tmp = dir_name dir_name = 'conv' + str(len(args.conv_channels.split('_'))) + 'L' if args.conv_batch_norm: dir_name += 'bn' dir_name += tmp dir_name += str(args.enc_nunits) + 'H' dir_name += str(args.enc_nprojs) + 'P' dir_name += str(args.enc_nlayers) + 'L' dir_name += '_' + args.subsample_type + str(subsample_factor) dir_name += '_' + args.dec_type if args.internal_lm > 0: dir_name += 'LM' dir_name += str(args.dec_nunits) + 'H' # dir_name += str(args.dec_nprojs) + 'P' dir_name += str(args.dec_nlayers) + 'L' if args.tie_embedding: dir_name += '_tie' dir_name += '_' + args.attn_type if args.attn_nheads > 1: dir_name += '_head' + str(args.attn_nheads) if args.attn_sigmoid: dir_name += '_sig' dir_name += '_' + args.optimizer dir_name += '_lr' + str(args.learning_rate) dir_name += '_bs' + str(args.batch_size) dir_name += '_ss' + str(args.ss_prob) dir_name += '_ls' + str(args.lsm_prob) if args.focal_loss_weight > 0: dir_name += '_fl' + str(args.focal_loss_weight) if args.layer_norm: dir_name += '_layernorm' # MTL if args.mtl_per_batch: dir_name += '_mtlperbatch' if args.ctc_weight > 0: dir_name += '_' + args.unit + 'ctc' if args.bwd_weight > 0: dir_name += '_' + args.unit + 'bwd' if args.lmobj_weight > 0: dir_name += '_' + args.unit + 'lmobj' if args.train_set_sub1: dir_name += '_' + args.unit_sub1 if args.ctc_weight_sub1 == 0: dir_name += 'att' elif args.ctc_weight_sub1 == args.sub1_weight: dir_name += 'ctc' else: dir_name += 'attctc' if args.train_set_sub2: dir_name += '_' + args.unit_sub2 if args.ctc_weight_sub2 == 0: dir_name += 'att' elif args.ctc_weight_sub2 == args.sub2_weight: dir_name += 'ctc' else: dir_name += 'attctc' else: if args.ctc_weight > 0: dir_name += '_ctc' + str(args.ctc_weight) if args.bwd_weight > 0: dir_name += '_bwd' + str(args.bwd_weight) if args.lmobj_weight > 0: dir_name += '_lmobj' + str(args.lmobj_weight) if args.sub1_weight > 0: if args.ctc_weight_sub1 == args.sub1_weight: dir_name += '_ctcsub1' + str(args.ctc_weight_sub1) elif args.ctc_weight_sub1 == 0: dir_name += '_attsub1' + str(args.sub1_weight) else: dir_name += '_ctcsub1' + str(args.ctc_weight_sub1) + 'attsub1' + \ str(args.sub1_weight - args.ctc_weight_sub1) if args.sub2_weight > 0: if args.ctc_weight_sub2 == args.sub2_weight: dir_name += '_ctcsub2' + str(args.ctc_weight_sub2) elif args.ctc_weight_sub2 == 0: dir_name += '_attsub2' + str(args.sub2_weight) else: dir_name += '_ctcsub2' + str(args.ctc_weight_sub2) + 'attsub2' + \ str(args.sub2_weight - args.ctc_weight_sub2) if args.task_specific_layer: dir_name += '_tsl' # Pre-training if args.pretrained_model and os.path.isdir(args.pretrained_model): # Load a config file config_pre = load_config( os.path.join(args.pretrained_model, 'config.yml')) dir_name += '_' + config_pre['unit'] + 'pt' if not args.resume: # Load pre-trained RNNLM # if config['rnnlm_cold_fusion']: # rnnlm = RNNLM(args) # rnnlm.load_checkpoint(save_path=config['rnnlm_cold_fusion'], epoch=-1) # rnnlm.flatten_parameters() # # # Fix RNNLM parameters # for param in rnnlm.parameters(): # param.requires_grad = False # # # Set pre-trained parameters # if config['rnnlm_config_cold_fusion']['backward']: # model.dec_0_bwd.rnnlm = rnnlm # else: # model.dec_0_fwd.rnnlm = rnnlm # TODO(hirofumi): 最初にRNNLMのモデルをコピー # Set save path save_path = mkdir_join( args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) model.set_save_path(save_path) # avoid overwriting # Save the config file as a yaml file save_config(vars(args), model.save_path) # Save the dictionary & wp_model shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt')) if args.dict_sub1: shutil.copy(args.dict_sub1, os.path.join(model.save_path, 'dict_sub1.txt')) if args.dict_sub2: shutil.copy(args.dict_sub2, os.path.join(model.save_path, 'dict_sub2.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(model.save_path, 'wp.model')) # Setting for logging logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): nparams = model.num_params_dict[n] logger.info("%s %d" % (n, nparams)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info(model) # Initialize with pre-trained model's parameters if args.pretrained_model and os.path.isdir(args.pretrained_model): # Merge config with args for k, v in config_pre.items(): setattr(args_pre, k, v) # Load the ASR model model_pre = Seq2seq(args_pre) model_pre.load_checkpoint(args.pretrained_model, epoch=-1) # Overwrite parameters param_dict = dict(model_pre.named_parameters()) for n, p in model.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): p.data = param_dict[n].data logger.info('Overwrite %s' % n) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate_init=float(args.learning_rate), weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) epoch, step = 1, 1 learning_rate = float(args.learning_rate) metric_dev_best = 10000 # NOTE: Restart from the last checkpoint # elif args.resume: # # Set save path # model.save_path = args.resume # # # Setting for logging # logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') # # # Set optimizer # model.set_optimizer( # optimizer=config['optimizer'], # learning_rate_init=float(config['learning_rate']), # on-the-fly # weight_decay=float(config['weight_decay']), # clip_grad_norm=config['clip_grad_norm'], # lr_schedule=False, # factor=config['decay_rate'], # patience_epoch=config['decay_patient_epoch']) # # # Restore the last saved model # epoch, step, learning_rate, metric_dev_best = model.load_checkpoint( # save_path=args.resume, epoch=-1, restart=True) # # if epoch >= config['convert_to_sgd_epoch']: # model.set_optimizer( # optimizer='sgd', # learning_rate_init=float(config['learning_rate']), # on-the-fly # weight_decay=float(config['weight_decay']), # clip_grad_norm=config['clip_grad_norm'], # lr_schedule=False, # factor=config['decay_rate'], # patience_epoch=config['decay_patient_epoch']) # # if config['rnnlm_cold_fusion']: # if config['rnnlm_config_cold_fusion']['backward']: # model.rnnlm_0_bwd.flatten_parameters() # else: # model.rnnlm_0_fwd.flatten_parameters() train_set.epoch = epoch - 1 # start from index:0 # GPU setting if args.ngpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.ngpus, 1)), deterministic=False, benchmark=True) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name # if args.job_name: # setproctitle(args.job_name) # else: # setproctitle(dir_name) # Set learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_epoch=args.decay_patient_epoch, lower_better=True, best_value=metric_dev_best, model_size=args.d_model, warmup_step=args.warmup_step, factor=1) # Set reporter reporter = Reporter(model.module.save_path, tensorboard=True) if args.mtl_per_batch: # NOTE: from easier to harder tasks tasks = ['ys'] if 0 < args.ctc_weight < 1: tasks = ['ys.ctc'] + tasks if 0 < args.bwd_weight < 1: tasks = ['ys.bwd'] + tasks if 0 < args.lmobj_weight < 1: tasks = ['ys.lmobj'] + tasks if args.train_set_sub1: if args.ctc_weight_sub1 > 0: tasks = ['ys_sub1.ctc'] + tasks else: tasks = ['ys_sub1'] + tasks if args.train_set_sub2: if args.ctc_weight_sub2 > 0: tasks = ['ys_sub2.ctc'] + tasks else: tasks = ['ys_sub2'] + tasks else: tasks = ['all'] start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0 pbar_epoch = tqdm(total=len(train_set)) while True: # Compute loss in the training set batch_train, is_new_epoch = train_set.next() # Change tasks depending on task for task in tasks: model.module.optimizer.zero_grad() loss, reporter = model(batch_train, reporter=reporter, task=task) if len(model.device_ids) > 1: loss.backward(torch.ones(len(model.device_ids))) else: loss.backward() loss.detach() # Trancate the graph if args.clip_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.module.parameters(), args.clip_grad_norm) model.module.optimizer.step() loss_train = loss.item() del loss reporter.step(is_eval=False) # Update learning rate if args.decay_type == 'warmup': model.module.optimizer, learning_rate = lr_controller.warmup_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, step=step) if step % args.print_step == 0: # Compute loss in the dev set batch_dev = dev_set.next()[0] # Change tasks depending on task for task in tasks: loss, reporter = model(batch_dev, reporter=reporter, task=task, is_eval=True) loss_dev = loss.item() del loss reporter.step(is_eval=True) duration_step = time.time() - start_time_step if args.input_type == 'speech': x_len = max(len(x) for x in batch_train['xs']) elif args.input_type == 'text': x_len = max(len(x) for x in batch_train['ys']) logger.info( "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/x_len:%d (%.2f min)" % (step, train_set.epoch_detail, loss_train, loss_dev, learning_rate, len( batch_train['utt_ids']), x_len, duration_step / 60)) start_time_step = time.time() step += args.ngpus pbar_epoch.update(len(batch_train['utt_ids'])) # Save fugures of loss and accuracy if step % (args.print_step * 10) == 0: reporter.snapshot() # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('========== EPOCH:%d (%.2f min) ==========' % (epoch, duration_epoch / 60)) if epoch < args.eval_start_epoch: # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step - 1, learning_rate, metric_dev_best) else: start_time_eval = time.time() # dev if args.metric == 'edit_distance': if args.unit in ['word', 'word_char']: metric_dev = eval_word([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info('WER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif args.unit == 'wp': metric_dev = eval_wordpiece([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info('WER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif 'char' in args.unit: dev_results = eval_char([model.module], dev_set, decode_params, epoch=epoch) metric_dev = dev_results[1][0] wer_dev = dev_results[0][0] logger.info('CER (%s): %.3f %%' % (dev_set.set, metric_dev)) logger.info('WER (%s): %.3f %%' % (dev_set.set, wer_dev)) elif 'phone' in args.unit: metric_dev = eval_phone([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info('PER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif args.metric == 'loss': metric_dev = eval_loss([model.module], dev_set, decode_params) logger.info('Loss (%s): %.3f %%' % (dev_set.set, metric_dev)) else: raise NotImplementedError() # Update learning rate if args.decay_type != 'warmup': model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=metric_dev) if metric_dev < metric_dev_best: metric_dev_best = metric_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step - 1, learning_rate, metric_dev_best) # test for eval_set in eval_sets: if args.metric == 'edit_distance': if args.unit in ['word', 'word_char']: wer_test = eval_word([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info('WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif args.unit == 'wp': wer_test = eval_wordpiece([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info('WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif 'char' in args.unit: test_results = eval_char([model.module], eval_set, decode_params, epoch=epoch) cer_test = test_results[1][0] wer_test = test_results[0][0] logger.info('CER (%s): %.3f %%' % (eval_set.set, cer_test)) logger.info('WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif 'phone' in args.unit: per_test = eval_phone([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info('PER (%s): %.3f %%' % (eval_set.set, per_test)) elif args.metric == 'loss': loss_test = eval_loss([model.module], eval_set, decode_params) logger.info('Loss (%s): %.3f %%' % (eval_set.set, loss_test)) else: raise NotImplementedError() else: not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_epoch: break if epoch == args.convert_to_sgd_epoch: # Convert to fine-tuning stage model.module.set_optimizer( 'sgd', learning_rate_init=float( args.learning_rate), # TODO: ? weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) if epoch == args.nepochs: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) if reporter.tensorboard: reporter.tf_writer.close() pbar_epoch.close() return model.module.save_path
def main(): args = parse() # Load a conf file dir_name = os.path.dirname(args.recog_model[0]) conf = load_config(os.path.join(dir_name, 'conf.yml')) # Overwrite conf for k, v in conf.items(): if 'recog' not in k: setattr(args, k, v) recog_params = vars(args) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding') for i, s in enumerate(args.recog_sets): # Load dataset dataset = Dataset(corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile( os.path.join(dir_name, 'dict_sub1.txt')) else False, nlsyms=args.nlsyms, wp_model=os.path.join(dir_name, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, batch_size=args.recog_batch_size, is_test=True) if i == 0: # Load the ASR model model = Seq2seq(args) model, checkpoint = load_checkpoint(model, args.recog_model[0]) epoch = checkpoint['epoch'] model.save_path = dir_name # ensemble (different models) ensemble_models = [model] if len(args.recog_model) > 1: for recog_model_e in args.recog_model[1:]: # Load a conf file conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml')) # Overwrite conf args_e = copy.deepcopy(args) for k, v in conf_e.items(): if 'recog' not in k: setattr(args_e, k, v) model_e = Seq2seq(args_e) model_e, _ = load_checkpoint(model_e, recog_model_e) model_e.cuda() ensemble_models += [model_e] # For shallow fusion if not args.lm_fusion: if args.recog_lm is not None and args.recog_lm_weight > 0: # Load a LM conf file conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml')) # Merge conf with args args_lm = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm, k, v) # Load the pre-trianed LM if args_lm.lm_type == 'gated_cnn': lm = GatedConvLM(args_lm) else: lm = RNNLM(args_lm) lm, _ = load_checkpoint(lm, args.recog_lm) if args_lm.backward: model.lm_bwd = lm else: model.lm_fwd = lm if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 and \ (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring): # Load a LM conf file conf_lm = load_config(os.path.join(args.recog_lm_bwd, 'conf.yml')) # Merge conf with args args_lm_bwd = argparse.Namespace() for k, v in conf_lm.items(): setattr(args_lm_bwd, k, v) # Load the pre-trianed LM if args_lm_bwd.lm_type == 'gated_cnn': lm_bwd = GatedConvLM(args_lm_bwd) else: lm_bwd = RNNLM(args_lm_bwd) lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd) model.lm_bwd = lm_bwd if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('recog metric: %s' % args.recog_metric) logger.info('recog oracle: %s' % args.recog_oracle) logger.info('epoch: %d' % (epoch - 1)) logger.info('batch size: %d' % args.recog_batch_size) logger.info('beam width: %d' % args.recog_beam_width) logger.info('min length ratio: %.3f' % args.recog_min_len_ratio) logger.info('max length ratio: %.3f' % args.recog_max_len_ratio) logger.info('length penalty: %.3f' % args.recog_length_penalty) logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty) logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold) logger.info('CTC weight: %.3f' % args.recog_ctc_weight) logger.info('LM path: %s' % args.recog_lm) logger.info('LM path (bwd): %s' % args.recog_lm_bwd) logger.info('LM weight: %.3f' % args.recog_lm_weight) logger.info('GNMT: %s' % args.recog_gnmt_decoding) logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention) logger.info('reverse LM rescoring: %s' % args.recog_reverse_lm_rescoring) logger.info('resolving UNK: %s' % args.recog_resolving_unk) logger.info('ensemble: %d' % (len(ensemble_models))) logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over)) logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over)) logger.info('cache size: %d' % (args.recog_n_caches)) logger.info('cache type: %s' % (args.recog_cache_type)) logger.info('cache word frequency threshold: %s' % (args.recog_cache_word_freq)) logger.info('cache theta (speech): %.3f' % (args.recog_cache_theta_speech)) logger.info('cache lambda (speech): %.3f' % (args.recog_cache_lambda_speech)) logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm)) logger.info('cache lambda (lm): %.3f' % (args.recog_cache_lambda_lm)) # GPU setting model.cuda() # TODO(hirofumi): move this save_path = mkdir_join(args.recog_dir, 'att_weights') if args.recog_n_caches > 0: save_path_cache = mkdir_join(args.recog_dir, 'cache') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) if args.recog_n_caches > 0: shutil.rmtree(save_path_cache) os.mkdir(save_path_cache) while True: batch, is_new_epoch = dataset.next(recog_params['recog_batch_size']) best_hyps_id, aws, (cache_attn_hist, cache_id_hist) = model.decode( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=False, refs_id=batch['ys'], ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers']) if model.bwd_weight > 0.5: # Reverse the order best_hyps_id = [hyp[::-1] for hyp in best_hyps_id] aws = [aw[::-1] for aw in aws] for b in range(len(batch['xs'])): tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True) spk = batch['speakers'][b] plot_attention_weights( aws[b][:len(tokens)], tokens, spectrogram=batch['xs'][b][:, :dataset.input_dim] if args.input_type == 'speech' else None, save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) if args.recog_n_caches > 0 and cache_id_hist is not None and cache_attn_hist is not None: n_keys, n_queries = cache_attn_hist[0].shape # mask = np.ones((n_keys, n_queries)) # for i in range(n_queries): # mask[:n_keys - i, -(i + 1)] = 0 mask = np.zeros((n_keys, n_queries)) plot_cache_weights( cache_attn_hist[0], keys=dataset.idx2token[0](cache_id_hist[-1], return_list=True), # fifo # keys=dataset.idx2token[0](cache_id_hist, return_list=True), # dict queries=tokens, save_path=mkdir_join(save_path_cache, spk, batch['utt_ids'][b] + '.png'), figsize=(40, 16), mask=mask) if model.bwd_weight > 0.5: hyp = ' '.join(tokens[::-1]) else: hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % batch['text'][b].lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging if os.path.isfile(os.path.join(args.plot_dir, 'plot.log')): os.remove(os.path.join(args.plot_dir, 'plot.log')) logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding') for i, set in enumerate(args.eval_sets): subsample_factor = 1 subsample_factor_sub1 = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings: for p in args.conv_poolings.split('_'): p = int(p.split(',')[0].replace('(', '')) if p > 1: subsample_factor *= p if args.train_set_sub1 is not None: subsample_factor_sub1 = subsample_factor * np.prod( subsample[:args.enc_nlayers_sub1 - 1]) subsample_factor *= np.prod(subsample) # Load dataset dataset = Dataset( csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub1=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None, wp_model=os.path.join(args.model, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, batch_size=args.batch_size, is_test=True) if i == 0: args.vocab = dataset.vocab args.vocab_sub1 = dataset.vocab_sub1 args.input_dim = dataset.input_dim # TODO(hirofumi): For cold fusion args.rnnlm_cold_fusion = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # GPU setting model.cuda() logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.plot_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] # Get CTC probs ctc_probs, indices_topk, x_lens = model.get_ctc_posteriors( batch['xs'], temperature=1, topk=min(100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): if args.unit == 'word': token_list = dataset.idx2word(best_hyps[b], return_list=True) elif args.unit == 'wp': token_list = dataset.idx2wp(best_hyps[b], return_list=True) elif args.unit == 'char': token_list = dataset.idx2char(best_hyps[b], return_list=True) elif args.unit == 'phone': token_list = dataset.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError(args.unit) token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) plot_ctc_probs( ctc_probs[b, :x_lens[b]], indices_topk[b], nframes=x_lens[b], subsample_factor=subsample_factor, spectrogram=batch['xs'][b][:, :dataset.input_dim], save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) ref = ys[b] hyp = ' '.join(token_list) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref.lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging if os.path.isfile(os.path.join(args.decode_dir, 'decode.log')): os.remove(os.path.join(args.decode_dir, 'decode.log')) logger = set_logger(os.path.join(args.decode_dir, 'decode.log'), key='decoding') wer_mean, cer_mean, per_mean = 0, 0, 0 for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset(csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub1=os.path.join(args.model, 'dict_sub1.txt') if os.path.isfile( os.path.join(args.model, 'dict_sub1.txt')) else None, dict_path_sub2=os.path.join(args.model, 'dict_sub2.txt') if os.path.isfile( os.path.join(args.model, 'dict_sub2.txt')) else None, wp_model=os.path.join(args.model, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, unit_sub2=args.unit_sub2, batch_size=args.batch_size, is_test=True) if i == 0: args.vocab = eval_set.vocab args.vocab_sub1 = eval_set.vocab_sub1 args.input_dim = eval_set.input_dim # For cold fusion # if args.rnnlm_cold_fusion: # # Load a RNNLM config file # config['rnnlm_config'] = load_config(os.path.join(args.model, 'config_rnnlm.yml')) # # assert args.unit == config['rnnlm_config']['unit'] # rnnlm_args.vocab = eval_set.vocab # logger.info('RNNLM path: %s' % config['rnnlm']) # logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) # else: # pass args.rnnlm_cold_fusion = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if (not args.rnnlm_cold_fusion) and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.unit == args_rnnlm.unit args_rnnlm.vocab = eval_set.vocab # Load the pre-trianed RNNLM seq_rnnlm = SeqRNNLM(args_rnnlm) seq_rnnlm.load_checkpoint(args.rnnlm, epoch=-1) # Copy parameters rnnlm = RNNLM(args_rnnlm) rnnlm.copy_from_seqrnnlm(seq_rnnlm) if args_rnnlm.backward: model.rnnlm_bwd = rnnlm else: model.rnnlm_fwd = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.cuda() logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) start_time = time.time() if args.unit in ['word', 'word_char'] and not args.recog_unit: wer, nsub, nins, ndel, noov_total = eval_word( [model], eval_set, decode_params, epoch=epoch - 1, decode_dir=args.decode_dir, progressbar=True) wer_mean += wer logger.info('WER (%s): %.3f %%' % (eval_set.set, wer)) logger.info('SUB: %.3f / INS: %.3f / DEL: %.3f' % (nsub, nins, ndel)) logger.info('OOV (total): %d' % (noov_total)) elif (args.unit == 'wp' and not args.recog_unit) or args.recog_unit == 'wp': wer, nsub, nins, ndel = eval_wordpiece( [model], eval_set, decode_params, epoch=epoch - 1, decode_dir=args.decode_dir, progressbar=True) wer_mean += wer logger.info('WER (%s): %.3f %%' % (eval_set.set, wer)) logger.info('SUB: %.3f / INS: %.3f / DEL: %.3f' % (nsub, nins, ndel)) elif ('char' in args.unit and not args.recog_unit) or 'char' in args.recog_unit: (wer, nsub, nins, ndel), (cer, _, _, _) = eval_char( [model], eval_set, decode_params, epoch=epoch - 1, decode_dir=args.decode_dir, progressbar=True, task_id=1 if args.recog_unit and 'char' in args.recog_unit else 0) wer_mean += wer cer_mean += cer logger.info('WER / CER (%s): %.3f / %.3f %%' % (eval_set.set, wer, cer)) logger.info('SUB: %.3f / INS: %.3f / DEL: %.3f' % (nsub, nins, ndel)) elif 'phone' in args.unit: per, nsub, nins, ndel = eval_phone( [model], eval_set, decode_params, epoch=epoch - 1, decode_dir=args.decode_dir, progressbar=True) per_mean += per logger.info('PER (%s): %.3f %%' % (eval_set.set, per)) logger.info('SUB: %.3f / INS: %.3f / DEL: %.3f' % (nsub, nins, ndel)) else: raise ValueError(args.unit) logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time)) if args.unit == 'word': logger.info('WER (mean): %.3f %%\n' % (wer_mean / len(args.eval_sets))) if args.unit == 'wp': logger.info('WER (mean): %.3f %%\n' % (wer_mean / len(args.eval_sets))) elif 'char' in args.unit: logger.info('WER / CER (mean): %.3f / %.3f %%\n' % (wer_mean / len(args.eval_sets), cer_mean / len(args.eval_sets))) elif 'phone' in args.unit: logger.info('PER (mean): %.3f %%\n' % (per_mean / len(args.eval_sets)))
def main(): args = parse() # Load a conf file dir_name = os.path.dirname(args.recog_model[0]) conf = load_config(os.path.join(dir_name, 'conf.yml')) # Overwrite conf for k, v in conf.items(): if 'recog' not in k: setattr(args, k, v) recog_params = vars(args) # Setting for logging if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')): os.remove(os.path.join(args.recog_dir, 'plot.log')) logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding') for i, s in enumerate(args.recog_sets): subsample_factor = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings: for p in args.conv_poolings.split('_'): p = int(p.split(',')[0].replace('(', '')) if p > 1: subsample_factor *= p subsample_factor *= np.prod(subsample) # Load dataset dataset = Dataset(corpus=args.corpus, tsv_path=s, dict_path=os.path.join(dir_name, 'dict.txt'), dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile( os.path.join(dir_name, 'dict_sub1.txt')) else False, nlsyms=args.nlsyms, wp_model=os.path.join(dir_name, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, batch_size=args.recog_batch_size, concat_prev_n_utterances=args.recog_concat_prev_n_utterances, is_test=True) if i == 0: # Load the ASR model model = Seq2seq(args) epoch = model.load_checkpoint(args.recog_model[0])['epoch'] model.save_path = dir_name if not args.recog_unit: args.recog_unit = args.unit logger.info('recog unit: %s' % args.recog_unit) logger.info('epoch: %d' % (epoch - 1)) logger.info('batch size: %d' % args.recog_batch_size) # GPU setting model.cuda() # TODO(hirofumi): move this save_path = mkdir_join(args.plot_dir, 'ctc_probs') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = dataset.next(recog_params['recog_batch_size']) best_hyps_id, aws, perm_ids, _ = model.decode(batch['xs'], recog_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_ids] # Get CTC probs ctc_probs, indices_topk, xlens = model.get_ctc_posteriors( batch['xs'], temperature=1, topk=min(100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True) tokens = [unicode(t, 'utf-8') for t in tokens] spk = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2]) plot_ctc_probs( ctc_probs[b, :xlens[b]], indices_topk[b], nframes=xlens[b], subsample_factor=subsample_factor, spectrogram=batch['xs'][b][:, :dataset.input_dim], save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) ref = ys[b] hyp = ' '.join(tokens) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref.lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break