def test_replace_parameter(): shape1 = PartialShape([1]) param1 = ops.parameter(shape1, dtype=np.float32, name="data") shape2 = PartialShape([2]) param2 = ops.parameter(shape2, dtype=np.float32, name="data") relu = ops.relu(param1, name="relu") function = Function(relu, [param1], "TestFunction") param_index = function.get_parameter_index(param1) function.replace_parameter(param_index, param2) assert function.get_parameter_index(param2) == param_index assert function.get_parameter_index(param1) == -1
def test_parameter_index_invalid(): shape1 = PartialShape([1]) param1 = ops.parameter(shape1, dtype=np.float32, name="data1") relu = ops.relu(param1, name="relu") function = Function(relu, [param1], "TestFunction") shape2 = PartialShape([2]) param2 = ops.parameter(shape2, dtype=np.float32, name="data2") assert function.get_parameter_index(param2) == -1
def test_parameter_index(): input_shape = PartialShape([1]) param = ops.parameter(input_shape, dtype=np.float32, name="data") relu = ops.relu(param, name="relu") function = Function(relu, [param], "TestFunction") assert function.get_parameter_index(param) == 0