def test_node_output(): input_array = np.array([0, 1, 2, 3, 4, 5]) splits = 3 expected_shape = len(input_array) // splits input_tensor = ops.constant(input_array, dtype=np.int32) axis = ops.constant(0, dtype=np.int64) split_node = ops.split(input_tensor, axis, splits) split_node_outputs = split_node.outputs() assert len(split_node_outputs) == splits assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2] assert np.equal( [output_node.get_element_type() for output_node in split_node_outputs], input_tensor.get_element_type(), ).all() assert np.equal( [output_node.get_shape() for output_node in split_node_outputs], Shape([expected_shape]), ).all() assert np.equal( [output_node.get_partial_shape() for output_node in split_node_outputs], PartialShape([expected_shape]), ).all() output0 = split_node.output(0) output1 = split_node.output(1) output2 = split_node.output(2) assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
def test_split(): runtime = get_runtime() input_tensor = ov.constant(np.array([0, 1, 2, 3, 4, 5], dtype=np.int32)) axis = ov.constant(0, dtype=np.int64) splits = 3 split_node = ov.split(input_tensor, axis, splits) computation = runtime.computation(split_node) split_results = computation() expected_results = np.array([[0, 1], [2, 3], [4, 5]], dtype=np.int32) assert np.allclose(split_results, expected_results)
def test_multiple_outputs(): input_shape = [4, 4] input_data = np.arange(-8, 8).reshape(input_shape).astype(np.float32) expected_output = np.split(input_data, 2, axis=1)[0] expected_output[expected_output < 0] = 0 test_param = ops.parameter(input_shape, dtype=np.float32, name="A") split = ops.split(test_param, axis=1, num_splits=2) split_first_output = split.output(0) relu = ops.relu(split_first_output) runtime = get_runtime() computation = runtime.computation(relu, test_param) output = computation(input_data) assert np.equal(output, expected_output).all()