def test_split_1d(): # 1D data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32) node = onnx.helper.make_node("Split", inputs=["input"], outputs=["z", "w"], axis=0) expected_outputs = [ np.array([1.0, 2.0, 3.0]).astype(np.float32), np.array([4.0, 5.0, 6.0]).astype(np.float32), ] ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) splits = np.array([2, 3, 1]).astype(np.int64) node = onnx.helper.make_node("Split", inputs=["input", "splits"], outputs=["y", "z", "w"], axis=0) expected_outputs = [ np.array([1.0, 2.0]).astype(np.float32), np.array([3.0, 4.0, 5.0]).astype(np.float32), np.array([6.0]).astype(np.float32), ] ng_results = run_node(node, [data, splits]) assert all_arrays_equal(ng_results, expected_outputs) # Default values data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32) node = onnx.helper.make_node("Split", inputs=["input"], outputs=["y", "z", "w"]) expected_outputs = [ np.array([1.0, 2.0]).astype(np.float32), np.array([3.0, 4.0]).astype(np.float32), np.array([5.0, 6.0]).astype(np.float32), ] ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) splits = np.array([2, 4]).astype(np.int64) node = onnx.helper.make_node("Split", inputs=["input", "splits"], outputs=["y", "z"], split=[2, 4]) expected_outputs = [ np.array([1.0, 2.0]).astype(np.float32), np.array([3.0, 4.0, 5.0, 6.0]).astype(np.float32), ] ng_results = run_node(node, [data, splits]) assert all_arrays_equal(ng_results, expected_outputs)
def test_split_2d_splits_input(): data = np.arange(8, dtype=np.int32).reshape(2, 4) splits = np.array([3, 1]).astype(np.int64) node = onnx.helper.make_node("Split", inputs=["x", "splits"], outputs=["a", "b"], axis=1) expected_outputs = [ np.array([[0, 1, 2], [4, 5, 6]], dtype=np.int32), np.array([[3], [7]], dtype=np.int32), ] ng_results = run_node(node, [data, splits]) assert all_arrays_equal(ng_results, expected_outputs)
def test_split_2d(node, expected_output): data = np.arange(8, dtype=np.int32).reshape(2, 4) ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_output)