def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding') for i, set in enumerate(args.eval_sets): # Load dataset eval_set = Dataset(csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile( os.path.join(args.model, 'dict_sub.txt')) else None, wp_model=os.path.join(args.model, 'wp.model'), unit=args.unit, batch_size=args.batch_size, max_num_frames=args.max_num_frames, min_num_frames=args.min_num_frames, is_test=True) if i == 0: args.vocab = eval_set.vocab args.vocab_sub = eval_set.vocab_sub args.input_dim = eval_set.input_dim # TODO(hirofumi): For cold fusion args.rnnlm_cold_fusion = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # For shallow fusion if args.rnnlm_cold_fusion is None and args.rnnlm is not None and args.rnnlm_weight > 0: # Load a RNNLM config file config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml')) # Merge config with args args_rnnlm = argparse.Namespace() for k, v in config_rnnlm.items(): setattr(args_rnnlm, k, v) assert args.unit == args_rnnlm.unit args_rnnlm.vocab = eval_set.vocab # Load the pre-trianed RNNLM rnnlm = RNNLM(args_rnnlm) rnnlm.load_checkpoint(args.rnnlm, epoch=-1) if args_rnnlm.backward: model.rnnlm_bwd_0 = rnnlm else: model.rnnlm_fwd_0 = rnnlm logger.info('RNNLM path: %s' % args.rnnlm) logger.info('RNNLM weight: %.3f' % args.rnnlm_weight) logger.info('RNNLM backward: %s' % str(config_rnnlm['backward'])) # GPU setting model.cuda() logger.info('beam width: %d' % args.beam_width) logger.info('length penalty: %.3f' % args.length_penalty) logger.info('coverage penalty: %.3f' % args.coverage_penalty) logger.info('coverage threshold: %.3f' % args.coverage_threshold) logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.plot_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = eval_set.next(decode_params['batch_size']) best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] if model.bwd_weight > 0.5: # Reverse the order best_hyps = [hyp[::-1] for hyp in best_hyps] aws = [aw[::-1] for aw in aws] for b in range(len(batch['xs'])): if args.unit == 'word': token_list = eval_set.idx2word(best_hyps[b], return_list=True) if args.unit == 'wp': token_list = eval_set.idx2wp(best_hyps[b], return_list=True) elif args.unit == 'char': token_list = eval_set.idx2char(best_hyps[b], return_list=True) elif args.unit == 'phone': token_list = eval_set.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError(args.unit) token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2]) # error check assert len(batch['xs'][b]) <= 2000 plot_attention_weights(aws[b][:len(token_list)], label_list=token_list, spectrogram=batch['xs'][b][:, :eval_set.input_dim] if args.input_type == 'speech' else None, save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) ref = ys[b] if model.bwd_weight > 0.5: hyp = ' '.join(token_list[::-1]) else: hyp = ' '.join(token_list) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref.lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break
def main(): # Load a config file config = load_config(os.path.join(args.model, 'config.yml')) decode_params = vars(args) # Merge config with args for k, v in config.items(): if not hasattr(args, k): setattr(args, k, v) # Setting for logging if os.path.isfile(os.path.join(args.plot_dir, 'plot.log')): os.remove(os.path.join(args.plot_dir, 'plot.log')) logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding') for i, set in enumerate(args.eval_sets): subsample_factor = 1 subsample_factor_sub1 = 1 subsample = [int(s) for s in args.subsample.split('_')] if args.conv_poolings: for p in args.conv_poolings.split('_'): p = int(p.split(',')[0].replace('(', '')) if p > 1: subsample_factor *= p if args.train_set_sub1 is not None: subsample_factor_sub1 = subsample_factor * np.prod( subsample[:args.enc_nlayers_sub1 - 1]) subsample_factor *= np.prod(subsample) # Load dataset dataset = Dataset( csv_path=set, dict_path=os.path.join(args.model, 'dict.txt'), dict_path_sub1=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None, wp_model=os.path.join(args.model, 'wp.model'), unit=args.unit, unit_sub1=args.unit_sub1, batch_size=args.batch_size, is_test=True) if i == 0: args.vocab = dataset.vocab args.vocab_sub1 = dataset.vocab_sub1 args.input_dim = dataset.input_dim # TODO(hirofumi): For cold fusion args.rnnlm_cold_fusion = None args.rnnlm_init = None # Load the ASR model model = Seq2seq(args) epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch) model.save_path = args.model # GPU setting model.cuda() logger.info('epoch: %d' % (epoch - 1)) save_path = mkdir_join(args.plot_dir, 'att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=False) ys = [batch['ys'][i] for i in perm_idx] # Get CTC probs ctc_probs, indices_topk, x_lens = model.get_ctc_posteriors( batch['xs'], temperature=1, topk=min(100, model.vocab)) # NOTE: ctc_probs: '[B, T, topk]' for b in range(len(batch['xs'])): if args.unit == 'word': token_list = dataset.idx2word(best_hyps[b], return_list=True) elif args.unit == 'wp': token_list = dataset.idx2wp(best_hyps[b], return_list=True) elif args.unit == 'char': token_list = dataset.idx2char(best_hyps[b], return_list=True) elif args.unit == 'phone': token_list = dataset.idx2phone(best_hyps[b], return_list=True) else: raise NotImplementedError(args.unit) token_list = [unicode(t, 'utf-8') for t in token_list] speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) plot_ctc_probs( ctc_probs[b, :x_lens[b]], indices_topk[b], nframes=x_lens[b], subsample_factor=subsample_factor, spectrogram=batch['xs'][b][:, :dataset.input_dim], save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'), figsize=(20, 8)) ref = ys[b] hyp = ' '.join(token_list) logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref.lower()) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if is_new_epoch: break