Exemplo n.º 1
0
def test_lm():
    n_vocab = 3
    n_layers = 2
    n_units = 2
    batchsize = 5
    for typ in ["lstm"]:  # TODO(anyone) gru
        rnnlm_ch = lm_chainer.ClassifierWithState(
            lm_chainer.RNNLM(n_vocab, n_layers, n_units, typ=typ)
        )
        rnnlm_th = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(n_vocab, n_layers, n_units, typ=typ)
        )
        transfer_lm(rnnlm_ch.predictor, rnnlm_th.predictor)

        # test prediction equality
        x = torch.from_numpy(numpy.random.randint(n_vocab, size=batchsize)).long()
        with torch.no_grad(), chainer.no_backprop_mode(), chainer.using_config(
            "train", False
        ):
            rnnlm_th.predictor.eval()
            state_th, y_th = rnnlm_th.predictor(None, x.long())
            state_ch, y_ch = rnnlm_ch.predictor(None, x.data.numpy())
            for k in state_ch.keys():
                for n in range(len(state_th[k])):
                    print(k, n)
                    print(state_th[k][n].data.numpy())
                    print(state_ch[k][n].data)
                    numpy.testing.assert_allclose(
                        state_th[k][n].data.numpy(), state_ch[k][n].data, 1e-5
                    )
            numpy.testing.assert_allclose(y_th.data.numpy(), y_ch.data, 1e-5)
Exemplo n.º 2
0
def test_recognition_results_with_lm(etype, dtype, m_str, text_idx1):
    const = 1e-4
    numpy.random.seed(1)
    seq_true_texts = [
        ["o", "iuiuiuiuiuiuiuiuo", "iuiuiuiuiuiuiuiuo"],
        ["o", "o", "ieieieieieieieieo"],
        ["o", "iuiuiuiuiuiuiuiuo", "iuiuiuiuiuiuiuiuo"],
        ["o", "o", "ieieieieieieieieo"],
        ["o", "iuiuiuiuiuiuiuiuo", "iuiuiuiuiuiuiuiuo"],
        ["o", "o", "ieieieieieieieieo"],
        ["o", "iuiuiuiuiuiuiuiuo", "iuiuiuiuiuiuiuiuo"],
        ["o", "o", "ieieieieieieieieo"],
    ]

    # ctc_weight: 0.0 (attention), 0.5 (hybrid CTC/attention), 1.0 (CTC)
    for text_idx2, ctc_weight in enumerate([0.0, 0.5, 1.0]):
        seq_true_text = seq_true_texts[text_idx1][text_idx2]

        args = make_arg(
            etype=etype, rnnlm="dummy", ctc_weight=ctc_weight, lm_weight=0.3
        )
        m = importlib.import_module(m_str)
        model = m.E2E(40, 5, args)

        if "pytorch" in m_str:
            rnnlm = lm_pytorch.ClassifierWithState(
                lm_pytorch.RNNLM(len(args.char_list), 2, 10)
            )
            init_torch_weight_const(model, const)
            init_torch_weight_const(rnnlm, const)
        else:
            rnnlm = lm_chainer.ClassifierWithState(
                lm_chainer.RNNLM(len(args.char_list), 2, 10)
            )
            init_chainer_weight_const(model, const)
            init_chainer_weight_const(rnnlm, const)

        data = [
            (
                "aaa",
                dict(
                    feat=numpy.random.randn(100, 40).astype(numpy.float32),
                    token=seq_true_text,
                ),
            )
        ]

        in_data = data[0][1]["feat"]
        nbest_hyps = model.recognize(in_data, args, args.char_list, rnnlm)
        y_hat = nbest_hyps[0]["yseq"][1:]
        seq_hat = [args.char_list[int(idx)] for idx in y_hat]
        seq_hat_text = "".join(seq_hat).replace("<space>", " ")
        seq_true_text = data[0][1]["token"]

        assert seq_hat_text == seq_true_text
Exemplo n.º 3
0
def test_lm():
    n_vocab = 3
    n_layers = 2
    n_units = 2
    batchsize = 5
    for typ in ["lstm"]:  # TODO(anyone) gru
        rnnlm_ch = lm_chainer.ClassifierWithState(lm_chainer.RNNLM(n_vocab, n_layers, n_units, typ=typ))
        rnnlm_th = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(n_vocab, n_layers, n_units, typ=typ))
        transfer_lm(rnnlm_ch.predictor, rnnlm_th.predictor)
        import numpy
        # TODO(karita) implement weight transfer
        # numpy.testing.assert_equal(rnnlm_ch.predictor.embed.W.data, rnnlm_th.predictor.embed.weight.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.l1.upward.b.data, rnnlm_th.predictor.l1.bias_ih.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.l1.upward.W.data, rnnlm_th.predictor.l1.weight_ih.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.l1.lateral.W.data, rnnlm_th.predictor.l1.weight_hh.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.l2.upward.b.data, rnnlm_th.predictor.l2.bias_ih.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.l2.upward.W.data, rnnlm_th.predictor.l2.weight_ih.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.l2.lateral.W.data, rnnlm_th.predictor.l2.weight_hh.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.lo.b.data, rnnlm_th.predictor.lo.bias.data.numpy())
        # numpy.testing.assert_equal(rnnlm_ch.predictor.lo.W.data, rnnlm_th.predictor.lo.weight.data.numpy())

        # test prediction equality
        x = torch.from_numpy(numpy.random.randint(n_vocab, size=batchsize)).long()
        with torch.no_grad(), chainer.no_backprop_mode(), chainer.using_config('train', False):
            rnnlm_th.predictor.eval()
            state_th, y_th = rnnlm_th.predictor(None, x.long())
            state_ch, y_ch = rnnlm_ch.predictor(None, x.data.numpy())
            for k in state_ch.keys():
                for n in range(len(state_th[k])):
                    print(k, n)
                    print(state_th[k][n].data.numpy())
                    print(state_ch[k][n].data)
                    numpy.testing.assert_allclose(state_th[k][n].data.numpy(), state_ch[k][n].data, 1e-5)
            print("y")
            print(y_th.data.numpy())
            print(y_ch.data)
            numpy.testing.assert_allclose(y_th.data.numpy(), y_ch.data, 1e-5)
Exemplo n.º 4
0
def recog(args):
    """Decode with the given args.

    Args:
        args (namespace): The program arguments.

    """
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    set_deterministic_chainer(args)

    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    logging.info('reading model parameters from ' + args.model)
    # To be compatible with v.0.3.0 models
    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.chainer_backend.e2e_asr:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)
    assert isinstance(model, ASRInterface)
    chainer_load(args.model, model)

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_chainer.ClassifierWithState(
            lm_chainer.RNNLM(len(train_args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        chainer_load(args.rnnlm, rnnlm)
    else:
        rnnlm = None

    if args.word_rnnlm:
        rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
        word_dict = rnnlm_args.char_list_dict
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_chainer.ClassifierWithState(
            lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer,
                             rnnlm_args.unit))
        chainer_load(args.word_rnnlm, word_rnnlm)

        if rnnlm is not None:
            rnnlm = lm_chainer.ClassifierWithState(
                extlm_chainer.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict,
                                           char_dict))
        else:
            rnnlm = lm_chainer.ClassifierWithState(
                extlm_chainer.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                              char_dict))

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )

    # decode each utterance
    new_js = {}
    with chainer.no_backprop_mode():
        for idx, name in enumerate(js.keys(), 1):
            logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
            batch = [(name, js[name])]
            feat = load_inputs_and_targets(batch)[0][0]
            nbest_hyps = model.recognize(feat, args, train_args.char_list,
                                         rnnlm)
            new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                               train_args.char_list)

    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            },
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
Exemplo n.º 5
0
def recog(args):
    """Decode with the given args

    :param Namespace args: The program arguments
    """
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    set_deterministic_chainer(args)

    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # specify model architecture
    logging.info('reading model parameters from ' + args.model)
    model = E2E(idim, odim, train_args)
    chainer_load(args.model, model)

    # read rnnlm
    if args.rnnlm:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_chainer.ClassifierWithState(
            lm_chainer.RNNLM(len(train_args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        chainer_load(args.rnnlm, rnnlm)
    else:
        rnnlm = None

    if args.word_rnnlm:
        rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
        word_dict = rnnlm_args.char_list_dict
        char_dict = {x: i for i, x in enumerate(train_args.char_list)}
        word_rnnlm = lm_chainer.ClassifierWithState(
            lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer,
                             rnnlm_args.unit))
        chainer_load(args.word_rnnlm, word_rnnlm)

        if rnnlm is not None:
            rnnlm = lm_chainer.ClassifierWithState(
                extlm_chainer.MultiLevelLM(word_rnnlm.predictor,
                                           rnnlm.predictor, word_dict,
                                           char_dict))
        else:
            rnnlm = lm_chainer.ClassifierWithState(
                extlm_chainer.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                              char_dict))

    # read json data
    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf)

    # decode each utterance
    new_js = {}
    with chainer.no_backprop_mode():
        for idx, name in enumerate(js.keys(), 1):
            logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
            batch = [(name, js[name])]
            with using_transform_config({'train': False}):
                feat = load_inputs_and_targets(batch)[0][0]
            nbest_hyps = model.recognize(feat, args, train_args.char_list,
                                         rnnlm)
            new_js[name] = add_results_to_json(js[name], nbest_hyps,
                                               train_args.char_list)

    # TODO(watanabe) fix character coding problems when saving it
    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            }, indent=4, sort_keys=True).encode('utf_8'))