def test_get_result_index(): input_shape = PartialShape([1]) param = ops.parameter(input_shape, dtype=np.float32, name="data") relu = ops.relu(param, name="relu") function = Function(relu, [param], "TestFunction") assert len(function.outputs) == 1 assert function.get_result_index(function.outputs[0]) == 0
def test_get_result_index_invalid(): shape1 = PartialShape([1]) param1 = ops.parameter(shape1, dtype=np.float32, name="data1") relu1 = ops.relu(param1, name="relu1") function = Function(relu1, [param1], "TestFunction") shape2 = PartialShape([2]) param2 = ops.parameter(shape2, dtype=np.float32, name="data2") relu2 = ops.relu(param2, name="relu2") invalid_output = relu2.outputs()[0] assert len(function.outputs) == 1 assert function.get_result_index(invalid_output) == -1