Пример #1
0
def replace_slice(
        dest_node,  # type: Node
        src_node,  # type: Node
        lower_bounds,  # type: List[int]
        upper_bounds,  # type: List[int]
        strides=None,  # type: List[int]
        name=None,  # type: str
):
    # type: (...) -> Node
    """Return a copy of `dest_node` with the specified slice overwritten by the `src_node` data.

    :param dest_node: The node providing data to be overwritten by the specified slice.
    :param src_node: The node providing data for overwriting.
    :param lower_bounds: The (inclusive) lower-bound coordinates for the replaced slice.
    :param upper_bounds: The (exclusive) upper-bound coordinates for the replaced slice.
    :param strides: The strides for the replaced slice.
    :param name: The optional name for the output new node.
    :return: The new node with copy of `dest_node` with the specified slice overwritten
             by the `src_node`.
    """
    if strides is None:
        return ReplaceSlice(dest_node, src_node, Coordinate(lower_bounds),
                            Coordinate(upper_bounds))
    else:
        return ReplaceSlice(dest_node, src_node, Coordinate(lower_bounds),
                            Coordinate(upper_bounds), Strides(strides))
Пример #2
0
def test_replace_slice():

    element_type = Type.f32
    A = Parameter(element_type, Shape([6, 4]))
    B = Parameter(element_type, Shape([3, 2]))
    parameter_list = [A, B]

    input_arr_a = np.zeros(24, dtype=np.float32).reshape(6, 4)
    input_arr_b = np.ones(6, dtype=np.float32).reshape(3, 2)
    lower_bounds = [0, 1]
    upper_bounds = [3, 3]

    function = Function(
        NodeVector([
            ReplaceSlice(A, B, Coordinate(lower_bounds),
                         Coordinate(upper_bounds))
        ]), parameter_list, 'test')
    backend = Backend.create(test.BACKEND_NAME)

    a = backend.create_tensor(element_type, Shape([6, 4]))
    b = backend.create_tensor(element_type, Shape([3, 2]))
    result = backend.create_tensor(element_type, Shape([6, 4]))

    a.write(util.numpy_to_c(input_arr_a), 0, 24 * 4)
    b.write(util.numpy_to_c(input_arr_b), 0, 6 * 4)

    result_arr = np.zeros(24, dtype=np.float32).reshape(6, 4)
    result.write(util.numpy_to_c(result_arr), 0, 24 * 4)
    handle = backend.compile(function)
    handle.call([result], [a, b])
    result.read(util.numpy_to_c(result_arr), 0, 24 * 4)

    result_arr_ref = np.copy(input_arr_a)
    result_arr_ref[lower_bounds[0]:upper_bounds[0],
                   lower_bounds[1]:upper_bounds[1]] = input_arr_b

    assert np.allclose(result_arr, result_arr_ref)

    #test with strides
    lower_bounds = [0, 0]
    upper_bounds = [5, 3]
    strides = [2, 2]

    function = Function(
        NodeVector([
            ReplaceSlice(A, B, Coordinate(lower_bounds),
                         Coordinate(upper_bounds), Strides(strides))
        ]), parameter_list, 'test')
    backend = Backend.create(test.BACKEND_NAME)

    handle = backend.compile(function)
    handle.call([result], [a, b])
    result.read(util.numpy_to_c(result_arr), 0, 24 * 4)

    result_arr_ref = np.copy(input_arr_a)
    result_arr_ref[::strides[0], ::strides[1]] = input_arr_b

    assert np.allclose(result_arr, result_arr_ref)