Example #1
0
def test_split_1d():
    # 1D
    data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)

    node = onnx.helper.make_node("Split",
                                 inputs=["input"],
                                 outputs=["z", "w"],
                                 axis=0)
    expected_outputs = [
        np.array([1.0, 2.0, 3.0]).astype(np.float32),
        np.array([4.0, 5.0, 6.0]).astype(np.float32),
    ]
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_outputs)

    splits = np.array([2, 3, 1]).astype(np.int64)
    node = onnx.helper.make_node("Split",
                                 inputs=["input", "splits"],
                                 outputs=["y", "z", "w"],
                                 axis=0)
    expected_outputs = [
        np.array([1.0, 2.0]).astype(np.float32),
        np.array([3.0, 4.0, 5.0]).astype(np.float32),
        np.array([6.0]).astype(np.float32),
    ]
    ng_results = run_node(node, [data, splits])
    assert all_arrays_equal(ng_results, expected_outputs)

    # Default values
    data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)

    node = onnx.helper.make_node("Split",
                                 inputs=["input"],
                                 outputs=["y", "z", "w"])
    expected_outputs = [
        np.array([1.0, 2.0]).astype(np.float32),
        np.array([3.0, 4.0]).astype(np.float32),
        np.array([5.0, 6.0]).astype(np.float32),
    ]
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_outputs)

    splits = np.array([2, 4]).astype(np.int64)
    node = onnx.helper.make_node("Split",
                                 inputs=["input", "splits"],
                                 outputs=["y", "z"],
                                 split=[2, 4])
    expected_outputs = [
        np.array([1.0, 2.0]).astype(np.float32),
        np.array([3.0, 4.0, 5.0, 6.0]).astype(np.float32),
    ]
    ng_results = run_node(node, [data, splits])
    assert all_arrays_equal(ng_results, expected_outputs)
def test_split_2d_splits_input():
    data = np.arange(8, dtype=np.int32).reshape(2, 4)
    splits = np.array([3, 1]).astype(np.int64)
    node = onnx.helper.make_node(
        "Split", inputs=["x", "splits"], outputs=["a", "b"], axis=1
    )
    expected_outputs = [
        np.array([[0, 1, 2], [4, 5, 6]], dtype=np.int32),
        np.array([[3], [7]], dtype=np.int32),
    ]
    ng_results = run_node(node, [data, splits])
    assert all_arrays_equal(ng_results, expected_outputs)
def test_split_2d(node, expected_output):
    data = np.arange(8, dtype=np.int32).reshape(2, 4)
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_output)