def test_block(): assignments = [Assignment(dst[0, 0](0), s[0]), Assignment(x, dst[0, 0](2))] bl = Block(assignments) assert bl.symbols_defined == {dst[0, 0](0), dst[0, 0](2), s[0], x} bl.append([Assignment(y, 10)]) assert bl.symbols_defined == {dst[0, 0](0), dst[0, 0](2), s[0], x, y} assert len(bl.args) == 3 list_iterator = iter([Assignment(s[1], 11)]) bl.insert_front(list_iterator) assert bl.args[0] == Assignment(s[1], 11)
def create_indexed_kernel( assignments: AssignmentOrAstNodeList, index_fields, function_name="kernel", type_info=None, coordinate_names=('x', 'y', 'z')) -> KernelFunction: """ Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling. The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type. This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the 'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for example boundary parameters. Args: assignments: list of assignments index_fields: list of index fields, i.e. 1D fields with struct data type type_info: see documentation of :func:`create_kernel` function_name: see documentation of :func:`create_kernel` coordinate_names: name of the coordinate fields in the struct data type """ fields_read, fields_written, assignments = add_types( assignments, type_info, check_independence_condition=False) all_fields = fields_read.union(fields_written) for index_field in index_fields: index_field.field_type = FieldType.INDEXED assert FieldType.is_indexed(index_field) assert index_field.spatial_dimensions == 1, "Index fields have to be 1D" non_index_fields = [f for f in all_fields if f not in index_fields] spatial_coordinates = {f.spatial_dimensions for f in non_index_fields} assert len( spatial_coordinates ) == 1, "Non-index fields do not have the same number of spatial coordinates" spatial_coordinates = list(spatial_coordinates)[0] def get_coordinate_symbol_assignment(name): for idx_field in index_fields: assert isinstance( idx_field.dtype, StructType), "Index fields have to have a struct data type" data_type = idx_field.dtype if data_type.has_element(name): rhs = idx_field[0](name) lhs = TypedSymbol(name, BasicType(data_type.get_element_type(name))) return SympyAssignment(lhs, rhs) raise ValueError( "Index %s not found in any of the passed index fields" % (name, )) coordinate_symbol_assignments = [ get_coordinate_symbol_assignment(n) for n in coordinate_names[:spatial_coordinates] ] coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments] assignments = coordinate_symbol_assignments + assignments # make 1D loop over index fields loop_body = Block([]) loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0]) for assignment in assignments: loop_body.append(assignment) function_body = Block([loop_node]) ast_node = KernelFunction(function_body, "cpu", "c", make_python_function, ghost_layers=None, function_name=function_name) fixed_coordinate_mapping = { f.name: coordinate_typed_symbols for f in non_index_fields } read_only_fields = set([f.name for f in fields_read - fields_written]) resolve_field_accesses(ast_node, read_only_fields, field_to_fixed_coordinates=fixed_coordinate_mapping) move_constants_before_loop(ast_node) return ast_node