def test_model_trainable_and_decodable(module, num_encs, model_dict): args = make_arg(num_encs=num_encs, **model_dict) batch = prepare_inputs("pytorch", num_encs) # test trainable m = importlib.import_module(module) model = m.E2E([40 for _ in range(num_encs)], 5, args) loss = model(*batch) loss.backward() # trainable # test attention plot dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=40, odim=5, num_inputs=num_encs) batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) att_ws = model.calculate_all_attentions(*convert_batch( batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotAttentionReport tmpdir = tempfile.mkdtemp() plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None) for i in range(num_encs): # att-encoder att_w = plot.get_attention_weight(0, att_ws[i][0]) plot._plot_and_save_attention(att_w, "{}/att{}.png".format(tmpdir, i)) # han att_w = plot.get_attention_weight(0, att_ws[num_encs][0]) plot._plot_and_save_attention(att_w, "{}/han.png".format(tmpdir), han_mode=True) # test CTC plot ctc_probs = model.calculate_all_ctc_probs(*convert_batch( batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotCTCReport tmpdir = tempfile.mkdtemp() plot = PlotCTCReport(model.calculate_all_ctc_probs, batchset[0], tmpdir, None, None, None) if args.mtlalpha > 0: for i in range(num_encs): # ctc-encoder plot._plot_and_save_ctc(ctc_probs[i][0], "{}/ctc{}.png".format(tmpdir, i)) # test decodable with torch.no_grad(), chainer.no_backprop_mode(): in_data = [np.random.randn(10, 40) for _ in range(num_encs)] model.recognize(in_data, args, args.char_list) # decodable if "pytorch" in module: batch_in_data = [[np.random.randn(10, 40), np.random.randn(5, 40)] for _ in range(num_encs)] model.recognize_batch(batch_in_data, args, args.char_list) # batch decodable
def test_calculate_plot_attention_ctc(module, num_encs, model_dict): args = make_arg(num_encs=num_encs, **model_dict) m = importlib.import_module(module) model = m.E2E([2 for _ in range(num_encs)], 2, args) # test attention plot dummy_json = make_dummy_json(num_encs, [2, 3], [2, 3], idim=2, odim=2, num_inputs=num_encs) batchset = make_batchset(dummy_json, 2, 2**10, 2**10, shortest_first=True) att_ws = model.calculate_all_attentions(*convert_batch( batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotAttentionReport tmpdir = tempfile.mkdtemp() plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None) for i in range(num_encs): # att-encoder att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[i][0]) plot._plot_and_save_attention(att_w, "{}/att{}.png".format(tmpdir, i)) # han att_w = plot.trim_attention_weight("utt_%d" % 0, att_ws[num_encs][0]) plot._plot_and_save_attention(att_w, "{}/han.png".format(tmpdir), han_mode=True) # test CTC plot ctc_probs = model.calculate_all_ctc_probs(*convert_batch( batchset[0], "pytorch", idim=2, odim=2, num_inputs=num_encs)) from espnet.asr.asr_utils import PlotCTCReport tmpdir = tempfile.mkdtemp() plot = PlotCTCReport(model.calculate_all_ctc_probs, batchset[0], tmpdir, None, None, None) if args.mtlalpha > 0: for i in range(num_encs): # ctc-encoder plot._plot_and_save_ctc(ctc_probs[i][0], "{}/ctc{}.png".format(tmpdir, i))