Exemplo n.º 1
0
def test_split_1d():
    # 1D
    data = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32)

    node = onnx.helper.make_node('Split', inputs=['input'], outputs=['z', 'w'], axis=0)
    expected_outputs = [np.array([1., 2., 3.]).astype(np.float32),
                        np.array([4., 5., 6.]).astype(np.float32)]
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_outputs)

    node = onnx.helper.make_node('Split', inputs=['input'], outputs=['y', 'z', 'w'], axis=0,
                                 split=[2, 3, 1])
    expected_outputs = [np.array([1., 2.]).astype(np.float32),
                        np.array([3., 4., 5.]).astype(np.float32),
                        np.array([6.]).astype(np.float32)]
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_outputs)

    # Default values
    data = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32)

    node = onnx.helper.make_node('Split', inputs=['input'], outputs=['y', 'z', 'w'])
    expected_outputs = [np.array([1., 2.]).astype(np.float32),
                        np.array([3., 4.]).astype(np.float32),
                        np.array([5., 6.]).astype(np.float32)]
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_outputs)

    node = onnx.helper.make_node('Split', inputs=['input'], outputs=['y', 'z'], split=[2, 4])
    expected_outputs = [np.array([1., 2.]).astype(np.float32),
                        np.array([3., 4., 5., 6.]).astype(np.float32)]
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_outputs)
Exemplo n.º 2
0
def test_split_2d(node, expected_output):
    data = np.arange(8).reshape(2, 4)
    ng_results = run_node(node, [data])
    assert all_arrays_equal(ng_results, expected_output)
Exemplo n.º 3
0
def test_split(node, expected_output):
    data = np.arange(8).reshape(2, 4)
    ng_results = convert_and_calculate(node, [data], expected_output)
    assert all_arrays_equal(ng_results, expected_output)