Exemple #1
0
def test_block():
    assignments = [Assignment(dst[0, 0](0), s[0]), Assignment(x, dst[0, 0](2))]
    bl = Block(assignments)
    assert bl.symbols_defined == {dst[0, 0](0), dst[0, 0](2), s[0], x}

    bl.append([Assignment(y, 10)])
    assert bl.symbols_defined == {dst[0, 0](0), dst[0, 0](2), s[0], x, y}
    assert len(bl.args) == 3

    list_iterator = iter([Assignment(s[1], 11)])
    bl.insert_front(list_iterator)

    assert bl.args[0] == Assignment(s[1], 11)
Exemple #2
0
def create_indexed_kernel(
    assignments: AssignmentOrAstNodeList,
    index_fields,
    function_name="kernel",
    type_info=None,
    coordinate_names=('x', 'y', 'z')) -> KernelFunction:
    """
    Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.

    The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type.
    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
    'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
    example boundary parameters.

    Args:
        assignments: list of assignments
        index_fields: list of index fields, i.e. 1D fields with struct data type
        type_info: see documentation of :func:`create_kernel`
        function_name: see documentation of :func:`create_kernel`
        coordinate_names: name of the coordinate fields in the struct data type
    """
    fields_read, fields_written, assignments = add_types(
        assignments, type_info, check_independence_condition=False)
    all_fields = fields_read.union(fields_written)

    for index_field in index_fields:
        index_field.field_type = FieldType.INDEXED
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"

    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(
        spatial_coordinates
    ) == 1, "Non-index fields do not have the same number of spatial coordinates"
    spatial_coordinates = list(spatial_coordinates)[0]

    def get_coordinate_symbol_assignment(name):
        for idx_field in index_fields:
            assert isinstance(
                idx_field.dtype,
                StructType), "Index fields have to have a struct data type"
            data_type = idx_field.dtype
            if data_type.has_element(name):
                rhs = idx_field[0](name)
                lhs = TypedSymbol(name,
                                  BasicType(data_type.get_element_type(name)))
                return SympyAssignment(lhs, rhs)
        raise ValueError(
            "Index %s not found in any of the passed index fields" % (name, ))

    coordinate_symbol_assignments = [
        get_coordinate_symbol_assignment(n)
        for n in coordinate_names[:spatial_coordinates]
    ]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
    assignments = coordinate_symbol_assignments + assignments

    # make 1D loop over index fields
    loop_body = Block([])
    loop_node = LoopOverCoordinate(loop_body,
                                   coordinate_to_loop_over=0,
                                   start=0,
                                   stop=index_fields[0].shape[0])

    for assignment in assignments:
        loop_body.append(assignment)

    function_body = Block([loop_node])
    ast_node = KernelFunction(function_body,
                              "cpu",
                              "c",
                              make_python_function,
                              ghost_layers=None,
                              function_name=function_name)

    fixed_coordinate_mapping = {
        f.name: coordinate_typed_symbols
        for f in non_index_fields
    }

    read_only_fields = set([f.name for f in fields_read - fields_written])
    resolve_field_accesses(ast_node,
                           read_only_fields,
                           field_to_fixed_coordinates=fixed_coordinate_mapping)
    move_constants_before_loop(ast_node)
    return ast_node