def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding') for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset(csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile( os.path.join(args.model, 'dict_sub.txt')) else None, wp_model=os.path.join(args.model, 'wp.model'), unit=args.unit, batch_size=args.batch_size, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, is_test=True) if i == 0: args.vocab = eval_set.vocab args.vocab_sub = eval_set.vocab_sub args.input_dim = eval_set.input_dim # TODO(hirofumi): For cold fusion args.rnnlm_cold_fusion = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cold_fusion is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.unit == args_rnnlm.unit args_rnnlm.vocab = eval_set.vocab # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.cuda() logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.plot_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = eval_set.next(decode_params['batch_size']) best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] if model.bwd_weight > 0.5: # Reverse the order best_hyps = [hyp[::-1] for hyp in best_hyps] aws = [aw[::-1] for aw in aws] for b in range(len(batch['xs'])): if args.unit == 'word': token_list = eval_set.idx2word(best_hyps[b], return_list=True) if args.unit == 'wp': token_list = eval_set.idx2wp(best_hyps[b], return_list=True) elif args.unit == 'char': token_list = eval_set.idx2char(best_hyps[b], return_list=True) elif args.unit == 'phone': token_list = eval_set.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError(args.unit) token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2]) # error check assert len(batch['xs'][b]) <= 2000 plot_attention_weights(aws[b][:len(token_list)], label_list=token_list, spectrogram=batch['xs'][b][:, :eval_set.input_dim] if args.input_type == 'speech' else None, save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) ref = ys[b] if model.bwd_weight > 0.5: hyp = ' '.join(token_list[::-1]) else: hyp = ' '.join(token_list) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref.lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding') for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset( csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None, label_type=args.label_type, batch_size=args.batch_size, max_epoch=args.num_epochs, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, is_test=False) if i == 0: args.num_classes = eval_set.num_classes args.input_dim = eval_set.input_dim args.num_classes_sub = eval_set.num_classes_sub # TODO(hirofumi): For cold fusion args.rnnlm_cf = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) # Restore the saved parameters epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config( os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.label_type == args_rnnlm.label_type args_rnnlm.num_classes = eval_set.num_classes # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.set_cuda(deterministic=False, benchmark=True) logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.model, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = eval_set.next(decode_params['batch_size']) best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] for b in range(len(batch['xs'])): if args.label_type in ['word', 'wordpiece']: token_list = eval_set.idx2word(best_hyps[b], return_list=True) elif args.label_type == 'char': token_list = eval_set.idx2char(best_hyps[b], return_list=True) elif args.label_type == 'phone': token_list = eval_set.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError() token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) # error check assert len(batch['xs'][b]) <= 2000 plot_attention_weights( aw[b][:len(token_list)], label_list=token_list, spectrogram=batch['xs'][b][:, :eval_set.input_dim] if args.input_type == 'speech' else None, save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) # Reference if eval_set.is_test: text_ref = ys[b] else: if args.label_type in ['word', 'wordpiece']: text_ref = eval_set.idx2word(ys[b]) if args.label_type in ['word', 'wordpiece']: token_list = eval_set.idx2word(ys[b]) elif args.label_type == 'char': token_list = eval_set.idx2char(ys[b]) elif args.label_type == 'phone': token_list = eval_set.idx2phone(ys[b]) # Hypothesis text_hyp = ' '.join(token_list) sys.stdout = open( os.path.join(save_path, speaker, batch['utt_ids'][b] + '.txt'), 'w') ler = wer_align( ref=text_ref.split(' '), hyp=text_hyp.encode('utf-8').split(' '), normalize=True, double_byte=False)[0] # TODO(hirofumi): add corpus to args print('\nLER: %.3f %%\n\n' % ler) if is_new_epoch: break
def main(): # Load a config file if args.resume_model is None: config = load_config(args.config) else: # Restart from the last checkpoint config = load_config(os.path.join(args.resume_model, 'config.yml')) # Check differences between args and yaml comfiguraiton for k, v in vars(args).items(): if k not in config.keys(): warnings.warn("key %s is automatically set to %s" % (k, str(v))) # Merge config with args for k, v in config.items(): setattr(args, k, v) # Automatically reduce batch size in multi-GPU setting if args.ngpus > 1: args.batch_size -= 10 args.print_step //= args.ngpus subsample_factor = 1 subsample_factor_sub = 1 for p in args.conv_poolings: if len(p) > 0: subsample_factor *= p[0] if args.train_set_sub is not None: subsample_factor_sub = subsample_factor * (2**sum( args.subsample[:args.enc_num_layers_sub - 1])) subsample_factor *= 2**sum(args.subsample) # Load dataset train_set = Dataset(csv_path=args.train_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, max_epoch=args.num_epochs, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, sort_by_input_length=True, short2long=True, sort_stop_epoch=args.sort_stop_epoch, dynamic_batching=True, use_ctc=args.ctc_weight > 0, subsample_factor=subsample_factor, csv_path_sub=args.train_set_sub, dict_path_sub=args.dict_sub, label_type_sub=args.label_type_sub, use_ctc_sub=args.ctc_weight_sub > 0, subsample_factor_sub=subsample_factor_sub, skip_speech=(args.input_type != 'speech')) dev_set = Dataset(csv_path=args.dev_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, max_epoch=args.num_epochs, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, shuffle=True, use_ctc=args.ctc_weight > 0, subsample_factor=subsample_factor, csv_path_sub=args.dev_set_sub, dict_path_sub=args.dict_sub, label_type_sub=args.label_type_sub, use_ctc_sub=args.ctc_weight_sub > 0, subsample_factor_sub=subsample_factor_sub, skip_speech=(args.input_type != 'speech')) eval_sets = [] for set in args.eval_sets: eval_sets += [ Dataset(csv_path=set, dict_path=args.dict, label_type=args.label_type, batch_size=1, max_epoch=args.num_epochs, is_test=True, skip_speech=(args.input_type != 'speech')) ] args.num_classes = train_set.num_classes args.input_dim = train_set.input_dim args.num_classes_sub = train_set.num_classes_sub # Load a RNNLM config file for cold fusion & RNNLM initialization # if config['rnnlm_cf']: # if args.model is not None: # config['rnnlm_config_cold_fusion'] = load_config( # os.path.join(config['rnnlm_cf'], 'config.yml'), is_eval=True) # elif args.resume_model is not None: # config = load_config(os.path.join( # args.resume_model, 'config_rnnlm_cf.yml')) # assert args.label_type == config['rnnlm_config_cold_fusion']['label_type'] # config['rnnlm_config_cold_fusion']['num_classes'] = train_set.num_classes args.rnnlm_cf = None args.rnnlm_init = None # Model setting model = Seq2seq(args) model.name = args.enc_type if len(args.conv_channels) > 0: tmp = model.name model.name = 'conv' + str(len(args.conv_channels)) + 'L' if args.conv_batch_norm: model.name += 'bn' model.name += tmp model.name += str(args.enc_num_units) + 'H' model.name += str(args.enc_num_projs) + 'P' model.name += str(args.enc_num_layers) + 'L' model.name += '_subsample' + str(subsample_factor) model.name += '_' + args.dec_type model.name += str(args.dec_num_units) + 'H' # model.name += str(args.dec_num_projs) + 'P' model.name += str(args.dec_num_layers) + 'L' model.name += '_' + args.att_type if args.att_num_heads > 1: model.name += '_head' + str(args.att_num_heads) model.name += '_' + args.optimizer model.name += '_lr' + str(args.learning_rate) model.name += '_bs' + str(args.batch_size) model.name += '_ss' + str(args.ss_prob) model.name += '_ls' + str(args.lsm_prob) if args.ctc_weight > 0: model.name += '_ctc' + str(args.ctc_weight) if args.bwd_weight > 0: model.name += '_bwd' + str(args.bwd_weight) if args.main_task_weight < 1: model.name += '_main' + str(args.main_task_weight) if args.ctc_weight_sub > 0: model.name += '_ctcsub' + str(args.ctc_weight_sub * (1 - args.main_task_weight)) else: model.name += '_attsub' + str(1 - args.main_task_weight) if args.resume_model is None: # Load pre-trained RNNLM # if config['rnnlm_cf']: # rnnlm = RNNLM(args) # rnnlm.load_checkpoint(save_path=config['rnnlm_cf'], epoch=-1) # rnnlm.flatten_parameters() # # # Fix RNNLM parameters # for param in rnnlm.parameters(): # param.requires_grad = False # # # Set pre-trained parameters # if config['rnnlm_config_cold_fusion']['backward']: # model.dec_0_bwd.rnnlm = rnnlm # else: # model.dec_0_fwd.rnnlm = rnnlm # TODO(hirofumi): 最初にRNNLMのモデルをコピー # Set save path save_path = mkdir_join( args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name) model.set_save_path(save_path) # avoid overwriting # Save the config file as a yaml file save_config(vars(args), model.save_path) # Save the dictionary & wp_model shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.dict_sub is not None: shutil.copy(args.dict_sub, os.path.join(save_path, 'dict_sub.txt')) if args.label_type == 'wordpiece': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) # Setting for logging logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # if os.path.isdir(args.pretrained_model): # # NOTE: Start training from the pre-trained model # # This is defferent from resuming training # model.load_checkpoint(args.pretrained_model, epoch=-1, # load_pretrained_model=True) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] logger.info("%s %d" % (name, num_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate_init=float(args.learning_rate), weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) epoch, step = 1, 0 learning_rate = float(args.learning_rate) metric_dev_best = 10000 # NOTE: Restart from the last checkpoint # elif args.resume_model is not None: # # Set save path # model.save_path = args.resume_model # # # Setting for logging # logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') # # # Set optimizer # model.set_optimizer( # optimizer=config['optimizer'], # learning_rate_init=float(config['learning_rate']), # on-the-fly # weight_decay=float(config['weight_decay']), # clip_grad_norm=config['clip_grad_norm'], # lr_schedule=False, # factor=config['decay_rate'], # patience_epoch=config['decay_patient_epoch']) # # # Restore the last saved model # epoch, step, learning_rate, metric_dev_best = model.load_checkpoint( # save_path=args.resume_model, epoch=-1, restart=True) # # if epoch >= config['convert_to_sgd_epoch']: # model.set_optimizer( # optimizer='sgd', # learning_rate_init=float(config['learning_rate']), # on-the-fly # weight_decay=float(config['weight_decay']), # clip_grad_norm=config['clip_grad_norm'], # lr_schedule=False, # factor=config['decay_rate'], # patience_epoch=config['decay_patient_epoch']) # # if config['rnnlm_cf']: # if config['rnnlm_config_cold_fusion']['backward']: # model.rnnlm_0_bwd.flatten_parameters() # else: # model.rnnlm_0_fwd.flatten_parameters() train_set.epoch = epoch - 1 # start from index:0 # GPU setting if args.ngpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.ngpus, 1)), deterministic=False, benchmark=True) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name # setproctitle(args.job_name) # Set learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_epoch=args.decay_patient_epoch, lower_better=True, best_value=metric_dev_best) # Set reporter reporter = Reporter(model.module.save_path, max_loss=300) # Set the updater updater = Updater(args.clip_grad_norm) # Setting for tensorboard tf_writer = SummaryWriter(model.module.save_path) start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0. loss_train_mean, acc_train_mean = 0., 0. pbar_epoch = tqdm(total=len(train_set)) pbar_all = tqdm(total=len(train_set) * args.num_epochs) while True: # Compute loss in the training set (including parameter update) batch_train, is_new_epoch = train_set.next() model, loss_train, acc_train = updater(model, batch_train) loss_train_mean += loss_train acc_train_mean += acc_train pbar_epoch.update(len(batch_train['utt_ids'])) if (step + 1) % args.print_step == 0: # Compute loss in the dev set batch_dev = dev_set.next()[0] model, loss_dev, acc_dev = updater(model, batch_dev, is_eval=True) loss_train_mean /= args.print_step acc_train_mean /= args.print_step reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev) # Logging by tensorboard tf_writer.add_scalar('train/loss', loss_train_mean, step + 1) tf_writer.add_scalar('dev/loss', loss_dev, step + 1) # for n, p in model.module.named_parameters(): # n = n.replace('.', '/') # if p.grad is not None: # tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1) # tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1) duration_step = time.time() - start_time_step if args.input_type == 'speech': x_len = max(len(x) for x in batch_train['xs']) elif args.input_type == 'text': x_len = max(len(x) for x in batch_train['ys_sub']) logger.info( "...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/lr:%.5f/bs:%d/x_len:%d (%.2f min)" % (step + 1, train_set.epoch_detail, loss_train_mean, loss_dev, acc_train_mean, acc_dev, learning_rate, train_set.current_batch_size, x_len, duration_step / 60)) start_time_step = time.time() loss_train_mean, acc_train_mean = 0, 0 step += args.ngpus # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60)) # Save fugures of loss and accuracy reporter.epoch() if epoch < args.eval_start_epoch: # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) else: start_time_eval = time.time() # dev if args.metric == 'ler': if args.label_type == 'word': metric_dev = eval_word([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif args.label_type == 'wordpiece': metric_dev = eval_wordpiece([model.module], dev_set, decode_params, args.wp_model, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif 'char' in args.label_type: metric_dev = eval_char([model.module], dev_set, decode_params, epoch=epoch)[1][0] logger.info(' CER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif 'phone' in args.label_type: metric_dev = eval_phone([model.module], dev_set, decode_params, epoch=epoch)[0] logger.info(' PER (%s): %.3f %%' % (dev_set.set, metric_dev)) elif args.metric == 'loss': metric_dev = eval_loss([model.module], dev_set, decode_params) logger.info(' Loss (%s): %.3f %%' % (dev_set.set, metric_dev)) else: raise NotImplementedError() if metric_dev < metric_dev_best: metric_dev_best = metric_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=metric_dev) # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) # test for eval_set in eval_sets: if args.metric == 'ler': if args.label_type == 'word': wer_test = eval_word([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif args.label_type == 'wordpiece': wer_test = eval_wordpiece([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer_test)) elif 'char' in args.label_type: cer_test = eval_char([model.module], eval_set, decode_params, epoch=epoch)[1][0] logger.info(' CER (%s): %.3f / %.3f %%' % (eval_set.set, cer_test)) elif 'phone' in args.label_type: per_test = eval_phone([model.module], eval_set, decode_params, epoch=epoch)[0] logger.info(' PER (%s): %.3f %%' % (eval_set.set, per_test)) elif args.metric == 'loss': loss_test = eval_loss([model.module], eval_set, decode_params) logger.info(' Loss (%s): %.3f %%' % (eval_set.set, loss_test)) else: raise NotImplementedError() else: # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=metric_dev) not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_epoch: break if epoch == args.convert_to_sgd_epoch: # Convert to fine-tuning stage model.module.set_optimizer( 'sgd', learning_rate_init=float( args.learning_rate), # TODO: ? weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) pbar_all.update(len(train_set)) if epoch == args.num_epochs: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) tf_writer.close() pbar_epoch.close() pbar_all.close() return model.module.save_path
def main(): # Load a config file if args.resume_model is None: config = load_config(args.config) else: # Restart from the last checkpoint config = load_config(os.path.join(args.resume_model, 'config.yml')) # Check differences between args and yaml comfiguraiton for k, v in vars(args).items(): if k not in config.keys(): warnings.warn("key %s is automatically set to %s" % (k, str(v))) # Merge config with args for k, v in config.items(): setattr(args, k, v) # Load dataset train_set = Dataset(csv_path=args.train_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, bptt=args.bptt, eos=args.eos, max_epoch=args.num_epochs, shuffle=True) dev_set = Dataset(csv_path=args.dev_set, dict_path=args.dict, label_type=args.label_type, batch_size=args.batch_size * args.ngpus, bptt=args.bptt, eos=args.eos, shuffle=True) eval_sets = [] for set in args.eval_sets: eval_sets += [Dataset(csv_path=set, dict_path=args.dict, label_type=args.label_type, batch_size=1, bptt=args.bptt, eos=args.eos, is_test=True)] args.num_classes = train_set.num_classes # Model setting model = RNNLM(args) model.name = args.rnn_type model.name += str(args.num_units) + 'H' model.name += str(args.num_projs) + 'P' model.name += str(args.num_layers) + 'L' model.name += '_emb' + str(args.emb_dim) model.name += '_' + args.optimizer model.name += '_lr' + str(args.learning_rate) model.name += '_bs' + str(args.batch_size) if args.tie_weights: model.name += '_tie' if args.residual: model.name += '_residual' if args.backward: model.name += '_bwd' if args.resume_model is None: # Set save path save_path = mkdir_join(args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name) model.set_save_path(save_path) # avoid overwriting # Save the config file as a yaml file save_config(vars(args), model.save_path) # Save the dictionary & wp_model shutil.copy(args.dict, os.path.join(save_path, 'dict.txt')) if args.label_type == 'wordpiece': shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model')) # Setting for logging logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training') for k, v in sorted(vars(args).items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] logger.info("%s %d" % (name, num_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) # Set optimizer model.set_optimizer(optimizer=args.optimizer, learning_rate_init=float(args.learning_rate), weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) epoch, step = 1, 0 learning_rate = float(args.learning_rate) metric_dev_best = 10000 else: raise NotImplementedError() train_set.epoch = epoch - 1 # GPU setting if args.ngpus >= 1: model = CustomDataParallel(model, device_ids=list(range(0, args.ngpus, 1)), deterministic=True, benchmark=False) model.cuda() logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) # Set process name # setproctitle(args.job_name) # Set learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, decay_type=args.decay_type, decay_start_epoch=args.decay_start_epoch, decay_rate=args.decay_rate, decay_patient_epoch=args.decay_patient_epoch, lower_better=True, best_value=metric_dev_best) # Set reporter reporter = Reporter(model.module.save_path, max_loss=10) # Set the updater updater = Updater(args.clip_grad_norm) # Setting for tensorboard tf_writer = SummaryWriter(model.module.save_path) start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() not_improved_epoch = 0 loss_train_mean, acc_train_mean = 0., 0. pbar_epoch = tqdm(total=len(train_set)) pbar_all = tqdm(total=len(train_set) * args.num_epochs) while True: # Compute loss in the training set (including parameter update) ys_train, is_new_epoch = train_set.next() model, loss_train, acc_train = updater(model, ys_train, args.bptt) loss_train_mean += loss_train acc_train_mean += acc_train pbar_epoch.update(np.sum([len(y) for y in ys_train])) if (step + 1) % args.print_step == 0: # Compute loss in the dev set ys_dev = dev_set.next()[0] model, loss_dev, acc_dev = updater(model, ys_dev, args.bptt, is_eval=True) loss_train_mean /= args.print_step acc_train_mean /= args.print_step reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev) # Logging by tensorboard tf_writer.add_scalar('train/loss', loss_train_mean, step + 1) tf_writer.add_scalar('dev/loss', loss_dev, step + 1) for n, p in model.module.named_parameters(): n = n.replace('.', '/') if p.grad is not None: tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1) tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1) duration_step = time.time() - start_time_step logger.info("...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/ppl:%.2f(%.2f)/lr:%.5f/bs:%d (%.2f min)" % (step + 1, train_set.epoch_detail, loss_train_mean, loss_dev, acc_train_mean, acc_dev, math.exp(loss_train_mean), math.exp(loss_dev), learning_rate, len(ys_train), duration_step / 60)) start_time_step = time.time() loss_train_mean, acc_train_mean = 0., 0. step += args.ngpus # Save checkpoint and evaluate model per epoch if is_new_epoch: duration_epoch = time.time() - start_time_epoch logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60)) # Save fugures of loss and accuracy reporter.epoch() if epoch < args.eval_start_epoch: # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) else: start_time_eval = time.time() # dev ppl_dev = eval_ppl([model.module], dev_set, args.bptt) logger.info(' PPL (%s): %.3f' % (dev_set.set, ppl_dev)) if ppl_dev < metric_dev_best: metric_dev_best = ppl_dev not_improved_epoch = 0 logger.info('||||| Best Score |||||') # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=ppl_dev) # Save the model model.module.save_checkpoint(model.module.save_path, epoch, step, learning_rate, metric_dev_best) # test ppl_test_mean = 0. for eval_set in eval_sets: ppl_test = eval_ppl([model.module], eval_set, args.bptt) logger.info(' PPL (%s): %.3f' % (eval_set.set, ppl_test)) ppl_test_mean += ppl_test if len(eval_sets) > 0: logger.info(' PPL (mean): %.3f' % (ppl_test_mean / len(eval_sets))) else: # Update learning rate model.module.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.module.optimizer, learning_rate=learning_rate, epoch=epoch, value=ppl_dev) not_improved_epoch += 1 duration_eval = time.time() - start_time_eval logger.info('Evaluation time: %.2f min' % (duration_eval / 60)) # Early stopping if not_improved_epoch == args.not_improved_patient_epoch: break if epoch == args.convert_to_sgd_epoch: # Convert to fine-tuning stage model.module.set_optimizer( 'sgd', learning_rate_init=float(args.learning_rate), # TODO: ? weight_decay=float(args.weight_decay), clip_grad_norm=args.clip_grad_norm, lr_schedule=False, factor=args.decay_rate, patience_epoch=args.decay_patient_epoch) logger.info('========== Convert to SGD ==========') pbar_epoch = tqdm(total=len(train_set)) pbar_all.update(len(train_set)) if epoch == args.num_epoch: break start_time_step = time.time() start_time_epoch = time.time() epoch += 1 duration_train = time.time() - start_time_train logger.info('Total time: %.2f hour' % (duration_train / 3600)) tf_writer.close() pbar_epoch.close() pbar_all.close() return model.module.save_path
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding') wer_mean, cer_mean, per_mean = 0, 0, 0 for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset( csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None, label_type=args.label_type, batch_size=args.batch_size, max_epoch=args.num_epochs, is_test=True) if i == 0: args.num_classes = eval_set.num_classes args.input_dim = eval_set.input_dim args.num_classes_sub = eval_set.num_classes_sub # For cold fusion # if args.rnnlm_cf: # # Load a RNNLM config file # config['rnnlm_config'] = load_config(os.path.join(args.model, 'config_rnnlm.yml')) # # assert args.label_type == config['rnnlm_config']['label_type'] # rnnlm_args.num_classes = eval_set.num_classes # logger.info('RNNLM path: %s' % config['rnnlm']) # logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) # else: # pass args.rnnlm_cf = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) # Restore the saved parameters epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config( os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.label_type == args_rnnlm.label_type args_rnnlm.num_classes = eval_set.num_classes # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.set_cuda(deterministic=False, benchmark=True) logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) start_time = time.time() if args.label_type == 'word': wer, _, _, _, decode_dir = eval_word([model], eval_set, decode_params, epoch=epoch - 1, progressbar=True) wer_mean += wer logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer)) elif args.label_type == 'wordpiece': wer, _, _, _, decode_dir = eval_wordpiece([model], eval_set, decode_params, os.path.join( args.model, 'wp.model'), epoch=epoch - 1, progressbar=True) wer_mean += wer logger.info(' WER (%s): %.3f %%' % (eval_set.set, wer)) elif 'char' in args.label_type: (wer, _, _, _), (cer, _, _, _), decode_dir = eval_char([model], eval_set, decode_params, epoch=epoch - 1, progressbar=True) wer_mean += wer cer_mean += cer logger.info(' WER / CER (%s): %.3f / %.3f %%' % (eval_set.set, wer, cer)) elif 'phone' in args.label_type: per, _, _, _, decode_dir = eval_phone([model], eval_set, decode_params, epoch=epoch - 1, progressbar=True) per_mean += per logger.info(' PER (%s): %.3f %%' % (eval_set.set, per)) else: raise ValueError(args.label_type) logger.info('Elasped time: %.2f [sec.]:' % (time.time() - start_time)) if args.label_type == 'word': logger.info(' WER (mean): %.3f %%\n' % (wer_mean / len(args.eval_sets))) if args.label_type == 'wordpiece': logger.info(' WER (mean): %.3f %%\n' % (wer_mean / len(args.eval_sets))) elif 'char' in args.label_type: logger.info( ' WER / CER (mean): %.3f / %.3f %%\n' % (wer_mean / len(args.eval_sets), cer_mean / len(args.eval_sets))) elif 'phone' in args.label_type: logger.info(' PER (mean): %.3f %%\n' % (per_mean / len(args.eval_sets))) print(decode_dir)