Exemplo n.º 1
0
def test_replace_parameter():
    shape1 = PartialShape([1])
    param1 = ops.parameter(shape1, dtype=np.float32, name="data")
    shape2 = PartialShape([2])
    param2 = ops.parameter(shape2, dtype=np.float32, name="data")
    relu = ops.relu(param1, name="relu")

    function = Function(relu, [param1], "TestFunction")
    param_index = function.get_parameter_index(param1)
    function.replace_parameter(param_index, param2)
    assert function.get_parameter_index(param2) == param_index
    assert function.get_parameter_index(param1) == -1
Exemplo n.º 2
0
def test_parameter_index_invalid():
    shape1 = PartialShape([1])
    param1 = ops.parameter(shape1, dtype=np.float32, name="data1")
    relu = ops.relu(param1, name="relu")
    function = Function(relu, [param1], "TestFunction")
    shape2 = PartialShape([2])
    param2 = ops.parameter(shape2, dtype=np.float32, name="data2")
    assert function.get_parameter_index(param2) == -1
Exemplo n.º 3
0
def test_parameter_index():
    input_shape = PartialShape([1])
    param = ops.parameter(input_shape, dtype=np.float32, name="data")
    relu = ops.relu(param, name="relu")
    function = Function(relu, [param], "TestFunction")
    assert function.get_parameter_index(param) == 0