def replace_slice( dest_node, # type: Node src_node, # type: Node lower_bounds, # type: List[int] upper_bounds, # type: List[int] strides=None, # type: List[int] name=None, # type: str ): # type: (...) -> Node """Return a copy of `dest_node` with the specified slice overwritten by the `src_node` data. :param dest_node: The node providing data to be overwritten by the specified slice. :param src_node: The node providing data for overwriting. :param lower_bounds: The (inclusive) lower-bound coordinates for the replaced slice. :param upper_bounds: The (exclusive) upper-bound coordinates for the replaced slice. :param strides: The strides for the replaced slice. :param name: The optional name for the output new node. :return: The new node with copy of `dest_node` with the specified slice overwritten by the `src_node`. """ if strides is None: return ReplaceSlice(dest_node, src_node, Coordinate(lower_bounds), Coordinate(upper_bounds)) else: return ReplaceSlice(dest_node, src_node, Coordinate(lower_bounds), Coordinate(upper_bounds), Strides(strides))
def test_replace_slice(): element_type = Type.f32 A = Parameter(element_type, Shape([6, 4])) B = Parameter(element_type, Shape([3, 2])) parameter_list = [A, B] input_arr_a = np.zeros(24, dtype=np.float32).reshape(6, 4) input_arr_b = np.ones(6, dtype=np.float32).reshape(3, 2) lower_bounds = [0, 1] upper_bounds = [3, 3] function = Function( NodeVector([ ReplaceSlice(A, B, Coordinate(lower_bounds), Coordinate(upper_bounds)) ]), parameter_list, 'test') backend = Backend.create(test.BACKEND_NAME) a = backend.create_tensor(element_type, Shape([6, 4])) b = backend.create_tensor(element_type, Shape([3, 2])) result = backend.create_tensor(element_type, Shape([6, 4])) a.write(util.numpy_to_c(input_arr_a), 0, 24 * 4) b.write(util.numpy_to_c(input_arr_b), 0, 6 * 4) result_arr = np.zeros(24, dtype=np.float32).reshape(6, 4) result.write(util.numpy_to_c(result_arr), 0, 24 * 4) handle = backend.compile(function) handle.call([result], [a, b]) result.read(util.numpy_to_c(result_arr), 0, 24 * 4) result_arr_ref = np.copy(input_arr_a) result_arr_ref[lower_bounds[0]:upper_bounds[0], lower_bounds[1]:upper_bounds[1]] = input_arr_b assert np.allclose(result_arr, result_arr_ref) #test with strides lower_bounds = [0, 0] upper_bounds = [5, 3] strides = [2, 2] function = Function( NodeVector([ ReplaceSlice(A, B, Coordinate(lower_bounds), Coordinate(upper_bounds), Strides(strides)) ]), parameter_list, 'test') backend = Backend.create(test.BACKEND_NAME) handle = backend.compile(function) handle.call([result], [a, b]) result.read(util.numpy_to_c(result_arr), 0, 24 * 4) result_arr_ref = np.copy(input_arr_a) result_arr_ref[::strides[0], ::strides[1]] = input_arr_b assert np.allclose(result_arr, result_arr_ref)