def load_trained_model(model_path, training=True):
    """Load the trained model for recognition.

    Args:
        model_path (str): Path to model.***.best

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json"))

    logging.warning("reading model parameters from " + model_path)

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    # CTC Loss is not needed, default to builtin to prevent import errors
    if hasattr(train_args, "ctc_type"):
        train_args.ctc_type = "builtin"

    model_class = dynamic_import(model_module)

    if "transducer" in model_module:
        model = model_class(idim, odim, train_args, training)
    else:
        model = model_class(idim, odim, train_args)
    torch_load(model_path, model)

    return model, train_args
Esempio n. 2
0
    def __init__(self, config):

        logging.info("Worker2 init start:")

        self.tmp_dir = config['tmp']
        os.makedirs(self.tmp_dir, exist_ok=True)
        self.debug_mode = config['debug_mode']

        self.phone2id = frontend.LoadDictionary(config['dict_path'])

        self.idim, odim, train_args = get_model_conf(
            config['acoustic_model_path'])

        model_class = dynamic_import(train_args.model_module)
        model = model_class(self.idim, odim, train_args)
        torch_load(config['acoustic_model_path'], model)

        logging.info("Model Load Done {}".format(
            config['acoustic_model_path']))

        self.model = model.eval().to(device)
        self.inference_args = Namespace(
            **{
                "threshold": config['threshold'],
                "minlenratio": config['minlenratio'],
                "maxlenratio": config['maxlenratio']
            })

        logging.info("TacotronLPCNetWorker Init Done")

        self.num_chunk_frame = config["num_chunk_frame"]
        logging.info("Chunk size: {}".format(self.num_chunk_frame))
        self.overlap = config["overlap"]
        logging.info("Overlap: {}".format(self.overlap))
def get_trained_model_state_dict(model_path):
    """Extract the trained model state dict for pre-initialization.

    Args:
        model_path (str): Path to model.***.best

    Return:
        model.state_dict() (OrderedDict): the loaded model state_dict
        (bool): Boolean defining whether the model is an LM

    """
    conf_path = os.path.join(os.path.dirname(model_path), "model.json")
    if "rnnlm" in model_path:
        logging.warning("reading model parameters from %s", model_path)

        return get_lm_state_dict(torch.load(model_path))

    idim, odim, args = get_model_conf(model_path, conf_path)

    logging.warning("reading model parameters from " + model_path)

    if hasattr(args, "model_module"):
        model_module = args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, args)
    torch_load(model_path, model)
    assert (isinstance(model, MTInterface) or isinstance(model, ASRInterface)
            or isinstance(model, TTSInterface))

    return model.state_dict()
Esempio n. 4
0
    def __init__(self, device='cpu'):
        dict_path = "downloads/data/lang_1char/train_no_dev_units.txt"
        model_path = "downloads/exp/train_no_dev_pytorch_train_pytorch_tacotron2.v3/results/model.last1.avg.best"
        vocoder_path = "downloads/ljspeech.parallel_wavegan.v1/checkpoint-400000steps.pkl"
        vocoder_conf = "downloads/ljspeech.parallel_wavegan.v1/config.yml"

        device = torch.device(device)

        idim, odim, train_args = get_model_conf(model_path)
        model_class = dynamic_import(train_args.model_module)
        model = model_class(idim, odim, train_args)
        torch_load(model_path, model)
        model = model.eval().to(device)
        inference_args = Namespace(**{"threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0})

        with open(vocoder_conf) as f:
            config = yaml.load(f, Loader=yaml.Loader)
        vocoder = ParallelWaveGANGenerator(**config["generator_params"])
        vocoder.load_state_dict(torch.load(vocoder_path, map_location="cpu")["model"]["generator"])
        vocoder.remove_weight_norm()
        vocoder = vocoder.eval().to(device)

        with open(dict_path) as f:
            lines = f.readlines()
        lines = [line.replace("\n", "").split(" ") for line in lines]
        char_to_id = {c: int(i) for c, i in lines}

        self.device = device
        self.char_to_id = char_to_id
        self.idim = idim
        self.model = model
        self.inference_args = inference_args
        self.config = config
        self.vocoder = vocoder
Esempio n. 5
0
def load_trained_model(model_path):
    """Load the trained model.

    Args:
        model_path(str): Path to model.***.best

    """
    # read training config
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), 'model.json'))

    # load trained model parameters
    logging.info('reading model parameters from ' + model_path)
    # To be compatible with v.0.3.0 models
    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr_ftransformer:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)
    #torch_load(model_path, model)
    model = torch.load(model_path).cpu()
#    for p1, p2 in zip(model.parameters(), mdl.parameters()):
#      p1 = p2
    
    return model, train_args
Esempio n. 6
0
def decode(args):
    """Decode with E2E-TTS model."""
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # show arguments
    for key in sorted(vars(args).keys()):
        logging.info('args: ' + key + ': ' + str(vars(args)[key]))

    # define model
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    assert isinstance(model, TTSInterface)
    logging.info(model)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    torch_load(args.model, model)
    model.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # generate speaker embeddings
    SequenceGenerator(model.resnet_spkid, device, args.feat_scp, args.out_file)
Esempio n. 7
0
def load_trained_model(model_path):
    """Load the trained model.

    Args:
        model_path(str): Path to model.***.best

    """
    # read training config
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), 'model.json'))

    # load trained model parameters
    logging.info('reading model parameters from ' + model_path)
    # To be compatible with v.0.3.0 models
    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)

    if hasattr(train_args,
               "slu_model") and train_args.slu_model and train_args.slu_loss:
        if hasattr(train_args, 'slu_pooling'):
            model.add_slu(train_args.slu_model, train_args.slu_loss,
                          train_args.slu_tune_weights, train_args.slu_pooling)
        else:
            model.add_slu(train_args.slu_model, train_args.slu_loss,
                          train_args.slu_tune_weights, '')

    torch_load(model_path, model)

    return model, train_args
Esempio n. 8
0
def get_trained_model_state_dict(model_path):
    """Extract the trained model state dict for pre-initialization.

    Args:
        model_path (str): Path to model.***.best

    Return:
        model.state_dict() (OrderedDict): the loaded model state_dict
        (str): Type of model. Either ASR/MT or LM.

    """
    conf_path = os.path.join(os.path.dirname(model_path), 'model.json')
    if 'rnnlm' in model_path:
        logging.warning('reading model parameters from %s', model_path)

        return torch.load(model_path), 'lm'

    idim, odim, args = get_model_conf(model_path, conf_path)

    logging.warning('reading model parameters from ' + model_path)

    if hasattr(args, "model_module"):
        model_module = args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, args)
    torch_load(model_path, model)
    assert isinstance(model, MTInterface) or isinstance(model, ASRInterface)

    return model.state_dict(), 'asr-mt'
def viterbi_decode(args):
    set_deterministic_pytorch(args)
    idim, odim, train_args = get_model_conf(
        args.model, os.path.join(os.path.dirname(args.model), 'model.json'))
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    if args.model is not None:
        load_params = dict(torch.load(args.model))
        if 'model' in load_params:
            load_params = dict(load_params['model'])
        if 'state_dict' in load_params:
            load_params = dict(load_params['state_dict'])
        model_params = dict(model.named_parameters())
        for k, v in load_params.items():
            k = k.replace('module.', '')
            if k in model_params and v.size() == model_params[k].size():
                model_params[k].data = v.data
                logging.warning('load parameters {}'.format(k))
    model.recog_args = args

    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()

    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf,
        preprocess_args={'train': False})

    with torch.no_grad():
        for idx, name in enumerate(js.keys(), 1):
            logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
            batch = [(name, js[name])]
            feat = load_inputs_and_targets(batch)
            y = np.fromiter(map(int,
                                batch[0][1]['output'][0]['tokenid'].split()),
                            dtype=np.int64)
            align = model.viterbi_decode(feat[0][0], y)
            assert len(align) == len(y)
            new_js[name] = js[name]
            new_js[name]['output'][0]['align'] = ' '.join(
                [str(i) for i in list(align)])

    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            },
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
Esempio n. 10
0
def decode(args):
    '''RUN DECODING'''
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # show argments
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # define model
    tacotron2 = Tacotron2(idim, odim, train_args)
    eos = str(tacotron2.idim - 1)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    torch_load(args.model, tacotron2)
    tacotron2.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    tacotron2 = tacotron2.to(device)

    # read json data
    with open(args.json, 'rb') as f:
        js = json.load(f)['utts']

    # chech direcitory
    outdir = os.path.dirname(args.out)
    if len(outdir) != 0 and not os.path.exists(outdir):
        os.makedirs(outdir)

    # write to ark and scp file (see https://github.com/vesis84/kaldi-io-for-python)
    arkscp = 'ark:| copy-feats --print-args=false ark:- ark,scp:%s.ark,%s.scp' % (
        args.out, args.out)
    with torch.no_grad(), kaldi_io_py.open_or_fd(arkscp, 'wb') as f:
        for idx, utt_id in enumerate(js.keys()):
            x = js[utt_id]['output'][0]['tokenid'].split() + [eos]
            x = np.fromiter(map(int, x), dtype=np.int64)
            x = torch.LongTensor(x).to(device)

            # get speaker embedding
            if train_args.use_speaker_embedding:
                spemb = kaldi_io_py.read_vec_flt(
                    js[utt_id]['input'][1]['feat'])
                spemb = torch.FloatTensor(spemb).to(device)
            else:
                spemb = None

            # decode and write
            outs, _, _ = tacotron2.inference(x, args, spemb)
            if outs.size(0) == x.size(0) * args.maxlenratio:
                logging.warn("output length reaches maximum length (%s)." %
                             utt_id)
            logging.info(
                '(%d/%d) %s (size:%d->%d)' %
                (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0)))
            kaldi_io_py.write_mat(f, outs.cpu().numpy(), utt_id)
Esempio n. 11
0
 def __make_lm_module(self, lm_path: str):
     if len(lm_path) <= 0:
         return None
     lm_config = get_model_conf(model_path=lm_path)
     lm_model_module = getattr(lm_config, "model_module", "default")
     lm_class = dynamic_import_lm(lm_model_module, lm_config.backend)
     lm = lm_class(len(self.train_args.char_list), lm_config)
     torch_load(lm_path, lm)
     lm.eval()
     return lm
Esempio n. 12
0
    def __init__(self):
        # TODO Move this into a config File, give option of different models
        self.trans_type = "phn"
        dict_path = "/home/ntrusse2/espnet/downloads/en/fastspeech/data/lang_1phn/phn_train_no_dev_units.txt"
        model_path = "/home/ntrusse2/espnet/downloads/en/fastspeech/exp/phn_train_no_dev_pytorch_train_tacotron2.v3_fastspeech.v4.single/results/model.last1.avg.best"
        vocoder_path = "/home/ntrusse2/espnet/downloads/en/parallel_wavegan/ljspeech.parallel_wavegan.v2/checkpoint-400000steps.pkl"
        vocoder_conf = "/home/ntrusse2/espnet/downloads/en/parallel_wavegan/ljspeech.parallel_wavegan.v2/config.yml"

        # Copied right out of the examples on ESPNETs DEMO
        self.device = torch.device("cuda")
        print("Loading Torch Model...")
        self.idim, odim, train_args = get_model_conf(model_path)
        model_class = dynamic_import(train_args.model_module)
        model = model_class(self.idim, odim, train_args)
        torch_load(model_path, model)
        self.model = model.eval().to(self.device)
        self.inference_args = Namespace(
            **{
                "threshold": 0.5,
                "minlenratio": 0.0,
                "maxlenratio": 10.0,
                "use_attention_constraint": True,
                "backward_window": 1,
                "forward_window": 3,
            })

        print("Loading Vocoder...")
        with open(vocoder_conf) as f:
            self.config = yaml.load(f, Loader=yaml.Loader)
        vocoder_class = self.config.get("generator_type",
                                        "ParallelWaveGANGenerator")
        vocoder = getattr(parallel_wavegan.models,
                          vocoder_class)(**self.config["generator_params"])
        vocoder.load_state_dict(
            torch.load(vocoder_path, map_location="cpu")["model"]["generator"])
        vocoder.remove_weight_norm()
        self.vocoder = vocoder.eval().to(self.device)

        print("Loading Text Frontend...")
        with open(dict_path) as f:
            lines = f.readlines()
        lines = [line.replace("\n", "").split(" ") for line in lines]
        self.char_to_id = {c: int(i) for c, i in lines}
        self.g2p = G2p()

        self.pad_fn = torch.nn.ReplicationPad1d(
            self.config["generator_params"].get("aux_context_window", 0))
        self.use_noise_input = vocoder_class == "ParallelWaveGANGenerator"
Esempio n. 13
0
def recog(args):
    set_deterministic_pytorch(args)

    from espnet.asr.asr_utils import get_model_conf, torch_load
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    torch_load(args.model, model)
    # model, train_args = load_trained_model(args.model)
    model.recog_args = args

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info("gpu id: " + str(gpu_id))
        model.cuda()

    # read json data
    with open(args.recog_json, "rb") as f:
        js = json.load(f)["utts"]

    load_inputs_and_targets = LoadInputsAndTargets(
        mode="asr",
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=None,
        preprocess_args={"train": False},
    )

    ark_file = open(args.result_ark, 'wb')
    if args.batchsize == 0:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
                batch = [(name, js[name])]

                feat = load_inputs_and_targets(batch)
                feat = (feat[0][0])

                hyps = model.recognize(feat)
                hyps = hyps.squeeze(0)
                hyps = hyps.data.numpy()

                write_mat(ark_file, hyps, key=name)
Esempio n. 14
0
    def _load_teacher_model(self, model_path):
        # get teacher model config
        idim, odim, args = get_model_conf(model_path)

        # assert dimension is the same between teacher and studnet
        assert idim == self.idim
        assert odim == self.odim
        assert args.reduction_factor == self.reduction_factor

        # load teacher model
        model = Transformer(idim, odim, args)
        torch_load(model_path, model)

        # freeze teacher model parameters
        for p in model.parameters():
            p.requires_grad = False

        return model
Esempio n. 15
0
def perform_tts(input_text):
    idim, odim, train_args = get_model_conf(model_path)
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    torch_load(model_path, model)
    model = model.eval().to(device)
    inference_args = Namespace(
        **{
            "threshold": 0.5,
            "minlenratio": 0.0,
            "maxlenratio": 10.0,
            # Only for Tacotron 2
            "use_attention_constraint": True,
            "backward_window": 1,
            "forward_window": 3,
            # Only for fastspeech (lower than 1.0 is faster speech, higher than 1.0 is slower speech)
            "fastspeech_alpha": 1.0,
        })

    # define neural vocoder
    fs = 22050
    vocoder = load_model(vocoder_path)
    vocoder.remove_weight_norm()
    vocoder = vocoder.eval().to(device)

    # define text frontend

    with open(dict_path) as f:
        lines = f.readlines()
    lines = [line.replace("\n", "").split(" ") for line in lines]
    char_to_id = {c: int(i) for c, i in lines}
    g2p = G2p()

    print('input : ', input_text)
    with torch.no_grad():
        start = time.time()
        x = frontend(input_text, g2p, char_to_id, idim)
        c, _, _ = model.inference(x, inference_args)
        y = vocoder.inference(c)
    rtf = (time.time() - start) / (len(y) / fs)
    print(f"RTF = {rtf:5f}")
    print(y)
    write("static/test.wav", fs, y.view(-1).cpu().numpy())
Esempio n. 16
0
    def _load_teacher_model(self, model_path):
        # get teacher model config
        idim, odim, args = get_model_conf(model_path)

        # assert dimension is the same between teacher and studnet
        assert idim == self.idim
        assert odim == self.odim
        assert args.reduction_factor == self.reduction_factor

        # load teacher model
        from espnet.utils.dynamic_import import dynamic_import
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
        torch_load(model_path, model)

        # freeze teacher model parameters
        for p in model.parameters():
            p.requires_grad = False

        return model
Esempio n. 17
0
def load_trained_model(model_path):
    """Load the trained model for recognition.

    Args:
        model_path(str): Path to model.***.best

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), 'model.json'))

    logging.info('reading model parameters from ' + model_path)

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)
    torch_load(model_path, model)

    return model, train_args
Esempio n. 18
0
def get_trained_model_state_dict(model_path, new_is_transducer):
    """Extract the trained model state dict for pre-initialization.

    Args:
        model_path (str): Path to trained model.
        new_is_transducer (bool): Whether the new model is Transducer-based.

    Return:
        (Dict): Trained model state dict.

    """
    logging.info(f"Reading model parameters from {model_path}")

    conf_path = os.path.join(os.path.dirname(model_path), "model.json")

    if "rnnlm" in model_path:
        return get_lm_state_dict(torch.load(model_path))

    idim, odim, args = get_model_conf(model_path, conf_path)

    if hasattr(args, "model_module"):
        model_module = args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, args)
    torch_load(model_path, model)

    assert (isinstance(model, MTInterface) or isinstance(model, ASRInterface)
            or isinstance(model, TTSInterface))

    if new_is_transducer and "transducer" not in args.model_module:
        return create_transducer_compatible_state_dict(
            model.state_dict(),
            args.etype,
            args.eunits,
        )

    return model.state_dict()
Esempio n. 19
0
def load_trained_model(model_path, use_ema=True, training=True):
    """Load the trained model for recognition.

    Args:
        model_path (str): Path to model.***.best
        use_ema (bool): Use EMA parameters when available

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json"))

    logging.warning("reading model parameters from " + model_path)

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    # CTC Loss is not needed, default to builtin to prevent import errors
    if hasattr(train_args, "ctc_type"):
        train_args.ctc_type = "builtin"

    model_class = dynamic_import(model_module)

    if "transducer" in model_module:
        model = model_class(idim, odim, train_args, training=training)
        loading_fn = custom_torch_load
    else:
        model = model_class(idim, odim, train_args)
        loading_fn = torch_load

    if hasattr(train_args, "ema_decay") and train_args.ema_decay > 0:
        model = EMA(model, train_args.ema_decay)
    loading_fn(model_path, model)

    if use_ema and hasattr(model, "shadow"):
        logging.warning("use EMA parameters")
        model = model.shadow

    return model, train_args
Esempio n. 20
0
def test_downloaded_asr_model_decodable(module, download_info):
    # download model
    print(download_info[0])
    tmpdir = tempfile.mkdtemp(prefix="tmp_", dir=".")
    model_path = download_zip_from_google_drive(tmpdir, download_info[1])

    # load trained model parameters
    m = importlib.import_module(module)
    idim, odim, train_args = get_model_conf(model_path)
    model = m.E2E(idim, odim, train_args)
    if "chainer" in module:
        chainer_load(model_path, model)
    else:
        torch_load(model_path, model)

    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = np.random.randn(128, idim)
        model.recognize(in_data, train_args, train_args.char_list)  # decodable

    # remove
    if os.path.exists(tmpdir):
        shutil.rmtree(tmpdir)
Esempio n. 21
0
def main(args):
    """Run the main function"""

    model_loc = "exp/" + args[0] + "/rnnlm.model.best"
    print(model_loc)
    train_args = get_model_conf(model_loc)

    #model_module = "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM"
    model_class = dynamic_import_lm(train_args.model_module,
                                    train_args.backend)
    if train_args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, train_args.train_dtype)
    else:
        dtype = torch.float32

    model = model_class(train_args.n_vocab, train_args).to(dtype=dtype)
    torch_load(model_loc, model)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    print("Trainable parameters: ", params)
    print("Length vocabulary: ", train_args.n_vocab)
Esempio n. 22
0
def load_tacotron_loss(tts_model_conf,
                       tts_model_file,
                       args,
                       reporter=None,
                       train_mode=None,
                       asr2tts_policy=None):
    # Read model
    if 'conf' in tts_model_conf:
        with open(tts_model_conf, 'rb') as f:
            idim_taco, odim_taco, train_args_taco = pickle.load(f)
    elif 'json' in tts_model_conf:
        from espnet.asr.asr_utils import get_model_conf
        idim_taco, odim_taco, train_args_taco = get_model_conf(
            tts_model_file, conf_path=tts_model_conf)
    if args.modify_output:
        import json
        with open(args.valid_json, 'rb') as f:
            valid_json = json.load(f)['utts']
        utts = list(valid_json.keys())
        idim_taco = int(valid_json[utts[0]]['output'][0]['shape'][1])
        from espnet.asr.asrtts_utils import remove_output_layer
        pretrained_model = remove_output_layer(torch.load(tts_model_file),
                                               idim_taco, args.eprojs,
                                               train_args_taco.embed_dim,
                                               'tts')

        torch.save(pretrained_model, 'tmp.model')
        tts_model_file = 'tmp.model'
    # Load loss
    return TacotronRewardLoss(tts_model_file,
                              idim=idim_taco,
                              odim=odim_taco,
                              train_args=train_args_taco,
                              update_asr_only=args.update_asr_only,
                              reporter=reporter,
                              train_mode=train_mode,
                              asr2tts_policy=asr2tts_policy), train_args_taco
Esempio n. 23
0
def load_trained_model(model_path, training=True):
    """Load the trained model for recognition.

    Args:
        model_path (str): Path to model.***.best
        training (bool): Training mode specification for transducer model.

    Returns:
        model (torch.nn.Module): Trained model.
        train_args (Namespace): Trained model arguments.

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json"))

    logging.info(f"Reading model parameters from {model_path}")

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    # CTC Loss is not needed, default to builtin to prevent import errors
    if hasattr(train_args, "ctc_type"):
        train_args.ctc_type = "builtin"

    model_class = dynamic_import(model_module)

    if "transducer" in model_module:
        model = model_class(idim, odim, train_args, training=training)
        custom_torch_load(model_path, model, training=training)
    else:
        model = model_class(idim, odim, train_args)
        torch_load(model_path, model)

    return model, train_args
Esempio n. 24
0
def recog(args):
    """Decode with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    model = E2E(idim, odim, train_args)
    torch_load(args.model, model)
    model.recog_args = args

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
        word_dict = rnnlm_args.char_list_dict
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(
            len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
        torch_load(args.word_rnnlm, word_rnnlm)
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict, char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor,
                                              word_dict, char_dict))

    # gpu
    if args.ngpu == 1:
        gpu_id = range(args.ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
        if rnnlm:
            rnnlm.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr', load_output=False, sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf)

    if args.batchsize == 0:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
                batch = [(name, js[name])]
                with using_transform_config({'train': True}):
                    feat = load_inputs_and_targets(batch)[0][0]
                nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
                new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list)
    else:
        try:
            from itertools import zip_longest as zip_longest
        except Exception:
            from itertools import izip_longest as zip_longest

        def grouper(n, iterable, fillvalue=None):
            kargs = [iter(iterable)] * n
            return zip_longest(*kargs, fillvalue=fillvalue)

        # sort data
        keys = list(js.keys())
        feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
        sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
        keys = [keys[i] for i in sorted_index]

        with torch.no_grad():
            for names in grouper(args.batchsize, keys, None):
                names = [name for name in names if name]
                batch = [(name, js[name]) for name in names]
                with using_transform_config({'train': False}):
                    feats = load_inputs_and_targets(batch)[0]
                nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm)
                for i, name in enumerate(names):
                    nbest_hyp = [hyp[i] for hyp in nbest_hyps]
                    new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list)

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(json.dumps({'utts': new_js}, indent=4, sort_keys=True).encode('utf_8'))
Esempio n. 25
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    model = E2E(idim, odim, args)
    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps,
            weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False,
            n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
    else:
        train_iter = chainer.iterators.SerialIterator(
            TransformDataset(train, converter.transform),
            batch_size=1)
        valid_iter = chainer.iterators.SerialIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(
        model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu)
    trainer = training.Trainer(
        updater, (args.epochs, 'epoch'), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
        else:
            att_vis_fn = model.calculate_all_attentions
        att_reporter = PlotAttentionReport(
            att_vis_fn, data, args.outdir + "/att_ws",
            converter=converter, device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
                                          'main/loss_ctc', 'validation/main/loss_ctc',
                                          'main/loss_att', 'validation/main/loss_att'],
                                         'epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                         'epoch', file_name='acc.png'))

    # Save best models
    trainer.extend(extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save),
                       trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration')))
    report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
                   'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
                   'main/acc', 'validation/main/acc', 'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]),
            trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(
        report_keys), trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Esempio n. 26
0
File: tts.py Progetto: zy1022/espnet
def decode(args):
    """Decode with E2E-TTS model."""
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # show arguments
    for key in sorted(vars(args).keys()):
        logging.info("args: " + key + ": " + str(vars(args)[key]))

    # define model
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    assert isinstance(model, TTSInterface)
    logging.info(model)

    # load trained model parameters
    logging.info("reading model parameters from " + args.model)
    torch_load(args.model, model)
    model.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # read json data
    with open(args.json, "rb") as f:
        js = json.load(f)["utts"]

    # check directory
    outdir = os.path.dirname(args.out)
    if len(outdir) != 0 and not os.path.exists(outdir):
        os.makedirs(outdir)

    load_inputs_and_targets = LoadInputsAndTargets(
        mode="tts",
        load_input=False,
        sort_in_input_length=False,
        use_speaker_embedding=train_args.use_speaker_embedding,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )

    # define function for plot prob and att_ws
    def _plot_and_save(array, figname, figsize=(6, 4), dpi=150):
        import matplotlib.pyplot as plt

        shape = array.shape
        if len(shape) == 1:
            # for eos probability
            plt.figure(figsize=figsize, dpi=dpi)
            plt.plot(array)
            plt.xlabel("Frame")
            plt.ylabel("Probability")
            plt.ylim([0, 1])
        elif len(shape) == 2:
            # for tacotron 2 attention weights, whose shape is (out_length, in_length)
            plt.figure(figsize=figsize, dpi=dpi)
            plt.imshow(array, aspect="auto")
            plt.xlabel("Input")
            plt.ylabel("Output")
        elif len(shape) == 4:
            # for transformer attention weights,
            # whose shape is (#leyers, #heads, out_length, in_length)
            plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]),
                       dpi=dpi)
            for idx1, xs in enumerate(array):
                for idx2, x in enumerate(xs, 1):
                    plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2)
                    plt.imshow(x, aspect="auto")
                    plt.xlabel("Input")
                    plt.ylabel("Output")
        else:
            raise NotImplementedError("Support only from 1D to 4D array.")
        plt.tight_layout()
        if not os.path.exists(os.path.dirname(figname)):
            # NOTE: exist_ok = True is needed for parallel process decoding
            os.makedirs(os.path.dirname(figname), exist_ok=True)
        plt.savefig(figname)
        plt.close()

    # define function to calculate focus rate
    # (see section 3.3 in https://arxiv.org/abs/1905.09263)
    def _calculate_focus_rete(att_ws):
        if att_ws is None:
            # fastspeech case -> None
            return 1.0
        elif len(att_ws.shape) == 2:
            # tacotron 2 case -> (L, T)
            return float(att_ws.max(dim=-1)[0].mean())
        elif len(att_ws.shape) == 4:
            # transformer case -> (#layers, #heads, L, T)
            return float(att_ws.max(dim=-1)[0].mean(dim=-1).max())
        else:
            raise ValueError("att_ws should be 2 or 4 dimensional tensor.")

    # define function to convert attention to duration
    def _convert_att_to_duration(att_ws):
        if len(att_ws.shape) == 2:
            # tacotron 2 case -> (L, T)
            pass
        elif len(att_ws.shape) == 4:
            # transformer case -> (#layers, #heads, L, T)
            # get the most diagonal head according to focus rate
            att_ws = torch.cat([att_w for att_w in att_ws],
                               dim=0)  # (#heads * #layers, L, T)
            diagonal_scores = att_ws.max(dim=-1)[0].mean(
                dim=-1)  # (#heads * #layers,)
            diagonal_head_idx = diagonal_scores.argmax()
            att_ws = att_ws[diagonal_head_idx]  # (L, T)
        else:
            raise ValueError("att_ws should be 2 or 4 dimensional tensor.")
        # calculate duration from 2d attention weight
        durations = torch.stack(
            [att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])])
        return durations.view(-1, 1).float()

    # define writer instances
    feat_writer = kaldiio.WriteHelper(
        "ark,scp:{o}.ark,{o}.scp".format(o=args.out))
    if args.save_durations:
        dur_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(
            o=args.out.replace("feats", "durations")))
    if args.save_focus_rates:
        fr_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(
            o=args.out.replace("feats", "focus_rates")))

    # start decoding
    for idx, utt_id in enumerate(js.keys()):
        # setup inputs
        batch = [(utt_id, js[utt_id])]
        data = load_inputs_and_targets(batch)
        x = torch.LongTensor(data[0][0]).to(device)
        spemb = None
        if train_args.use_speaker_embedding:
            spemb = torch.FloatTensor(data[1][0]).to(device)

        # decode and write
        start_time = time.time()
        outs, probs, att_ws = model.inference(x, args, spemb=spemb)
        logging.info("inference speed = %.1f frames / sec." %
                     (int(outs.size(0)) / (time.time() - start_time)))
        if outs.size(0) == x.size(0) * args.maxlenratio:
            logging.warning("output length reaches maximum length (%s)." %
                            utt_id)
        focus_rate = _calculate_focus_rete(att_ws)
        logging.info("(%d/%d) %s (size: %d->%d, focus rate: %.3f)" %
                     (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0),
                      focus_rate))
        feat_writer[utt_id] = outs.cpu().numpy()
        if args.save_durations:
            ds = _convert_att_to_duration(att_ws)
            dur_writer[utt_id] = ds.cpu().numpy()
        if args.save_focus_rates:
            fr_writer[utt_id] = np.array(focus_rate).reshape(1, 1)

        # plot and save prob and att_ws
        if probs is not None:
            _plot_and_save(
                probs.cpu().numpy(),
                os.path.dirname(args.out) + "/probs/%s_prob.png" % utt_id,
            )
        if att_ws is not None:
            _plot_and_save(
                att_ws.cpu().numpy(),
                os.path.dirname(args.out) + "/att_ws/%s_att_ws.png" % utt_id,
            )

    # close file object
    feat_writer.close()
    if args.save_durations:
        dur_writer.close()
    if args.save_focus_rates:
        fr_writer.close()
Esempio n. 27
0
def enhance(args):
    """Dumping enhanced speech and mask.

    Args:
        args (namespace): The program arguments.
    """
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    assert isinstance(model, ASRInterface)
    torch_load(args.model, model)
    model.recog_args = args

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=None  # Apply pre_process in outer func
    )
    if args.batchsize == 0:
        args.batchsize = 1

    # Creates writers for outputs from the network
    if args.enh_wspecifier is not None:
        enh_writer = file_writer_helper(args.enh_wspecifier,
                                        filetype=args.enh_filetype)
    else:
        enh_writer = None

    # Creates a Transformation instance
    preprocess_conf = (train_args.preprocess_conf if
                       args.preprocess_conf is None else args.preprocess_conf)
    if preprocess_conf is not None:
        logging.info('Use preprocessing'.format(preprocess_conf))
        transform = Transformation(preprocess_conf)
    else:
        transform = None

    # Creates a IStft instance
    istft = None
    frame_shift = args.istft_n_shift  # Used for plot the spectrogram
    if args.apply_istft:
        if preprocess_conf is not None:
            # Read the conffile and find stft setting
            with open(preprocess_conf) as f:
                # Json format: e.g.
                #    {"process": [{"type": "stft",
                #                  "win_length": 400,
                #                  "n_fft": 512, "n_shift": 160,
                #                  "window": "han"},
                #                 {"type": "foo", ...}, ...]}
                conf = json.load(f)
                assert 'process' in conf, conf
                # Find stft setting
                for p in conf['process']:
                    if p['type'] == 'stft':
                        istft = IStft(win_length=p['win_length'],
                                      n_shift=p['n_shift'],
                                      window=p.get('window', 'hann'))
                        logging.info('stft is found in {}. '
                                     'Setting istft config from it\n{}'.format(
                                         preprocess_conf, istft))
                        frame_shift = p['n_shift']
                        break
        if istft is None:
            # Set from command line arguments
            istft = IStft(win_length=args.istft_win_length,
                          n_shift=args.istft_n_shift,
                          window=args.istft_window)
            logging.info(
                'Setting istft config from the command line args\n{}'.format(
                    istft))

    # sort data
    keys = list(js.keys())
    feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
    sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
    keys = [keys[i] for i in sorted_index]

    def grouper(n, iterable, fillvalue=None):
        kargs = [iter(iterable)] * n
        return zip_longest(*kargs, fillvalue=fillvalue)

    num_images = 0
    if not os.path.exists(args.image_dir):
        os.makedirs(args.image_dir)

    for names in grouper(args.batchsize, keys, None):
        batch = [(name, js[name]) for name in names]

        # May be in time region: (Batch, [Time, Channel])
        org_feats = load_inputs_and_targets(batch)[0]
        if transform is not None:
            # May be in time-freq region: : (Batch, [Time, Channel, Freq])
            feats = transform(org_feats, train=False)
        else:
            feats = org_feats

        with torch.no_grad():
            enhanced, mask, ilens = model.enhance(feats)

        for idx, name in enumerate(names):
            # Assuming mask, feats : [Batch, Time, Channel. Freq]
            #          enhanced    : [Batch, Time, Freq]
            enh = enhanced[idx][:ilens[idx]]
            mas = mask[idx][:ilens[idx]]
            feat = feats[idx]

            # Plot spectrogram
            if args.image_dir is not None and num_images < args.num_images:
                import matplotlib.pyplot as plt
                num_images += 1
                ref_ch = 0

                plt.figure(figsize=(20, 10))
                plt.subplot(4, 1, 1)
                plt.title('Mask [ref={}ch]'.format(ref_ch))
                plot_spectrogram(plt,
                                 mas[:, ref_ch].T,
                                 fs=args.fs,
                                 mode='linear',
                                 frame_shift=frame_shift,
                                 bottom=False,
                                 labelbottom=False)

                plt.subplot(4, 1, 2)
                plt.title('Noisy speech [ref={}ch]'.format(ref_ch))
                plot_spectrogram(plt,
                                 feat[:, ref_ch].T,
                                 fs=args.fs,
                                 mode='db',
                                 frame_shift=frame_shift,
                                 bottom=False,
                                 labelbottom=False)

                plt.subplot(4, 1, 3)
                plt.title('Masked speech [ref={}ch]'.format(ref_ch))
                plot_spectrogram(plt, (feat[:, ref_ch] * mas[:, ref_ch]).T,
                                 frame_shift=frame_shift,
                                 fs=args.fs,
                                 mode='db',
                                 bottom=False,
                                 labelbottom=False)

                plt.subplot(4, 1, 4)
                plt.title('Enhanced speech')
                plot_spectrogram(plt,
                                 enh.T,
                                 fs=args.fs,
                                 mode='db',
                                 frame_shift=frame_shift)

                plt.savefig(os.path.join(args.image_dir, name + '.png'))
                plt.clf()

            # Write enhanced wave files
            if enh_writer is not None:
                if istft is not None:
                    enh = istft(enh)
                else:
                    enh = enh

                if args.keep_length:
                    if len(org_feats[idx]) < len(enh):
                        # Truncate the frames added by stft padding
                        enh = enh[:len(org_feats[idx])]
                    elif len(org_feats) > len(enh):
                        padwidth = [(0, (len(org_feats[idx]) - len(enh)))] \
                            + [(0, 0)] * (enh.ndim - 1)
                        enh = np.pad(enh, padwidth, mode='constant')

                if args.enh_filetype in ('sound', 'sound.hdf5'):
                    enh_writer[name] = (args.fs, enh)
                else:
                    # Hint: To dump stft_signal, mask or etc,
                    # enh_filetype='hdf5' might be convenient.
                    enh_writer[name] = enh

            if num_images >= args.num_images and enh_writer is None:
                logging.info('Breaking the process.')
                break
Esempio n. 28
0
def recog(args):
    """Decode with the given args.

    Args:
        args (namespace): The program arguments.
    """
    set_deterministic_pytorch(args)
    model, train_args = load_trained_model(args.model)
    assert isinstance(model, ASRInterface)
    model.recog_args = args

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        if getattr(rnnlm_args, "model_module", "default") != "default":
            raise ValueError(
                "use '--api v2' option to decode with non-default language model"
            )
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
        word_dict = rnnlm_args.char_list_dict
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.word_rnnlm, word_rnnlm)
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict,
                                           char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                              char_dict))

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
        if rnnlm:
            rnnlm.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf,
        preprocess_args={'train': False})

    if args.batchsize == 0:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
                batch = [(name, js[name])]
                feat = load_inputs_and_targets(batch)[0][0]
                if args.streaming_mode == 'window':
                    logging.info(
                        'Using streaming recognizer with window size %d frames',
                        args.streaming_window)
                    se2e = WindowStreamingE2E(e2e=model,
                                              recog_args=args,
                                              rnnlm=rnnlm)
                    for i in range(0, feat.shape[0], args.streaming_window):
                        logging.info('Feeding frames %d - %d', i,
                                     i + args.streaming_window)
                        se2e.accept_input(feat[i:i + args.streaming_window])
                    logging.info('Running offline attention decoder')
                    se2e.decode_with_attention_offline()
                    logging.info('Offline attention decoder finished')
                    nbest_hyps = se2e.retrieve_recognition()
                elif args.streaming_mode == 'segment':
                    logging.info(
                        'Using streaming recognizer with threshold value %d',
                        args.streaming_min_blank_dur)
                    nbest_hyps = []
                    for n in range(args.nbest):
                        nbest_hyps.append({'yseq': [], 'score': 0.0})
                    se2e = SegmentStreamingE2E(e2e=model,
                                               recog_args=args,
                                               rnnlm=rnnlm)
                    r = np.prod(model.subsample)
                    for i in range(0, feat.shape[0], r):
                        hyps = se2e.accept_input(feat[i:i + r])
                        if hyps is not None:
                            text = ''.join([
                                train_args.char_list[int(x)]
                                for x in hyps[0]['yseq'][1:-1] if int(x) != -1
                            ])
                            text = text.replace(
                                '\u2581', ' ').strip()  # for SentencePiece
                            text = text.replace(model.space, ' ')
                            text = text.replace(model.blank, '')
                            logging.info(text)
                            for n in range(args.nbest):
                                nbest_hyps[n]['yseq'].extend(hyps[n]['yseq'])
                                nbest_hyps[n]['score'] += hyps[n]['score']
                else:
                    nbest_hyps = model.recognize(feat, args,
                                                 train_args.char_list, rnnlm)
                new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                                   train_args.char_list)

    else:

        def grouper(n, iterable, fillvalue=None):
            kargs = [iter(iterable)] * n
            return zip_longest(*kargs, fillvalue=fillvalue)

        # sort data if batchsize > 1
        keys = list(js.keys())
        if args.batchsize > 1:
            feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
            sorted_index = sorted(range(len(feat_lens)),
                                  key=lambda i: -feat_lens[i])
            keys = [keys[i] for i in sorted_index]

        with torch.no_grad():
            for names in grouper(args.batchsize, keys, None):
                names = [name for name in names if name]
                batch = [(name, js[name]) for name in names]
                feats = load_inputs_and_targets(batch)[0]
                nbest_hyps = model.recognize_batch(feats,
                                                   args,
                                                   train_args.char_list,
                                                   rnnlm=rnnlm)

                for i, nbest_hyp in enumerate(nbest_hyps):
                    name = names[i]
                    new_js[name] = add_results_to_json(js[name], nbest_hyp,
                                                       train_args.char_list)

    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            },
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
Esempio n. 29
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args)
    elif args.asr_init is not None:
        model, _ = load_trained_model(args.asr_init)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, ASRInterface)

    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    logging.info(device)
    logging.info(dtype)
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            converter,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device,
                        args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Esempio n. 30
0
def trans(args):
    """Decode with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)
    model, train_args = load_trained_model(args.model)
    assert isinstance(model, MTInterface)
    model.recog_args = args

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
        if rnnlm:
            rnnlm.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    # remove enmpy utterances
    if train_args.replace_sos:
        js = {
            k: v
            for k, v in js.items() if v['output'][0]['shape'][0] > 1
            and v['output'][1]['shape'][0] > 1
        }
    else:
        js = {
            k: v
            for k, v in js.items() if v['output'][0]['shape'][0] > 0
            and v['output'][1]['shape'][0] > 0
        }

    if args.batchsize == 0:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
                feat = [js[name]['output'][1]['tokenid'].split()]
                nbest_hyps = model.translate(feat, args, train_args.char_list,
                                             rnnlm)
                new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                                   train_args.char_list)
    else:

        def grouper(n, iterable, fillvalue=None):
            kargs = [iter(iterable)] * n
            return zip_longest(*kargs, fillvalue=fillvalue)

        # sort data
        keys = list(js.keys())
        feat_lens = [js[key]['output'][1]['shape'][0] for key in keys]
        sorted_index = sorted(range(len(feat_lens)),
                              key=lambda i: -feat_lens[i])
        keys = [keys[i] for i in sorted_index]

        with torch.no_grad():
            for names in grouper(args.batchsize, keys, None):
                names = [name for name in names if name]
                feats = [
                    np.fromiter(map(int,
                                    js[name]['output'][1]['tokenid'].split()),
                                dtype=np.int64) for name in names
                ]
                nbest_hyps = model.translate_batch(feats,
                                                   args,
                                                   train_args.char_list,
                                                   rnnlm=rnnlm)
                for i, nbest_hyp in enumerate(nbest_hyps):
                    name = names[i]
                    new_js[name] = add_results_to_json(js[name], nbest_hyp,
                                                       train_args.char_list)

    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            },
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))