def test_split_1d(): # 1D data = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32) node = onnx.helper.make_node('Split', inputs=['input'], outputs=['z', 'w'], axis=0) expected_outputs = [np.array([1., 2., 3.]).astype(np.float32), np.array([4., 5., 6.]).astype(np.float32)] ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) node = onnx.helper.make_node('Split', inputs=['input'], outputs=['y', 'z', 'w'], axis=0, split=[2, 3, 1]) expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4., 5.]).astype(np.float32), np.array([6.]).astype(np.float32)] ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) # Default values data = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32) node = onnx.helper.make_node('Split', inputs=['input'], outputs=['y', 'z', 'w']) expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4.]).astype(np.float32), np.array([5., 6.]).astype(np.float32)] ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) node = onnx.helper.make_node('Split', inputs=['input'], outputs=['y', 'z'], split=[2, 4]) expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4., 5., 6.]).astype(np.float32)] ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs)
def test_split_2d(node, expected_output): data = np.arange(8).reshape(2, 4) ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_output)
def test_split(node, expected_output): data = np.arange(8).reshape(2, 4) ng_results = convert_and_calculate(node, [data], expected_output) assert all_arrays_equal(ng_results, expected_output)