Пример #1
0
def test_pytorch_freezable():
    from espnet.nets.pytorch_backend.e2e_asr import E2E

    idim, odim, ilens, olens = get_default_scope_inputs()
    args = get_rnn_args(freeze_mods="enc.enc.0.")

    model = E2E(idim, odim, args)
    model, model_params = freeze_modules(model, args.freeze_mods)

    model.train()
Пример #2
0
 def add_arguments(parser):
     """Add arguments."""
     E2EASR.encoder_add_arguments(parser)
     E2E.encoder_mix_add_arguments(parser)
     E2EASR.attention_add_arguments(parser)
     E2EASR.decoder_add_arguments(parser)
     return parser
Пример #3
0
def load_espnet_encoder(model_path, pretrained=True):
    import json
    import argparse
    from pathlib import Path
    from espnet.nets.pytorch_backend.e2e_asr import E2E

    model_dir = (Path(model_path).parent)
    with open(str(Path(model_dir) / "model.json"), "r") as f:
        idim, odim, conf = json.load(f)
    model = E2E(idim, odim, argparse.Namespace(**conf))
    if pretrained:
        model.load_state_dict(torch.load(model_path, map_location="cpu"))
    encoder = E2EASREncoder(model.enc)
    return encoder
Пример #4
0
 def add_arguments(parser):
     return E2E_pytorch.add_arguments(parser)
Пример #5
0
def recog(args):
    """Decode with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    model = E2E(idim, odim, train_args)
    torch_load(args.model, model)
    model.recog_args = args

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(train_args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        rnnlm.eval()
    else:
        rnnlm = None

    if args.word_rnnlm:
        rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
        word_dict = rnnlm_args.char_list_dict
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.word_rnnlm, word_rnnlm)
        word_rnnlm.eval()

        if rnnlm is not None:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict,
                                           char_dict))
        else:
            rnnlm = lm_pytorch.ClassifierWithState(
                extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                              char_dict))

    # gpu
    if args.ngpu == 1:
        gpu_id = range(args.ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
        if rnnlm:
            rnnlm.cuda()

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf)

    if args.batchsize == 0:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
                batch = [(name, js[name])]
                with using_transform_config({'train': True}):
                    feat = load_inputs_and_targets(batch)[0][0]
                nbest_hyps = model.recognize(feat, args, train_args.char_list,
                                             rnnlm)
                new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                                   train_args.char_list)
    else:
        try:
            from itertools import zip_longest as zip_longest
        except Exception:
            from itertools import izip_longest as zip_longest

        def grouper(n, iterable, fillvalue=None):
            kargs = [iter(iterable)] * n
            return zip_longest(*kargs, fillvalue=fillvalue)

        # sort data
        keys = list(js.keys())
        feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
        sorted_index = sorted(range(len(feat_lens)),
                              key=lambda i: -feat_lens[i])
        keys = [keys[i] for i in sorted_index]

        with torch.no_grad():
            for names in grouper(args.batchsize, keys, None):
                names = [name for name in names if name]
                batch = [(name, js[name]) for name in names]
                with using_transform_config({'train': False}):
                    feats = load_inputs_and_targets(batch)[0]
                nbest_hyps = model.recognize_batch(feats,
                                                   args,
                                                   train_args.char_list,
                                                   rnnlm=rnnlm)
                for i, nbest_hyp in enumerate(nbest_hyps):
                    name = names[i]
                    new_js[name] = add_results_to_json(js[name], nbest_hyp,
                                                       train_args.char_list)

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            }, indent=4, sort_keys=True).encode('utf_8'))
Пример #6
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    model = E2E(idim, odim, args)
    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, converter.transform),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, converter.transform),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
        else:
            att_vis_fn = model.calculate_all_attentions
        att_reporter = PlotAttentionReport(att_vis_fn,
                                           data,
                                           args.outdir + "/att_ws",
                                           converter=converter,
                                           device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model,
                                       'model.acc.best',
                                       savefun=torch_save),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(log_dir=args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Пример #7
0
 def add_arguments(parser):
     """Add arguments."""
     E2E.add_conformer_encoder_arguments(parser)
     E2ELas.attention_add_arguments(parser)
     E2ELas.decoder_add_arguments(parser)
     return parser
Пример #8
0
 def add_arguments(parser):
     """Add arguments."""
     return E2E_pytorch.add_arguments(parser)
Пример #9
0
"""## 4.2 Load pretrained model

let's load it from python
"""

import json
import torch
from espnet.nets.pytorch_backend.e2e_asr import E2E

model_dir = "espnet/egs/an4/asr1/exp/train_nodev_pytorch_train_mtlalpha0.5/results"

# load model
with open(model_dir + "/model.json", "r") as f:
    idim, odim, conf = json.load(f)
model = E2E.build(idim, odim, **conf)
model.load_state_dict(torch.load(model_dir + "/model.acc.best"))
model.cpu().eval()
vocab = conf["char_list"]
print(vocab)
model

"""## 4.3 Recognize the speech by the model

You can perform joint decoding with all the models (S2S, CTC, LM, etc)  in ESPnet
"""

import re
from espnet.nets.beam_search import BeamSearch

key, info = list(test_json.items())[10]
Пример #10
0
def test_pytorch_trainable_and_transferable(model_type, finetune_dic):
    idim, odim, ilens, olens = get_default_scope_inputs()

    if model_type == "rnn":
        from espnet.nets.pytorch_backend.e2e_asr import E2E

        arg_function = get_rnn_args
    else:
        from espnet.nets.pytorch_backend.e2e_asr_transducer import E2E

        arg_function = get_rnnt_args

    args = arg_function()

    model = E2E(idim, odim, args)

    batch = pytorch_prepare_inputs(idim, odim, ilens, olens)

    loss = model(*batch)
    loss.backward()

    if not os.path.exists(".pytest_cache"):
        os.makedirs(".pytest_cache")

    tmppath = tempfile.mktemp()

    if finetune_dic["use_lm"] is not None:
        lm = get_lm(args.dlayers, args.dunits, args.char_list)
        tmppath += "_rnnlm"

        torch_save(tmppath, lm)
    else:
        torch_save(tmppath, model)

    if finetune_dic["enc_init"] is not None:
        finetune_dic["enc_init"] = tmppath
    if finetune_dic["dec_init"] is not None:
        finetune_dic["dec_init"] = tmppath

    finetune_args = arg_function(**finetune_dic)

    # create dummy model.json for saved model to go through
    # get_model_conf(...) called in load_trained_modules method.
    model_conf = os.path.dirname(tmppath) + "/model.json"
    with open(model_conf, "wb") as f:
        f.write(
            json.dumps(
                (idim, odim, vars(finetune_args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8"))

    model = load_trained_modules(idim, odim, finetune_args)

    loss = model(*batch)
    loss.backward()

    if model_type == "rnnt":
        beam_search = BeamSearchTransducer(
            decoder=model.dec,
            joint_network=model.joint_network,
            beam_size=1,
            lm=None,
            lm_weight=0.0,
            search_type="default",
            max_sym_exp=2,
            u_max=10,
            nstep=1,
            prefix_alpha=1,
            score_norm=False,
        )

        with torch.no_grad():
            in_data = np.random.randn(10, idim)
            model.recognize(in_data, beam_search)
    else:
        with torch.no_grad():
            in_data = np.random.randn(10, idim)
            model.recognize(in_data, args, args.char_list)
Пример #11
0
assert len(devset[0]) == batch_size
devset[0][:3]

"""### Build neural networks (3/4)

For simplicity, we use a predefined model: [Transformer](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf). 

NOTE: You can also use your custom model in command line tools as `asr_train.py --model-module your_module:YourModel`
"""

import argparse
from espnet.bin.asr_train import get_parser
from espnet.nets.pytorch_backend.e2e_asr import E2E

parser = get_parser()
parser = E2E.add_arguments(parser)
config = parser.parse_args([
    "--mtlalpha", "0.0",  # weight for cross entropy and CTC loss
    "--outdir", "out", "--dict", ""])  # TODO: allow no arg

idim = info["input"][0]["shape"][1]
odim = info["output"][0]["shape"][1]
setattr(config, "char_list", [])
model = E2E(idim, odim, config)
model

"""### Update neural networks by iterating datasets (4/4)

Finaly, we got the training part.
"""