示例#1
0
def test_multi_source_function_wrapper_too_few_outputs_outputs_fails():
    functions = [lambda x: 2 * x,
                 lambda x: 4 * x]
    function_input = np.array([[1, 0], [2, 1], [3, 0], [4, 0], [5, 1]])
    source_index = -1
    msfw = MultiSourceFunctionWrapper(functions, source_index, extra_output_names=['cost'])

    with pytest.raises(ValueError):
        msfw.evaluate(function_input)
示例#2
0
def test_multi_source_function_wrapper_too_many_outputs_outputs_fails():
    functions = [lambda x: (2 * x, np.array([[1]] * x.shape[0]), np.array([[1]] * x.shape[0])),
                 lambda x: (4 * x, np.array([[2]] * x.shape[0]), np.array([[1]] * x.shape[0]))]
    function_input = np.array([[1, 0], [2, 1], [3, 0], [4, 0], [5, 1]])
    source_index = -1
    msfw = MultiSourceFunctionWrapper(functions, source_index)

    with pytest.raises(ValueError):
        msfw.evaluate(function_input)
示例#3
0
def test_multi_source_function_wrapper_evaluation_with_multiple_extra_arguments(
):
    functions = [
        lambda x:
        (2 * x, np.array([[1]] * x.shape[0]), np.array([[1]] * x.shape[0])),
        lambda x:
        (4 * x, np.array([[2]] * x.shape[0]), np.array([[1]] * x.shape[0]))
    ]
    function_input = np.array([[1, 0], [2, 1], [3, 0], [4, 0], [5, 1]])
    source_index = -1
    msfw = MultiSourceFunctionWrapper(
        functions, source_index, extra_output_names=['cost', 'constraint'])

    output = msfw.evaluate(function_input)

    assert len(output) == function_input.shape[0]
    for i, record in enumerate(output):
        assert_array_equal(output[i].X, function_input[i])
        this_function = functions[function_input[i, source_index]]
        this_function_input = np.delete(function_input[i], source_index)
        assert_array_equal(output[i].Y, this_function(this_function_input)[0])
        assert_array_equal(output[i].cost,
                           this_function(this_function_input)[1][0])
        assert_array_equal(output[i].constraint,
                           this_function(this_function_input)[2][0])
示例#4
0
def test_multi_source_function_wrapper_evaluation_single_output():
    functions = [lambda x: 2 * x, lambda x: 4 * x]
    function_input = np.array([[1, 0], [2, 1], [3, 0], [4, 0], [5, 1]])
    source_index = -1
    msfw = MultiSourceFunctionWrapper(functions, source_index)

    output = msfw.evaluate(function_input)

    assert len(output) == function_input.shape[0]
    for i, record in enumerate(output):
        assert_array_equal(output[i].X, function_input[i])
        this_function = functions[function_input[i, source_index]]
        this_function_input = np.delete(function_input[i], source_index)
        assert_array_equal(output[i].Y, this_function(this_function_input))
示例#5
0
def test_multi_source_function_wrapper_evaluation_no_cost():
    functions = [lambda x: 2 * x, lambda x: 4 * x]
    function_input = np.array([[1, 0], [2, 1], [3, 0], [4, 0], [5, 1]])
    source_index = -1
    msfw = MultiSourceFunctionWrapper(functions, source_index)

    output = msfw.evaluate(function_input)

    assert len(output) == function_input.shape[0]
    for i, record in enumerate(output):
        assert_array_equal(output[i].X, function_input[i])
        this_function = functions[function_input[i, source_index]]
        this_function_input = np.delete(function_input[i], source_index)
        assert_array_equal(output[i].Y, this_function(this_function_input))
        assert output[i].cost is None
示例#6
0
def test_multi_source_function_wrapper_invalid_input():
    # invalid input
    with pytest.raises(ValueError):
        functions = [lambda x: 2 * x]
        function_input = np.array([1, 0])
        msfw = MultiSourceFunctionWrapper(functions)
        msfw.evaluate(function_input)

    # invalid function output
    with pytest.raises(ValueError):
        functions = [lambda x: np.array([2])]
        function_input = np.array([[1, 0]])
        msfw = MultiSourceFunctionWrapper(functions)
        msfw.evaluate(function_input)

    # invalid function output type
    with pytest.raises(ValueError):
        functions = [lambda x: [2]]
        function_input = np.array([[1, 0]])
        msfw = MultiSourceFunctionWrapper(functions)
        msfw.evaluate(function_input)
示例#7
0
def test_multi_source_function_wrapper_invalid_input():
    # invalid input
    with pytest.raises(ValueError):
        functions = [lambda x: 2 * x]
        function_input = np.array([1, 0])
        msfw = MultiSourceFunctionWrapper(functions)
        msfw.evaluate(function_input)

    # invalid function output
    with pytest.raises(ValueError):
        functions = [lambda x: np.array([2])]
        function_input = np.array([[1, 0]])
        msfw = MultiSourceFunctionWrapper(functions)
        msfw.evaluate(function_input)

    # invalid function output type
    with pytest.raises(ValueError):
        functions = [lambda x: [2]]
        function_input = np.array([[1, 0]])
        msfw = MultiSourceFunctionWrapper(functions)
        msfw.evaluate(function_input)