示例#1
0
    def test_batch_scan(self):
        input_dim = 3
        hidden_dim = 5
        bidirectional = False
        layers = 3

        lstm_model_name = 'test_batch_rnn_lstm.onnx'
        # create an LSTM model for generating baseline data
        generate_model('lstm',
                       input_dim,
                       hidden_dim,
                       bidirectional,
                       layers,
                       lstm_model_name,
                       batch_one=False,
                       has_seq_len=True)

        seq_len = 8
        batch_size = 2
        # prepare input
        data_input = (np.random.rand(seq_len, batch_size, input_dim) * 2 - 1).astype(np.float32)
        data_seq_len = np.random.randint(1, seq_len, size=(batch_size,), dtype=np.int32)

        # run lstm as baseline
        sess = onnxrt.InferenceSession(lstm_model_name)
        first_lstm_data_output = sess.run([], {'input': data_input[:, 0:1, :], 'seq_len': data_seq_len[0:1]})

        lstm_data_output = []
        lstm_data_output = first_lstm_data_output

        for b in range(1, batch_size):
            lstm_data_output = lstm_data_output + sess.run([], {
                'input': data_input[:, b:(b + 1), :],
                'seq_len': data_seq_len[b:(b + 1)]
            })
        lstm_data_output = np.concatenate(lstm_data_output, axis=1)

        # generate a batch scan model
        scan_model_name = 'test_batch_rnn_scan.onnx'
        subprocess.run([
            sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', lstm_model_name, '--output',
            scan_model_name, '--mode', 'to_scan'
        ],
                       check=True)

        # run scan_batch with batch size 1
        sess = onnxrt.InferenceSession(scan_model_name)
        scan_batch_data_output = sess.run([], {'input': data_input[:, 0:1, :], 'seq_len': data_seq_len[0:1]})
        assert np.allclose(first_lstm_data_output, scan_batch_data_output)

        # run scan_batch with batch size 2
        scan_batch_data_output = sess.run([], {'input': data_input, 'seq_len': data_seq_len})
        assert np.allclose(lstm_data_output, scan_batch_data_output)

        # run scan_batch with batch size 1 again
        scan_batch_data_output = sess.run([], {'input': data_input[:, 0:1, :], 'seq_len': data_seq_len[0:1]})
        assert np.allclose(first_lstm_data_output, scan_batch_data_output)
    def test_batch_scan(self):
        input_dim = 3
        hidden_dim = 5
        bidirectional = False
        layers = 3

        for onnx_opset_ver in [7, 13]:
            lstm_model_name = "test_batch_rnn_lstm.onnx"
            # create an LSTM model for generating baseline data
            generate_model(
                "lstm",
                input_dim,
                hidden_dim,
                bidirectional,
                layers,
                lstm_model_name,
                batch_one=False,
                has_seq_len=True,
                onnx_opset_ver=onnx_opset_ver,
            )

            seq_len = 8
            batch_size = 2
            # prepare input
            data_input = (np.random.rand(seq_len, batch_size, input_dim) * 2 -
                          1).astype(np.float32)
            data_seq_len = np.random.randint(1,
                                             seq_len,
                                             size=(batch_size, ),
                                             dtype=np.int32)

            # run lstm as baseline
            sess = onnxrt.InferenceSession(
                lstm_model_name, providers=onnxrt.get_available_providers())
            first_lstm_data_output = sess.run([], {
                "input": data_input[:, 0:1, :],
                "seq_len": data_seq_len[0:1]
            })

            lstm_data_output = []
            lstm_data_output = first_lstm_data_output

            for b in range(1, batch_size):
                lstm_data_output = lstm_data_output + sess.run(
                    [],
                    {
                        "input": data_input[:, b:(b + 1), :],
                        "seq_len": data_seq_len[b:(b + 1)],
                    },
                )
            lstm_data_output = np.concatenate(lstm_data_output, axis=1)

            # generate a batch scan model
            scan_model_name = "test_batch_rnn_scan.onnx"
            subprocess.run(
                [
                    sys.executable,
                    "-m",
                    "onnxruntime.nuphar.model_editor",
                    "--input",
                    lstm_model_name,
                    "--output",
                    scan_model_name,
                    "--mode",
                    "to_scan",
                ],
                check=True,
            )

            # run scan_batch with batch size 1
            sess = onnxrt.InferenceSession(
                scan_model_name, providers=onnxrt.get_available_providers())
            scan_batch_data_output = sess.run([], {
                "input": data_input[:, 0:1, :],
                "seq_len": data_seq_len[0:1]
            })
            assert np.allclose(first_lstm_data_output, scan_batch_data_output)

            # run scan_batch with batch size 2
            scan_batch_data_output = sess.run([], {
                "input": data_input,
                "seq_len": data_seq_len
            })
            assert np.allclose(lstm_data_output, scan_batch_data_output)

            # run scan_batch with batch size 1 again
            scan_batch_data_output = sess.run([], {
                "input": data_input[:, 0:1, :],
                "seq_len": data_seq_len[0:1]
            })
            assert np.allclose(first_lstm_data_output, scan_batch_data_output)