def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.recog_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) if getattr(rnnlm_args, "model_module", "default") != "default": raise ValueError( "use '--api v2' option to decode with non-default language model" ) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit, getattr(rnnlm_args, "embed_unit", None), # for backward compatibility ) ) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit) ) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM( word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict ) ) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM( word_rnnlm.predictor, word_dict, char_dict ) ) # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info("gpu id: " + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.recog_json, "rb") as f: js = json.load(f)["utts"] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json( js[name], nbest_hyps, train_args.char_list ) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.recognize_batch( feats, args, train_args.char_list, rnnlm=rnnlm ) for i, name in enumerate(names): nbest_hyp = [hyp[i] for hyp in nbest_hyps] new_js[name] = add_results_to_json( js[name], nbest_hyp, train_args.char_list ) with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )
def __init__(self, subsampling_factor=1, preprocess_conf=None): self.subsampling_factor = subsampling_factor self.load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=preprocess_conf) self.ignore_id = -1
def decode(args): """Decode with E2E VC model.""" set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # show arguments for key in sorted(vars(args).keys()): logging.info("args: " + key + ": " + str(vars(args)[key])) # define model model_class = dynamic_import(train_args.model_module) model = model_class(idim, odim, train_args) assert isinstance(model, TTSInterface) logging.info(model) # load trained model parameters logging.info("reading model parameters from " + args.model) torch_load(args.model, model) model.eval() # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # read json data with open(args.json, "rb") as f: js = json.load(f)["utts"] # check directory outdir = os.path.dirname(args.out) if len(outdir) != 0 and not os.path.exists(outdir): os.makedirs(outdir) load_inputs_and_targets = LoadInputsAndTargets( mode="vc", load_output=False, sort_in_input_length=False, use_speaker_embedding=train_args.use_speaker_embedding, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing ) # define function for plot prob and att_ws def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt shape = array.shape if len(shape) == 1: # for eos probability plt.figure(figsize=figsize, dpi=dpi) plt.plot(array) plt.xlabel("Frame") plt.ylabel("Probability") plt.ylim([0, 1]) elif len(shape) == 2: # for tacotron 2 attention weights, whose shape is (out_length, in_length) plt.figure(figsize=figsize, dpi=dpi) plt.imshow(array, aspect="auto") plt.xlabel("Input") plt.ylabel("Output") elif len(shape) == 4: # for transformer attention weights, # whose shape is (#leyers, #heads, out_length, in_length) plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi) for idx1, xs in enumerate(array): for idx2, x in enumerate(xs, 1): plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) plt.imshow(x, aspect="auto") plt.xlabel("Input") plt.ylabel("Output") else: raise NotImplementedError("Support only from 1D to 4D array.") plt.tight_layout() if not os.path.exists(os.path.dirname(figname)): # NOTE: exist_ok = True is needed for parallel process decoding os.makedirs(os.path.dirname(figname), exist_ok=True) plt.savefig(figname) plt.close() # define function to calculate focus rate # (see section 3.3 in https://arxiv.org/abs/1905.09263) def _calculate_focus_rete(att_ws): if att_ws is None: # fastspeech case -> None return 1.0 elif len(att_ws.shape) == 2: # tacotron 2 case -> (L, T) return float(att_ws.max(dim=-1)[0].mean()) elif len(att_ws.shape) == 4: # transformer case -> (#layers, #heads, L, T) return float(att_ws.max(dim=-1)[0].mean(dim=-1).max()) else: raise ValueError("att_ws should be 2 or 4 dimensional tensor.") # define function to convert attention to duration def _convert_att_to_duration(att_ws): if len(att_ws.shape) == 2: # tacotron 2 case -> (L, T) pass elif len(att_ws.shape) == 4: # transformer case -> (#layers, #heads, L, T) # get the most diagonal head according to focus rate att_ws = torch.cat([att_w for att_w in att_ws], dim=0) # (#heads * #layers, L, T) diagonal_scores = att_ws.max(dim=-1)[0].mean( dim=-1) # (#heads * #layers,) diagonal_head_idx = diagonal_scores.argmax() att_ws = att_ws[diagonal_head_idx] # (L, T) else: raise ValueError("att_ws should be 2 or 4 dimensional tensor.") # calculate duration from 2d attention weight durations = torch.stack( [att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])]) return durations.view(-1, 1).float() # define writer instances feat_writer = kaldiio.WriteHelper( "ark,scp:{o}.ark,{o}.scp".format(o=args.out)) if args.save_durations: dur_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format( o=args.out.replace("feats", "durations"))) if args.save_focus_rates: fr_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format( o=args.out.replace("feats", "focus_rates"))) # start decoding for idx, utt_id in enumerate(js.keys()): # setup inputs batch = [(utt_id, js[utt_id])] data = load_inputs_and_targets(batch) x = torch.FloatTensor(data[0][0]).to(device) spemb = None if train_args.use_speaker_embedding: spemb = torch.FloatTensor(data[1][0]).to(device) # decode and write start_time = time.time() outs, probs, att_ws = model.inference(x, args, spemb=spemb) logging.info("inference speed = %.1f frames / sec." % (int(outs.size(0)) / (time.time() - start_time))) if outs.size(0) == x.size(0) * args.maxlenratio: logging.warning("output length reaches maximum length (%s)." % utt_id) focus_rate = _calculate_focus_rete(att_ws) logging.info("(%d/%d) %s (size: %d->%d, focus rate: %.3f)" % (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0), focus_rate)) feat_writer[utt_id] = outs.cpu().numpy() if args.save_durations: ds = _convert_att_to_duration(att_ws) dur_writer[utt_id] = ds.cpu().numpy() if args.save_focus_rates: fr_writer[utt_id] = np.array(focus_rate).reshape(1, 1) # plot and save prob and att_ws if probs is not None: _plot_and_save( probs.cpu().numpy(), os.path.dirname(args.out) + "/probs/%s_prob.png" % utt_id, ) if att_ws is not None: _plot_and_save( att_ws.cpu().numpy(), os.path.dirname(args.out) + "/att_ws/%s_att_ws.png" % utt_id, ) # close file object feat_writer.close() if args.save_durations: dur_writer.close() if args.save_focus_rates: fr_writer.close()
def enhance(args): """Dumping enhanced speech and mask. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # TODO(ruizhili): implement enhance for multi-encoder model assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format( args.num_encs ) # load trained model parameters logging.info("reading model parameters from " + args.model) model_class = dynamic_import(train_args.model_module) model = model_class(idim, odim, train_args) assert isinstance(model, ASRInterface) torch_load(args.model, model) model.recog_args = args # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info("gpu id: " + str(gpu_id)) model.cuda() # read json data with open(args.recog_json, "rb") as f: js = json.load(f)["utts"] load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=None, # Apply pre_process in outer func ) if args.batchsize == 0: args.batchsize = 1 # Creates writers for outputs from the network if args.enh_wspecifier is not None: enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype) else: enh_writer = None # Creates a Transformation instance preprocess_conf = ( train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf ) if preprocess_conf is not None: logging.info(f"Use preprocessing: {preprocess_conf}") transform = Transformation(preprocess_conf) else: transform = None # Creates a IStft instance istft = None frame_shift = args.istft_n_shift # Used for plot the spectrogram if args.apply_istft: if preprocess_conf is not None: # Read the conffile and find stft setting with open(preprocess_conf) as f: # Json format: e.g. # {"process": [{"type": "stft", # "win_length": 400, # "n_fft": 512, "n_shift": 160, # "window": "han"}, # {"type": "foo", ...}, ...]} conf = json.load(f) assert "process" in conf, conf # Find stft setting for p in conf["process"]: if p["type"] == "stft": istft = IStft( win_length=p["win_length"], n_shift=p["n_shift"], window=p.get("window", "hann"), ) logging.info( "stft is found in {}. " "Setting istft config from it\n{}".format( preprocess_conf, istft ) ) frame_shift = p["n_shift"] break if istft is None: # Set from command line arguments istft = IStft( win_length=args.istft_win_length, n_shift=args.istft_n_shift, window=args.istft_window, ) logging.info( "Setting istft config from the command line args\n{}".format(istft) ) # sort data keys = list(js.keys()) feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) num_images = 0 if not os.path.exists(args.image_dir): os.makedirs(args.image_dir) for names in grouper(args.batchsize, keys, None): batch = [(name, js[name]) for name in names] # May be in time region: (Batch, [Time, Channel]) org_feats = load_inputs_and_targets(batch)[0] if transform is not None: # May be in time-freq region: : (Batch, [Time, Channel, Freq]) feats = transform(org_feats, train=False) else: feats = org_feats with torch.no_grad(): enhanced, mask, ilens = model.enhance(feats) for idx, name in enumerate(names): # Assuming mask, feats : [Batch, Time, Channel. Freq] # enhanced : [Batch, Time, Freq] enh = enhanced[idx][: ilens[idx]] mas = mask[idx][: ilens[idx]] feat = feats[idx] # Plot spectrogram if args.image_dir is not None and num_images < args.num_images: import matplotlib.pyplot as plt num_images += 1 ref_ch = 0 plt.figure(figsize=(20, 10)) plt.subplot(4, 1, 1) plt.title("Mask [ref={}ch]".format(ref_ch)) plot_spectrogram( plt, mas[:, ref_ch].T, fs=args.fs, mode="linear", frame_shift=frame_shift, bottom=False, labelbottom=False, ) plt.subplot(4, 1, 2) plt.title("Noisy speech [ref={}ch]".format(ref_ch)) plot_spectrogram( plt, feat[:, ref_ch].T, fs=args.fs, mode="db", frame_shift=frame_shift, bottom=False, labelbottom=False, ) plt.subplot(4, 1, 3) plt.title("Masked speech [ref={}ch]".format(ref_ch)) plot_spectrogram( plt, (feat[:, ref_ch] * mas[:, ref_ch]).T, frame_shift=frame_shift, fs=args.fs, mode="db", bottom=False, labelbottom=False, ) plt.subplot(4, 1, 4) plt.title("Enhanced speech") plot_spectrogram( plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift ) plt.savefig(os.path.join(args.image_dir, name + ".png")) plt.clf() # Write enhanced wave files if enh_writer is not None: if istft is not None: enh = istft(enh) else: enh = enh if args.keep_length: if len(org_feats[idx]) < len(enh): # Truncate the frames added by stft padding enh = enh[: len(org_feats[idx])] elif len(org_feats) > len(enh): padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [ (0, 0) ] * (enh.ndim - 1) enh = np.pad(enh, padwidth, mode="constant") if args.enh_filetype in ("sound", "sound.hdf5"): enh_writer[name] = (args.fs, enh) else: # Hint: To dump stft_signal, mask or etc, # enh_filetype='hdf5' might be convenient. enh_writer[name] = enh if num_images >= args.num_images and enh_writer is None: logging.info("Breaking the process.") break
def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.recog_args = args if args.streaming_mode and "transformer" in train_args.model_module: raise NotImplementedError("streaming mode for transformer is not implemented") logging.info( " Total parameter of the model = " + str(sum(p.numel() for p in model.parameters())) ) # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) if getattr(rnnlm_args, "model_module", "default") != "default": raise ValueError( "use '--api v2' option to decode with non-default language model" ) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit, getattr(rnnlm_args, "embed_unit", None), # for backward compatibility ) ) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(word_dict), rnnlm_args.layer, rnnlm_args.unit, getattr(rnnlm_args, "embed_unit", None), # for backward compatibility ) ) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM( word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict ) ) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM( word_rnnlm.predictor, word_dict, char_dict ) ) # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info("gpu id: " + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.recog_json, "rb") as f: js = json.load(f)["utts"] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch) feat = ( feat[0][0] if args.num_encs == 1 else [feat[idx][0] for idx in range(model.num_encs)] ) if args.streaming_mode == "window" and args.num_encs == 1: logging.info( "Using streaming recognizer with window size %d frames", args.streaming_window, ) se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) for i in range(0, feat.shape[0], args.streaming_window): logging.info( "Feeding frames %d - %d", i, i + args.streaming_window ) se2e.accept_input(feat[i : i + args.streaming_window]) logging.info("Running offline attention decoder") se2e.decode_with_attention_offline() logging.info("Offline attention decoder finished") nbest_hyps = se2e.retrieve_recognition() elif args.streaming_mode == "segment" and args.num_encs == 1: logging.info( "Using streaming recognizer with threshold value %d", args.streaming_min_blank_dur, ) nbest_hyps = [] for n in range(args.nbest): nbest_hyps.append({"yseq": [], "score": 0.0}) se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) r = np.prod(model.subsample) for i in range(0, feat.shape[0], r): hyps = se2e.accept_input(feat[i : i + r]) if hyps is not None: text = "".join( [ train_args.char_list[int(x)] for x in hyps[0]["yseq"][1:-1] if int(x) != -1 ] ) text = text.replace( "\u2581", " " ).strip() # for SentencePiece text = text.replace(model.space, " ") text = text.replace(model.blank, "") logging.info(text) for n in range(args.nbest): nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"]) nbest_hyps[n]["score"] += hyps[n]["score"] else: nbest_hyps = model.recognize( feat, args, train_args.char_list, rnnlm ) new_js[name] = add_results_to_json( js[name], nbest_hyps, train_args.char_list ) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = ( load_inputs_and_targets(batch)[0] if args.num_encs == 1 else load_inputs_and_targets(batch) ) if args.streaming_mode == "window" and args.num_encs == 1: raise NotImplementedError elif args.streaming_mode == "segment" and args.num_encs == 1: if args.batchsize > 1: raise NotImplementedError feat = feats[0] nbest_hyps = [] for n in range(args.nbest): nbest_hyps.append({"yseq": [], "score": 0.0}) se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) r = np.prod(model.subsample) for i in range(0, feat.shape[0], r): hyps = se2e.accept_input(feat[i : i + r]) if hyps is not None: text = "".join( [ train_args.char_list[int(x)] for x in hyps[0]["yseq"][1:-1] if int(x) != -1 ] ) text = text.replace( "\u2581", " " ).strip() # for SentencePiece text = text.replace(model.space, " ") text = text.replace(model.blank, "") logging.info(text) for n in range(args.nbest): nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"]) nbest_hyps[n]["score"] += hyps[n]["score"] nbest_hyps = [nbest_hyps] else: nbest_hyps = model.recognize_batch( feats, args, train_args.char_list, rnnlm=rnnlm ) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json( js[name], nbest_hyp, train_args.char_list ) with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )
output_path.mkdir(parents=True, exist_ok=True) Path("tmp").mkdir(parents=True, exist_ok=True) for path_wav in data_path.glob("*.wav"): output_file = output_path / (path_wav.name.replace(".wav", ".npz")) print("Predicting: " + path_wav.name) # Compute fbanks features with open("tmp/wav.scp", "w+") as f: f.write("file " + str(path_wav.resolve())) os.system("./fbanks.sh " + cmdargs.cmvn_path) print("Finished fbanks") load_inputs_and_targets = LoadInputsAndTargets(mode='asr', load_output=False, sort_in_input_length=False) with torch.no_grad(): # Load input frames data = { "input": [{ "name": "input1", "feat": str(Path("tmp/feats.1.ark:5").resolve()) }] } full_feat = load_inputs_and_targets([("data", data)])[0][0] if cmdargs.split is not None: # Split audio in multiple parts and decode each one individually all_probs = []
def recog(args): """Decode with the given args :param Namespace args: The program arguments """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture logging.info('reading model parameters from ' + args.model) model = E2E(idim, odim, train_args) chainer_load(args.model, model) # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_chainer.ClassifierWithState( lm_chainer.RNNLM(len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) chainer_load(args.rnnlm, rnnlm) else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_chainer.ClassifierWithState( lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) chainer_load(args.word_rnnlm, word_rnnlm) if rnnlm is not None: rnnlm = lm_chainer.ClassifierWithState( extlm_chainer.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_chainer.ClassifierWithState( extlm_chainer.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf) # decode each utterance new_js = {} with chainer.no_backprop_mode(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] with using_transform_config({'train': False}): feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) # TODO(watanabe) fix character coding problems when saving it with open(args.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_js }, indent=4, sort_keys=True).encode('utf_8'))
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]['output'][0]['shape'][1]) odim = int(valid_json[utts[0]]['input'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]['input'][1]['shape'][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]['input'][1]['shape'][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to' + model_conf) f.write(json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) logging.info('batch size is automatically increased (%d -> %d)' % ( args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer if args.opt == 'adam': optimizer = torch.optim.Adam( model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, 'target', reporter) setattr(optimizer, 'serialize', lambda s: reporter.serialize(s)) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True) valid_batchset = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True) load_tr = LoadInputsAndTargets( mode='tts', use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode='tts', use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.num_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train_batchset, load_tr), batch_size=1, n_processes=args.num_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid_batchset, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.num_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train_batchset, load_tr), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator( TransformDataset(valid_batchset, load_cv), batch_size=1, repeat=False, shuffle=False) # Set up a trainer converter = CustomConverter() updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, converter, device, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device)) # set intervals save_interval = (args.save_interval_epochs, 'epoch') report_interval = (args.report_interval_iters, 'iteration') # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss', trigger=save_interval)) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + '/att_ws', converter=converter, transform=load_cv, device=device, reverse=True) trainer.extend(att_reporter, trigger=save_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ['main/' + key, 'validation/main/' + key] trainer.extend(extensions.PlotReport(plot_key, 'epoch', file_name=key + '.png')) plot_keys += plot_key trainer.extend(extensions.PlotReport(plot_keys, 'epoch', file_name='all_loss.png')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ['epoch', 'iteration', 'elapsed_time'] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar()) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend(ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def decode(args): """Decode with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # show arguments for key in sorted(vars(args).keys()): logging.info('args: ' + key + ': ' + str(vars(args)[key])) # define model model_class = dynamic_import(train_args.model_module) model = model_class(idim, odim, train_args) assert isinstance(model, TTSInterface) logging.info(model) # load trained model parameters logging.info('reading model parameters from ' + args.model) torch_load(args.model, model) model.eval() # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # read json data with open(args.json, 'rb') as f: js = json.load(f)['utts'] # check directory outdir = os.path.dirname(args.out) if len(outdir) != 0 and not os.path.exists(outdir): os.makedirs(outdir) load_inputs_and_targets = LoadInputsAndTargets( mode='tts', load_input=False, sort_in_input_length=False, use_speaker_embedding=train_args.use_speaker_embedding, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # define function for plot prob and att_ws def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): import matplotlib.pyplot as plt shape = array.shape if len(shape) == 1: # for eos probability plt.figure(figsize=figsize, dpi=dpi) plt.plot(array) plt.xlabel("Frame") plt.ylabel("Probability") plt.ylim([0, 1]) elif len(shape) == 2: # for tacotron 2 attention weights, whose shape is (out_length, in_length) plt.figure(figsize=figsize, dpi=dpi) plt.imshow(array, aspect="auto") plt.xlabel("Input") plt.ylabel("Output") elif len(shape) == 4: # for transformer attention weights, whose shape is (#leyers, #heads, out_length, in_length) plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi) for idx1, xs in enumerate(array): for idx2, x in enumerate(xs, 1): plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) plt.imshow(x, aspect="auto") plt.xlabel("Input") plt.ylabel("Output") else: raise NotImplementedError("Support only from 1D to 4D array.") plt.tight_layout() if not os.path.exists(os.path.dirname(figname)): os.makedirs(os.path.dirname(figname)) plt.savefig(figname) plt.clf() with torch.no_grad(), \ kaldiio.WriteHelper('ark,scp:{o}.ark,{o}.scp'.format(o=args.out)) as f: for idx, utt_id in enumerate(js.keys()): batch = [(utt_id, js[utt_id])] data = load_inputs_and_targets(batch) if train_args.use_speaker_embedding: spemb = data[1][0] spemb = torch.FloatTensor(spemb).to(device) else: spemb = None x = data[0][0] x = torch.LongTensor(x).to(device) # decode and write start_time = time.time() outs, probs, att_ws = model.inference(x, args, spemb=spemb) logging.info("inference speed = %s msec / frame." % ( (time.time() - start_time) / (int(outs.size(0)) * 1000))) if outs.size(0) == x.size(0) * args.maxlenratio: logging.warning("output length reaches maximum length (%s)." % utt_id) logging.info('(%d/%d) %s (size:%d->%d)' % ( idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0))) f[utt_id] = outs.cpu().numpy() # plot prob and att_ws if probs is not None: _plot_and_save(probs.cpu().numpy(), os.path.dirname(args.out) + "/probs/%s_prob.png" % utt_id) if att_ws is not None: _plot_and_save(att_ws.cpu().numpy(), os.path.dirname(args.out) + "/att_ws/%s_att_ws.png" % utt_id)
def save_alignment(args): set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.recog_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) if getattr(rnnlm_args, "model_module", "default") != "default": raise ValueError("use '--api v2' option to decode with non-default language model") rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM( len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") dtype = next(model.parameters()).dtype model = model.to(device=device) if rnnlm: rnnlm = rnnlm.to(device=device) # read json data with open(args.json, 'rb') as f: js = json.load(f)['utts'] load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=True, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False}) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # Setup a converter if args.num_encs == 1: converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype) else: converter = CustomConverterMulEnc([i[0] for i in model.subsample_list], dtype=dtype) import matplotlib.pyplot as plt outdir = args.outdir if not os.path.exists(outdir): os.makedirs(outdir) with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] x = converter([load_inputs_and_targets(batch)], device) alignments = model.calculate_alignments(*x) for i in range(len(alignments)): alignment = np.transpose(np.exp(alignments[i].astype(np.float32))) np_filename = "%s/%s.npy" % (outdir, names[i]) np.save(np_filename, alignment) plt.imshow(alignment, aspect="auto") plt.xlabel("Input Index") plt.ylabel("Label Index") plt.tight_layout() fig_filename = "%s/%s.png" % (outdir, names[i]) plt.savefig(fig_filename) plt.close()
def recog(args): """Decode with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # load trained model parameters logging.info('reading model parameters from ' + args.model) # To be compatible with v.0.3.0 models if hasattr(train_args, "model_module"): model_module = train_args.model_module else: model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" model_class = dynamic_import(model_module) model = model_class(idim, odim, train_args) assert isinstance(model, ASRInterface) torch_load(args.model, model) model.recog_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info('gpu id: ' + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False}) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] if args.streaming_mode == 'window': logging.info( 'Using streaming recognizer with window size %d frames', args.streaming_window) se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) for i in range(0, feat.shape[0], args.streaming_window): logging.info('Feeding frames %d - %d', i, i + args.streaming_window) se2e.accept_input(feat[i:i + args.streaming_window]) logging.info('Running offline attention decoder') se2e.decode_with_attention_offline() logging.info('Offline attention decoder finished') nbest_hyps = se2e.retrieve_recognition() elif args.streaming_mode == 'segment': logging.info( 'Using streaming recognizer with threshold value %d', args.streaming_min_blank_dur) nbest_hyps = [] for n in range(args.nbest): nbest_hyps.append({'yseq': [], 'score': 0.0}) se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) r = np.prod(model.subsample) for i in range(0, feat.shape[0], r): hyps = se2e.accept_input(feat[i:i + r]) if hyps is not None: text = ''.join([ train_args.char_list[int(x)] for x in hyps[0]['yseq'][1:-1] if int(x) != -1 ]) text = text.replace(model.space, ' ') text = text.replace(model.blank, '') logging.info(text) for n in range(args.nbest): nbest_hyps[n]['yseq'].extend(hyps[n]['yseq']) nbest_hyps[n]['score'] += hyps[n]['score'] else: nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data keys = list(js.keys()) feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) with open(args.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_js }, indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
class ASRConverter(Converter): ''' ASR preprocess ''' def __init__(self, config): super().__init__(config) taskconf = self.config['data']['task'] assert taskconf['type'] == TASK_SET['asr'] self.subsampling_factor = taskconf['src']['subsampling_factor'] self.preprocess_conf = taskconf['src']['preprocess_conf'] # mode: asr or tts self.load_inputs_and_targets = LoadInputsAndTargets( mode=taskconf['type'], load_output=True, preprocess_conf=self.preprocess_conf) #pylint: disable=arguments-differ #pylint: disable=too-many-branches def transform(self, batch): """Function to load inputs, targets and uttid from list of dicts :param List[Tuple[str, dict]] batch: list of dict which is subset of loaded data.json :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)] :return: list of input feature sequences [(T_1, D), (T_2, D), ..., (T_B, D)] :rtype: list of float ndarray :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)] :rtype: list of int ndarray Reference: Espnet source code, /espnet/utils/io_utils.py https://github.com/espnet/espnet/blob/master/espnet/utils/io_utils.py """ x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]] y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]] uttid_list = [] # List[str] mode = self.load_inputs_and_targets.mode for uttid, info in batch: uttid_list.append(uttid) if self.load_inputs_and_targets.load_input: # Note(kamo): This for-loop is for multiple inputs for idx, inp in enumerate(info['input']): # {"input": # [{"feat": "some/path.h5:F01_050C0101_PED_REAL", # "filetype": "hdf5", # "name": "input1", ...}], ...} #pylint: disable=protected-access x_data = self.load_inputs_and_targets._get_from_loader( filepath=inp['feat'], filetype=inp.get('filetype', 'mat')) x_feats_dict.setdefault(inp['name'], []).append(x_data) elif mode == 'tts' and self.load_inputs_and_targets.use_speaker_embedding: for idx, inp in enumerate(info['input']): if idx != 1 and len(info['input']) > 1: x_data = None else: x_data = self.load_inputs_and_targets._get_from_loader( #pylint: disable=protected-access filepath=inp['feat'], filetype=inp.get('filetype', 'mat')) x_feats_dict.setdefault(inp['name'], []).append(x_data) if self.load_inputs_and_targets.load_output: for idx, inp in enumerate(info['output']): if 'tokenid' in inp: # ======= Legacy format for output ======= # {"output": [{"tokenid": "1 2 3 4"}]) x_data = np.fromiter(map(int, inp['tokenid'].split()), dtype=np.int64) else: # ======= New format ======= # {"input": # [{"feat": "some/path.h5:F01_050C0101_PED_REAL", # "filetype": "hdf5", # "name": "target1", ...}], ...} x_data = self.load_inputs_and_targets._get_from_loader( #pylint: disable=protected-access filepath=inp['feat'], filetype=inp.get('filetype', 'mat')) y_feats_dict.setdefault(inp['name'], []).append(x_data) if self.load_inputs_and_targets.mode == 'asr': #pylint: disable=protected-access return_batch, uttid_list = self.load_inputs_and_targets._create_batch_asr( x_feats_dict, y_feats_dict, uttid_list) elif self.load_inputs_and_targets.mode == 'tts': _, info = batch[0] eos = int(info['output'][0]['shape'][1]) - 1 #pylint: disable=protected-access return_batch, uttid_list = self.load_inputs_and_targets._create_batch_tts( x_feats_dict, y_feats_dict, uttid_list, eos) else: raise NotImplementedError if self.load_inputs_and_targets.preprocessing is not None: # Apply pre-processing only to input1 feature, now if 'input1' in return_batch: return_batch['input1'] = \ self.load_inputs_and_targets.preprocessing(return_batch['input1'], uttid_list, **self.load_inputs_and_targets.preprocess_args) # Doesn't return the names now. return tuple(return_batch.values()), uttid_list
import torch import json from espnet.utils.training.batchfy import make_batchset from espnet.utils.dataset import TransformDataset from espnet.asr.pytorch_backend.asr import CustomConverter from espnet.utils.io_utils import LoadInputsAndTargets converter = CustomConverter(subsampling_factor=1, dtype=torch.float32) with open("dump/train_nodev/deltafalse/data.json", "rb") as f: train_json = json.load(f)["utts"] train = make_batchset(train_json, batch_size=30) load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=None, preprocess_args={"train": True}, # Switch the mode of preprocessing ) dataset = TransformDataset(train, lambda data: converter([load_tr(data)])) model = torch.load("model.loss.best.entire") model.train() for i in range(1): for key in train[0]: print(key[0]) x = dataset[i][0] ilen = dataset[i][1] y = dataset[i][2] print("x:", x) print("x size:", x.size()) print("ilen:", ilen)
def ctc_align(args): """CTC forced alignments with the given args. Args: args (namespace): The program arguments. """ def add_alignment_to_json(js, alignment, char_list): """Add N-best results to json. Args: js (dict[str, Any]): Groundtruth utterance dict. alignment (list[int]): List of alignment. char_list (list[str]): List of characters. Returns: dict[str, Any]: N-best results added utterance dict. """ # copy old json info new_js = dict() new_js["ctc_alignment"] = [] alignment_tokens = [] for idx, a in enumerate(alignment): alignment_tokens.append(char_list[a]) alignment_tokens = " ".join(alignment_tokens) new_js["ctc_alignment"] = alignment_tokens return new_js set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.eval() load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=True, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") if args.ngpu == 1: device = "cuda" else: device = "cpu" dtype = getattr(torch, args.dtype) logging.info(f"Decoding device={device}, dtype={dtype}") model.to(device=device, dtype=dtype).eval() # read json data with open(args.align_json, "rb") as f: js = json.load(f)["utts"] new_js = {} if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) aligning " + name, idx, len(js.keys())) batch = [(name, js[name])] feat, label = load_inputs_and_targets(batch) feat = feat[0] label = label[0] enc = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0) alignment = model.ctc.forced_align(enc, label) new_js[name] = add_alignment_to_json( js[name], alignment, train_args.char_list ) else: raise NotImplementedError("Align_batch is not implemented.") with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )
def dist_train(gpu, args): """Initialize torch.distributed.""" args.gpu = gpu args.rank = gpu logging.warning("Hi Master, I am gpu {0}".format(gpu)) if not os.path.exists(args.outdir): os.makedirs(args.outdir) init_method = "tcp://localhost:{port}".format(port=args.port) torch.distributed.init_process_group( backend='nccl', world_size=args.ngpu, rank=args.gpu, init_method=init_method) torch.cuda.set_device(args.gpu) if args.streaming: converter = StreamingConverter(args.gpu, args) validconverter = StreamingConverter(args.gpu, args) else: converter = CustomConverter(args.gpu, reverse=False) validconverter = CustomConverter(args.gpu, reverse=False) with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][-1]) if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.warning('writing a model config file to ' + model_conf) f.write(json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) model.cuda(args.gpu) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) if args.opt == 'adadelta': optimizer = torch.optim.Adadelta( model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, shortest_first=False, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) valid = make_batchset(valid_json, 1, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) train_dataset = TransformDataset(train, lambda data: converter(load_tr(data))) valid_dataset = TransformDataset(valid, lambda data: validconverter(load_cv(data))) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, num_workers=args.n_iter_processes, pin_memory=False, sampler=train_sampler) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=1, shuffle=False, num_workers=args.n_iter_processes, pin_memory=False) start_epoch = 0 latest = [int(f.split('.')[-1]) for f in os.listdir(args.outdir) if 'snapshot.ep' in f] if not args.resume and len(latest): latest_snapshot = os.path.join(args.outdir, 'snapshot.ep.{}'.format(str(max(latest)))) args.resume = latest_snapshot if args.resume: logging.warning("=> loading checkpoint '{}'".format(args.resume)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logging.warning("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: start_epoch = 0 for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) train_epoch(train_loader, model, optimizer, epoch, args) loss = validate(valid_loader, model, args) if args.rank == 0: save_checkpoint({ 'epoch': epoch + 1, 'arch': args.model_module, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, filename=os.path.join(args.outdir, 'snapshot.ep.{}'.format(epoch)))
def trans(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, STInterface) # args.ctc_weight = 0.0 model.trans_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) if getattr(rnnlm_args, "model_module", "default") != "default": raise ValueError( "use '--api v2' option to decode with non-default language model" ) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info('gpu id: ' + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.trans_json, 'rb') as f: js = json.load(f)['utts'] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False}) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.translate(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.translate_batch(feats, args, train_args.char_list, rnnlm=rnnlm) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) with open(args.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_js }, indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
def trans(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, STInterface) model.trans_args = args # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info("gpu id: " + str(gpu_id)) model.cuda() # read json data with open(args.trans_json, "rb") as f: js = json.load(f)["utts"] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.translate( feat, args, train_args.char_list, ) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.translate_batch( feats, args, train_args.char_list, ) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) with open(args.result_label, "wb") as f: f.write( json.dumps({ "utts": new_js }, indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8"))
def train(args): """Train E2E-TTS model.""" set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # reverse input and output dimension idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None if args.use_character_embedding: logging.info("\nUsing character embeddings? Hahah!\n") args.char_embed_dim = 768 else: args.char_embed_dim = None # Manually set the number of intonation types args.into_type_num = 3 if args.into_embed_dim is not None: if args.into_embed_dim <= 0: args.into_embed_dim = None elif not args.use_intonation_type: raise ValueError( "use_intonation_type should be true when into_embed_dim is given." ) if args.use_intotype_loss: if not args.use_intonation_type: raise ValueError( "use_intonation_type should be true when use_intotype_loss is true." ) if not args.use_character_embedding: raise ValueError( "use_character_embedding should be true when use_intotype_loss is true." ) # write model config if args.local_rank == 0: if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: # model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) # args.batch_size *= args.ngpu # set torch device # device = torch.device("cuda" if args.ngpu > 0 else "cpu") device = torch.device("cuda", args.local_rank) model = model.to(device) model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, ) # freeze modules, if specified if args.freeze_mods: if hasattr(model, "module"): freeze_mods = ["module." + x for x in args.freeze_mods] else: freeze_mods = args.freeze_mods for mod, param in model.named_parameters(): if any(mod.startswith(key) for key in freeze_mods): logging.info(f"{mod} is frozen not to be updated.") param.requires_grad = False model_params = filter(lambda x: x.requires_grad, model.parameters()) else: model_params = model.parameters() # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam(model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, use_character_embedding=args.use_character_embedding, use_intonation_type=args.use_intonation_type, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode="tts", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, use_character_embedding=args.use_character_embedding, use_intonation_type=args.use_intonation_type, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list train_dataset = TransformDataset(train_batchset, lambda data: converter([load_tr(data)])) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) train_iter = { "main": ChainerDataLoader( dataset=train_dataset, batch_size=1, # num_workers=args.num_iter_processes, # shuffle=not use_sortagrad, collate_fn=lambda x: x[0], sampler=train_sampler, ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset(valid_batchset, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater( model, args.grad_clip, train_iter, optimizer, device, args.accum_grad, local_rank=args.local_rank, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Only the major device would evaluate and report if args.local_rank == 0: # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss", trigger=eval_interval), ) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["output"][0]["shape"][0]), ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class reduction_factor = model.module.reduction_factor else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class reduction_factor = model.reduction_factor if reduction_factor > 1: # fix the length to crop attention weight plot correctly data = copy.deepcopy(data) for idx in range(len(data)): ilen = data[idx][1]["input"][0]["shape"][0] data[idx][1]["input"][0]["shape"][ 0] = ilen // reduction_factor att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, reverse=True, ) trainer.extend(att_reporter, trigger=eval_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) # Again, only the major device would report if args.local_rank == 0: if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]["output"][1]["shape"][1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, MTInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), )) # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr, ) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == "noam": model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter() # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0, ) valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0, ) load_tr = LoadInputsAndTargets(mode="mt", load_output=True) load_cv = LoadInputsAndTargets(mode="mt", load_output=True) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, {"main": train_iter}, optimizer, device, args.ngpu, False, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0: # NOTE: sort it by output lengths data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["output"][0]["shape"][0]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, ikey="output", iaxis=1, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport(["main/loss", "validation/main/loss"], "epoch", file_name="loss.png")) trainer.extend( extensions.PlotReport(["main/acc", "validation/main/acc"], "epoch", file_name="acc.png")) trainer.extend( extensions.PlotReport(["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png")) trainer.extend( extensions.PlotReport(["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png")) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) elif args.opt == "adam": if args.criterion == "acc": trainer.extend( restore_snapshot(model, args.outdir + "/model.acc.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot(model, args.outdir + "/model.loss.best", load_fn=torch_load), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))) report_keys = [ "epoch", "iteration", "main/loss", "validation/main/loss", "main/acc", "validation/main/acc", "main/ppl", "validation/main/ppl", "elapsed_time", ] if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["eps"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") elif args.opt in ["adam", "noam"]: trainer.extend( extensions.observe_value( "lr", lambda trainer: trainer.updater.get_optimizer("main"). param_groups[0]["lr"], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("lr") if args.report_bleu: report_keys.append("main/bleu") report_keys.append("validation/main/bleu") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend( TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def gta_inference(args): set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # show arguments for key in sorted(vars(args).keys()): logging.info("args: " + key + ": " + str(vars(args)[key])) # define model model_class = dynamic_import(train_args.model_module) model = model_class(idim, odim, train_args) assert isinstance(model, TTSInterface) logging.info(model) # load trained model parameters logging.info("reading model parameters from " + args.model) torch_load(args.model, model) model.eval() # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # read json data with open(args.json, "rb") as f: js = json.load(f)["utts"] # check directory outdir = os.path.dirname(args.out) if len(outdir) != 0 and not os.path.exists(outdir): os.makedirs(outdir) use_sortagrad = train_args.sortagrad == -1 or train_args.sortagrad > 0 if use_sortagrad: train_args.batch_sort_key = "input" if args.batch_size is not None: assert args.batch_size > 0 batch_size = args.batch_size else: batch_size = args.batch_size # make minibatch list (variable length) train_batchset = make_batchset( js, batch_size, train_args.maxlen_in, train_args.maxlen_out, train_args.minibatches, batch_sort_key=train_args.batch_sort_key, min_batch_size=train_args.ngpu if train_args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=train_args.batch_count, batch_bins=train_args.batch_bins, batch_frames_in=train_args.batch_frames_in, batch_frames_out=train_args.batch_frames_out, batch_frames_inout=train_args.batch_frames_inout, swap_io=True, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="tts", use_speaker_embedding=train_args.use_speaker_embedding, use_second_target=train_args.use_second_target, use_character_embedding=train_args.use_character_embedding, use_intonation_type=train_args.use_intonation_type, preprocess_conf=train_args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=train_args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list def transform(data, loader, converter): batch, utt_list = loader(data, return_uttid=True) batch = converter([batch]) return batch, utt_list train_dataset = TransformDataset( train_batchset, lambda data: transform(data, load_tr, converter)) feat_writer = kaldiio.WriteHelper( "ark,scp:{o}.ark,{o}.scp".format(o=args.out)) for batch, utt_list in train_dataset: x = batch for key in x.keys(): x[key] = x[key].to(device) outputs = model.gta_inference(**x) olens = x['olens'] batch_size = olens.shape[0] for i in range(batch_size): utt_id = utt_list[i] mlspec = outputs[i] ol = olens[i] feat_writer[utt_id] = mlspec[:ol].cpu().numpy() feat_writer.close()
def recog_v2(args): """Decode with custom models that implements ScorerInterface. Notes: The previous backend espnet.asr.pytorch_backend.asr.recog only supports E2E and RNNLM Args: args (namespace): The program arguments. See py:func:`espnet.bin.asr_recog.get_parser` for details """ logging.warning("experimental API for custom LMs is selected by --api v2") if args.batchsize > 1: raise NotImplementedError("multi-utt batch decoding is not implemented") if args.streaming_mode is not None: raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.eval() load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.rnnlm: lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) # NOTE: for a compatibility with less than 0.5.0 version models lm_model_module = getattr(lm_args, "model_module", "default") lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) lm = lm_class(len(train_args.char_list), lm_args) torch_load(args.rnnlm, lm) lm.eval() else: lm = None if args.ngram_model: from espnet.nets.scorers.ngram import NgramFullScorer from espnet.nets.scorers.ngram import NgramPartScorer if args.ngram_scorer == "full": ngram = NgramFullScorer(args.ngram_model, train_args.char_list) else: ngram = NgramPartScorer(args.ngram_model, train_args.char_list) else: ngram = None scorers = model.scorers() scorers["lm"] = lm scorers["ngram"] = ngram scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) weights = dict( decoder=1.0 - args.ctc_weight, ctc=args.ctc_weight, lm=args.lm_weight, ngram=args.ngram_weight, length_bonus=args.penalty, ) beam_search = BeamSearch( beam_size=args.beam_size, vocab_size=len(train_args.char_list), weights=weights, scorers=scorers, sos=model.sos, eos=model.eos, token_list=train_args.char_list, pre_beam_score_key=None if args.ctc_weight == 1.0 else "decoder", ) # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: beam_search.__class__ = BatchBeamSearch logging.info("BatchBeamSearch implementation is selected.") else: logging.warning( f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation." ) if args.ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") if args.ngpu == 1: device = "cuda" else: device = "cpu" dtype = getattr(torch, args.dtype) logging.info(f"Decoding device={device}, dtype={dtype}") model.to(device=device, dtype=dtype).eval() beam_search.to(device=device, dtype=dtype).eval() # read json data with open(args.recog_json, "rb") as f: js = json.load(f)["utts"] new_js = {} with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype)) nbest_hyps = beam_search( x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio ) nbest_hyps = [ h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] ] new_js[name] = add_results_to_json( js[name], nbest_hyps, train_args.char_list ) with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )
def train(args): """Train with the given args :param Namespace args: The program arguments """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # check cuda and cudnn availability if not chainer.cuda.available: logging.warning('cuda is not available') if not chainer.cuda.cudnn_enabled: logging.warning('cudnn is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # check attention type if args.atype not in ['noatt', 'dot', 'location']: raise NotImplementedError( 'chainer supports only noatt, dot, and location attention.') # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture logging.info('import model module: ' + args.model_module) model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args, flag_return=False) assert isinstance(model, ASRInterface) # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # Set gpu ngpu = args.ngpu if ngpu == 1: gpu_id = 0 # Make a specified GPU current chainer.cuda.get_device_from_id(gpu_id).use() model.to_gpu() # Copy the model to the GPU logging.info('single gpu calculation.') elif ngpu > 1: gpu_id = 0 devices = {'main': gpu_id} for gid in six.moves.xrange(1, ngpu): devices['sub_%d' % gid] = gid logging.info('multi gpu calculation (#gpus = %d).' % ngpu) logging.info('batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) else: gpu_id = -1 logging.info('cpu calculation') # Setup an optimizer if args.opt == 'adadelta': optimizer = chainer.optimizers.AdaDelta(eps=args.eps) elif args.opt == 'adam': optimizer = chainer.optimizers.Adam() elif args.opt == 'noam': optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) else: raise NotImplementedError('args.opt={}'.format(args.opt)) optimizer.setup(model) optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) # Setup a converter converter = CustomConverter(subsampling_factor=model.subsample[0]) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] # set up training iterator and updater load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 accum_grad = args.accum_grad if ngpu <= 1: # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train, load_tr), batch_size=1, shuffle=not use_sortagrad) ] # set up updater updater = CustomUpdater(train_iters[0], optimizer, converter=converter, device=gpu_id, accum_grad=accum_grad) else: if args.batch_count not in ("auto", "seq") and args.batch_size == 0: raise NotImplementedError( "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu" ) # set up minibatches train_subsets = [] for gid in six.moves.xrange(ngpu): # make subset train_json_subset = { k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid } # make minibatch list (variable length) train_subsets += [ make_batchset(train_json_subset, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches) ] # each subset must have same length for MultiprocessParallelUpdater maxlen = max([len(train_subset) for train_subset in train_subsets]) for train_subset in train_subsets: if maxlen != len(train_subset): for i in six.moves.xrange(maxlen - len(train_subset)): train_subset += [train_subset[i]] # hack to make batchsize argument as 1 # actual batchsize is included in a list if args.n_iter_processes > 0: train_iters = [ ToggleableShufflingMultiprocessIterator( TransformDataset(train_subsets[gid], load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] else: train_iters = [ ToggleableShufflingSerialIterator(TransformDataset( train_subsets[gid], load_tr), batch_size=1, shuffle=not use_sortagrad) for gid in six.moves.xrange(ngpu) ] # set up updater updater = CustomParallelUpdater(train_iters, optimizer, converter=converter, devices=devices) # Set up a trainer trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler(train_iters), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) if args.opt == 'noam': from espnet.nets.chainer_backend.e2e_asr_transformer import VaswaniRule trainer.extend(VaswaniRule('alpha', d=args.adim, warmup_steps=args.transformer_warmup_steps, scale=args.transformer_lr), trigger=(1, 'iteration')) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # set up validation iterator valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) if args.n_iter_processes > 0: valid_iter = chainer.iterators.MultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: valid_iter = chainer.iterators.SerialIterator(TransformDataset( valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Evaluate the model with the test dataset for each epoch trainer.extend( extensions.Evaluator(valid_iter, model, converter=converter, device=gpu_id)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class logging.info('Using custom PlotAttentionReport') att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=gpu_id) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Take a snapshot for each specified epoch trainer.extend( extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'), trigger=(1, 'epoch')) # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) # Save best models trainer.extend( extensions.snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( extensions.snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best'), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best'), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').eps), trigger=(REPORT_INTERVAL, 'iteration')) report_keys.append('eps') trainer.extend(extensions.PrintReport(report_keys), trigger=(REPORT_INTERVAL, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=(REPORT_INTERVAL, 'iteration')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) if args.num_encs > 1: args = format_mulenc_args(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim_list = [ int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs) ] odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) for i in range(args.num_encs): logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i])) logging.info("#output dims: " + str(odim)) # specify attention, CTC, hybrid mode if "transducer" in args.model_module: assert args.mtlalpha == 1.0 mtl_mode = "transducer" logging.info("Pure transducer mode") if args.mtlalpha == 1.0: mtl_mode = "ctc" logging.info("Pure CTC mode") elif args.mtlalpha == 0.0: mtl_mode = "att" logging.info("Pure attention mode") else: mtl_mode = "mtl" logging.info("Multitask learning mode") if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1: model = load_trained_modules(idim_list[0], odim, args) else: model_class = dynamic_import(args.model_module) model = model_class( idim_list[0] if args.num_encs == 1 else idim_list, odim, args ) assert isinstance(model, ASRInterface) logging.info( " Total parameter of the model = " + str(sum(p.numel() for p in model.parameters())) ) if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit) ) torch_load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps( (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True, ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) args.batch_size *= args.ngpu if args.num_encs > 1: # TODO(ruizhili): implement data parallel for multi-encoder setup. raise NotImplementedError( "Data parallel is not supported for multi-encoder setup." ) # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) if args.freeze_mods: model, model_params = freeze_modules(model, args.freeze_mods) else: model_params = model.parameters() # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta( model_params, rho=0.95, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "adam": optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr ) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux" ) raise e if args.opt == "noam": model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype ) else: model, optimizer = amp.initialize( model, optimizer, opt_level=args.train_dtype ) use_apex = True from espnet.nets.pytorch_backend.ctc import CTC amp.register_float_function(CTC, "loss_fn") amp.init() logging.warning("register ctc as float function") else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter if args.num_encs == 1: converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype) else: converter = CustomConverterMulEnc( [i[0] for i in model.subsample_list], dtype=dtype ) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, {"main": train_iter}, optimizer, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) ) # Save attention weight each epoch if args.num_save_attention > 0 and ( mtl_mode == "transducer" and getattr(args, "rnnt_mode", False) == "rnnt" ): data = sorted( list(valid_json.items())[: args.num_save_attention], key=lambda x: int(x[1]["input"][0]["shape"][1]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Make a plot for training and validation values if args.num_encs > 1: report_keys_loss_ctc = [ "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs) ] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)] report_keys_cer_ctc = [ "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs) ] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)] trainer.extend( extensions.PlotReport( [ "main/loss", "validation/main/loss", "main/loss_ctc", "validation/main/loss_ctc", "main/loss_att", "validation/main/loss_att", ] + ([] if args.num_encs == 1 else report_keys_loss_ctc), "epoch", file_name="loss.png", ) ) trainer.extend( extensions.PlotReport( ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" ) ) trainer.extend( extensions.PlotReport( ["main/cer_ctc", "validation/main/cer_ctc"] + ([] if args.num_encs == 1 else report_keys_loss_ctc), "epoch", file_name="cer.png", ) ) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) if mtl_mode not in ["ctc", "transducer"]: trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc" and mtl_mode != "ctc": trainer.extend( restore_snapshot( model, args.outdir + "/model.acc.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot( model, args.outdir + "/model.loss.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) ) report_keys = [ "epoch", "iteration", "main/loss", "main/loss_ctc", "main/loss_att", "validation/main/loss", "validation/main/loss_ctc", "validation/main/loss_att", "main/acc", "validation/main/acc", "main/cer_ctc", "validation/main/cer_ctc", "elapsed_time", ] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc) if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ "eps" ], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") if args.report_cer: report_keys.append("validation/main/cer") if args.report_wer: report_keys.append("validation/main/wer") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend( TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def recog(args): """Decode with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # load trained model parameters logging.info('reading model parameters from ' + args.model) model = E2E(idim, odim, train_args) torch_load(args.model, model) model.recog_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # gpu if args.ngpu == 1: gpu_id = range(args.ngpu) logging.info('gpu id: ' + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] with using_transform_config({'train': True}): feat = load_inputs_and_targets(batch)[0][0] if args.streaming_window: logging.info( 'Using streaming recognizer with window size %d frames', args.streaming_window) se2e = StreamingE2E(e2e=model, recog_args=args, char_list=train_args.char_list, rnnlm=rnnlm) for i in range(0, feat.shape[0], args.streaming_window): logging.info('Feeding frames %d - %d', i, i + args.streaming_window) se2e.accept_input(feat[i:i + args.streaming_window]) logging.info('Running offline attention decoder') se2e.decode_with_attention_offline() logging.info('Offline attention decoder finished') nbest_hyps = se2e.retrieve_recognition() else: nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) else: try: from itertools import zip_longest as zip_longest except Exception: from itertools import izip_longest as zip_longest def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data keys = list(js.keys()) feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] with using_transform_config({'train': False}): feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) # TODO(watanabe) fix character coding problems when saving it with open(args.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_js }, indent=4, sort_keys=True).encode('utf_8'))
def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # load trained model parameters logging.info('reading model parameters from ' + args.model) model = E2E(idim, odim, train_args) torch_load(args.model, model) model.recog_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # gpu if args.ngpu == 1: gpu_id = range(args.ngpu) logging.info('gpu id: ' + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] with using_transform_config({'train': True}): feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) else: try: from itertools import zip_longest as zip_longest except Exception: from itertools import izip_longest as zip_longest def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] with using_transform_config({'train': False}): feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm) for i, name in enumerate(names): nbest_hyp = [hyp[i] for hyp in nbest_hyps] new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) with open(args.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_js }, indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][-1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') # specify model architecture model = E2E(idim, odim, args) subsampling_factor = model.subsample[0] if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( 'batch size is automatically increased (%d -> %d)' % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == 'noam': model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, dtype=dtype) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=-1) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=-1) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=True, collate_fn=lambda x: x[0]) } valid_iter = { 'main': ChainerDataLoader(dataset=TransformDataset( valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes) } # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex) trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend( CustomEvaluator(model, valid_iter, reporter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class(att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport([ 'main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att' ], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend( extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'], 'epoch', file_name='cer.png')) # Save best models trainer.extend( snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend( snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = [ 'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time' ] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main'). param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend( extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration")) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train E2E VC model.""" set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) # In TTS, this is reversed, but not in VC. See `espnet.utils.training.batchfy` idim = int(valid_json[utts[0]]["input"][0]["shape"][1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # get extra input and output dimenstion if args.use_speaker_embedding: args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) else: args.spk_embed_dim = None if args.use_second_target: args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) else: args.spc_dim = None # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to" + model_conf) f.write( json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8")) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) # specify model architecture if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, TTSInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, TTSInterface) logging.info(model) reporter = model.reporter # freeze modules, if specified if args.freeze_mods: for mod, param in model.named_parameters(): if any(mod.startswith(key) for key in args.freeze_mods): logging.info("freezing %s" % mod) param.requires_grad = False for mod, param in model.named_parameters(): if not param.requires_grad: logging.info("Frozen module %s" % mod) # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), )) # Setup an optimizer if args.opt == "adam": optimizer = torch.optim.Adam(model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr, ) elif args.opt == "lamb": from pytorch_lamb import Lamb optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.999)) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 if use_sortagrad: args.batch_sort_key = "input" # make minibatch list (variable length) train_batchset = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=False, iaxis=0, oaxis=0, ) valid_batchset = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, batch_sort_key=args.batch_sort_key, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, swap_io=False, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="vc", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) load_cv = LoadInputsAndTargets( mode="vc", use_speaker_embedding=args.use_speaker_embedding, use_second_target=args.use_second_target, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing keep_all_data_on_mem=args.keep_all_data_on_mem, ) converter = CustomConverter() # hack to make batchsize argument as 1 # actual bathsize is included in a list train_iter = { "main": ChainerDataLoader( dataset=TransformDataset(train_batchset, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.num_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) } valid_iter = { "main": ChainerDataLoader( dataset=TransformDataset(valid_batchset, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.num_iter_processes, ) } # Set up a trainer updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer, device, args.accum_grad) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # set intervals eval_interval = (args.eval_interval_epochs, "epoch") save_interval = (args.save_interval_epochs, "epoch") report_interval = (args.report_interval_iters, "iteration") # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval) # Save snapshot for each epoch trainer.extend(torch_snapshot(), trigger=save_interval) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss", trigger=eval_interval), ) # Save attention figure for each epoch if args.num_save_attention > 0: data = sorted( list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]["input"][0]["shape"][1]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, reverse=True, ) trainer.extend(att_reporter, trigger=eval_interval) else: att_reporter = None # Make a plot for training and validation values if hasattr(model, "module"): base_plot_keys = model.module.base_plot_keys else: base_plot_keys = model.base_plot_keys plot_keys = [] for key in base_plot_keys: plot_key = ["main/" + key, "validation/main/" + key] trainer.extend( extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), trigger=eval_interval, ) plot_keys += plot_key trainer.extend( extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), trigger=eval_interval, ) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=report_interval)) report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) trainer.extend(extensions.ProgressBar(), trigger=report_interval) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def train(args): """Train with the given args :param Namespace args: The program arguments """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['output'][1]['shape'][1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify model architecture model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, MTInterface) if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write(json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) if args.batch_size != 0: logging.info('batch size is automatically increased (%d -> %d)' % ( args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") model = model.to(device) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta( model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(idim=idim) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, mt=True, iaxis=1, oaxis=0) load_tr = LoadInputsAndTargets( mode='mt', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='mt', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train, load_tr), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, args.accum_grad) trainer = training.Trainer( updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend(ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device)) # Save attention weight each epoch if args.num_save_attention > 0: # sort it by output lengths data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['output'][0]['shape'][0]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, ikey="output", iaxis=1) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss', 'main/loss_att', 'validation/main/loss_att'], 'epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend(extensions.PlotReport(['main/ppl', 'validation/main/ppl'], 'epoch', file_name='ppl.png')) # Save best models trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) trainer.extend(snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = ['epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/acc', 'validation/main/acc', 'main/ppl', 'validation/main/ppl', 'elapsed_time'] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') trainer.extend(extensions.PrintReport( report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": writer = SummaryWriter(args.tensorboard_dir) trainer.extend(TensorboardLogger(writer, att_reporter), trigger=(args.report_interval_iters, 'iteration')) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ # display chainer version logging.info('chainer version = ' + chainer.__version__) set_deterministic_chainer(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) # specify model architecture logging.info('reading model parameters from ' + args.model) # To be compatible with v.0.3.0 models if hasattr(train_args, "model_module"): model_module = train_args.model_module else: model_module = "espnet.nets.chainer_backend.e2e_asr:E2E" model_class = dynamic_import(model_module) model = model_class(idim, odim, train_args) assert isinstance(model, ASRInterface) chainer_load(args.model, model) # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_chainer.ClassifierWithState(lm_chainer.RNNLM( len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) chainer_load(args.rnnlm, rnnlm) else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_chainer.ClassifierWithState(lm_chainer.RNNLM( len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) chainer_load(args.word_rnnlm, word_rnnlm) if rnnlm is not None: rnnlm = lm_chainer.ClassifierWithState( extlm_chainer.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_chainer.ClassifierWithState( extlm_chainer.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # decode each utterance new_js = {} with chainer.no_backprop_mode(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) with open(args.result_label, 'wb') as f: f.write(json.dumps({'utts': new_js}, indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
def ctc_align(args, device): """ESPnet-specific interface for CTC segmentation. Parses configuration, infers the CTC posterior probabilities, and then aligns start and end of utterances using CTC segmentation. Results are written to the output file given in the args. :param args: given configuration :param device: for inference; one of ['cuda', 'cpu'] :return: 0 on success """ model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=True, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) logging.info(f"Decoding device={device}") # Warn for nets with high memory consumption on long audio files if hasattr(model, "enc"): encoder_module = model.enc.__class__.__module__ elif hasattr(model, "encoder"): encoder_module = model.encoder.__class__.__module__ else: encoder_module = "Unknown" logging.info(f"Encoder module: {encoder_module}") logging.info(f"CTC module: {model.ctc.__class__.__module__}") if "rnn" not in encoder_module: logging.warning( "No BLSTM model detected; memory consumption may be high.") model.to(device=device).eval() # read audio and text json data with open(args.data_json, "rb") as f: js = json.load(f)["utts"] with open(args.utt_text, "r", encoding="utf-8") as f: lines = f.readlines() i = 0 text = {} segment_names = {} for name in js.keys(): text_per_audio = [] segment_names_per_audio = [] while i < len(lines) and lines[i].startswith(name): text_per_audio.append(lines[i][lines[i].find(" ") + 1:]) segment_names_per_audio.append(lines[i][:lines[i].find(" ")]) i += 1 text[name] = text_per_audio segment_names[name] = segment_names_per_audio # apply configuration config = CtcSegmentationParameters() subsampling_factor = 1 frame_duration_ms = 10 if args.subsampling_factor is not None: subsampling_factor = args.subsampling_factor if args.frame_duration is not None: frame_duration_ms = args.frame_duration # Backwards compatibility to ctc_segmentation <= 1.5.3 if hasattr(config, "index_duration"): config.index_duration = frame_duration_ms * subsampling_factor / 1000 else: config.subsampling_factor = subsampling_factor config.frame_duration_ms = frame_duration_ms if args.min_window_size is not None: config.min_window_size = args.min_window_size if args.max_window_size is not None: config.max_window_size = args.max_window_size config.char_list = train_args.char_list if args.use_dict_blank is not None: logging.warning("The option --use-dict-blank is deprecated. If needed," " use --set-blank instead.") if args.set_blank is not None: config.blank = args.set_blank if args.replace_spaces_with_blanks is not None: if args.replace_spaces_with_blanks: config.replace_spaces_with_blanks = True else: config.replace_spaces_with_blanks = False if args.gratis_blank: config.blank_transition_cost_zero = True if config.blank_transition_cost_zero and args.replace_spaces_with_blanks: logging.error( "Blanks are inserted between words, and also the transition cost of blank" " is zero. This configuration may lead to misalignments!") if args.scoring_length is not None: config.score_min_mean_over_L = args.scoring_length logging.info( f"Frame timings: {frame_duration_ms}ms * {subsampling_factor}") # Iterate over audio files to decode and align for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) Aligning " + name, idx, len(js.keys())) batch = [(name, js[name])] feat, label = load_inputs_and_targets(batch) feat = feat[0] with torch.no_grad(): # Encode input frames enc_output = model.encode( torch.as_tensor(feat).to(device)).unsqueeze(0) # Apply ctc layer to obtain log character probabilities lpz = model.ctc.log_softmax(enc_output)[0].cpu().numpy() # Prepare the text for aligning ground_truth_mat, utt_begin_indices = prepare_text(config, text[name]) # Align using CTC segmentation timings, char_probs, state_list = ctc_segmentation( config, lpz, ground_truth_mat) logging.debug(f"state_list = {state_list}") # Obtain list of utterances with time intervals and confidence score segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text[name]) # Write to "segments" file for i, boundary in enumerate(segments): utt_segment = (f"{segment_names[name][i]} {name} {boundary[0]:.2f}" f" {boundary[1]:.2f} {boundary[2]:.9f}\n") args.output.write(utt_segment) return 0