Пример #1
0
def test_model_trainable_and_decodable(module, num_encs, model_dict):
    args = make_arg(num_encs=num_encs, **model_dict)
    batch = prepare_inputs("pytorch", num_encs)

    # test trainable
    m = importlib.import_module(module)
    model = m.E2E([40 for _ in range(num_encs)], 5, args)
    loss = model(*batch)
    loss.backward()  # trainable

    # test attention plot
    dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=40, odim=5, num_inputs=num_encs)
    batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True)
    att_ws = model.calculate_all_attentions(*convert_batch(
        batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs))
    from espnet.asr.asr_utils import PlotAttentionReport
    tmpdir = tempfile.mkdtemp()
    plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None)
    for i in range(num_encs):
        # att-encoder
        att_w = plot.get_attention_weight(0, att_ws[i][0])
        plot._plot_and_save_attention(att_w, '{}/att{}.png'.format(tmpdir, i))
    # han
    att_w = plot.get_attention_weight(0, att_ws[num_encs][0])
    plot._plot_and_save_attention(att_w, '{}/han.png'.format(tmpdir), han_mode=True)

    # test decodable
    with torch.no_grad(), chainer.no_backprop_mode():
        in_data = [np.random.randn(10, 40) for _ in range(num_encs)]
        model.recognize(in_data, args, args.char_list)  # decodable
        if "pytorch" in module:
            batch_in_data = [[np.random.randn(10, 40), np.random.randn(5, 40)] for _ in range(num_encs)]
            model.recognize_batch(batch_in_data, args, args.char_list)  # batch decodable
Пример #2
0
def test_calculate_plot_attention_ctc(module, num_encs, model_dict):
    args = make_arg(num_encs=num_encs, **model_dict)
    m = importlib.import_module(module)
    model = m.E2E([2 for _ in range(num_encs)], 2, args)

    # test attention plot
    dummy_json = make_dummy_json(num_encs, [2, 3], [2, 3],
                                 idim=2,
                                 odim=2,
                                 num_inputs=num_encs)
    batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True)
    att_ws = model.calculate_all_attentions(*convert_batch(
        batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs))
    from espnet.asr.asr_utils import PlotAttentionReport

    tmpdir = tempfile.mkdtemp()
    plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0],
                               tmpdir, None, None, None)
    for i in range(num_encs):
        # att-encoder
        att_w = plot.get_attention_weight(0, att_ws[i][0])
        plot._plot_and_save_attention(att_w, "{}/att{}.png".format(tmpdir, i))
    # han
    att_w = plot.get_attention_weight(0, att_ws[num_encs][0])
    plot._plot_and_save_attention(att_w,
                                  "{}/han.png".format(tmpdir),
                                  han_mode=True)

    # test CTC plot
    ctc_probs = model.calculate_all_ctc_probs(*convert_batch(
        batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs))
    from espnet.asr.asr_utils import PlotCTCReport

    tmpdir = tempfile.mkdtemp()
    plot = PlotCTCReport(model.calculate_all_ctc_probs, batchset[0], tmpdir,
                         None, None, None)
    if args.mtlalpha > 0:
        for i in range(num_encs):
            # ctc-encoder
            plot._plot_and_save_ctc(ctc_probs[i][0],
                                    "{}/ctc{}.png".format(tmpdir, i))