Example #1
0
def upsample(input: {'field_type': pystencils.field.FieldType.CUSTOM}, result,
             factor):

    ndim = input.spatial_dimensions
    here = pystencils.x_vector(ndim)

    assignments = AssignmentCollection({
        result.center:
        pystencils.astnodes.ConditionalFieldAccess(
            input.absolute_access(
                tuple(
                    cast_func(sympy.S(1) / factor * h, create_type('int64'))
                    for h in here), ()),
            sympy.Or(*[s % cast_func(factor, 'int64') > 0 for s in here]))
    })

    def create_autodiff(self, constant_fields=None, **kwargs):
        backward_assignments = downsample(AdjointField(result),
                                          AdjointField(input), factor)
        self._autodiff = pystencils.autodiff.AutoDiffOp(
            assignments,
            "",
            backward_assignments=backward_assignments,
            **kwargs)

    assignments._create_autodiff = types.MethodType(create_autodiff,
                                                    assignments)
    return assignments
Example #2
0
    def visit_expr(expr):

        if isinstance(expr, cast_func) or isinstance(expr,
                                                     vector_memory_access):
            return expr
        elif expr.func in handled_functions or isinstance(
                expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
            new_args = [visit_expr(a) for a in expr.args]
            arg_types = [get_type_of_expression(a) for a in new_args]
            if not any(type(t) is VectorType for t in arg_types):
                return expr
            else:
                target_type = collate_types(arg_types)
                casted_args = [
                    cast_func(a, target_type) if t != target_type else a
                    for a, t in zip(new_args, arg_types)
                ]
                return expr.func(*casted_args)
        elif expr.func is sp.Pow:
            new_arg = visit_expr(expr.args[0])
            return expr.func(new_arg, expr.args[1])
        elif expr.func == sp.Piecewise:
            new_results = [visit_expr(a[0]) for a in expr.args]
            new_conditions = [visit_expr(a[1]) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [
                get_type_of_expression(a) for a in new_conditions
            ]

            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(
                    result_target_type) is not VectorType:
                result_target_type = VectorType(
                    result_target_type, width=condition_target_type.width)
            if type(condition_target_type) is not VectorType and type(
                    result_target_type) is VectorType:
                condition_target_type = VectorType(
                    condition_target_type, width=result_target_type.width)

            casted_results = [
                cast_func(a, result_target_type)
                if t != result_target_type else a
                for a, t in zip(new_results, types_of_results)
            ]

            casted_conditions = [
                cast_func(a, condition_target_type)
                if t != condition_target_type and a is not True else a
                for a, t in zip(new_conditions, types_of_conditions)
            ]

            return sp.Piecewise(
                *[(r, c) for r, c in zip(casted_results, casted_conditions)])
        else:
            return expr
Example #3
0
def test_address_of_with_cse():
    x, y = pystencils.fields('x,y: int64[2d]')
    s = pystencils.TypedSymbol('s', PointerType(create_type('int64')))

    assignments = pystencils.AssignmentCollection({
        y[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + s,
        x[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64')) + 1
    }, {})

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast)
    assignments_cse = sympy_cse(assignments)

    ast = pystencils.create_kernel(assignments_cse)
    pystencils.show_code(ast)
Example #4
0
 def visit_node(node, substitution_dict):
     substitution_dict = substitution_dict.copy()
     for arg in node.args:
         if isinstance(arg, ast.SympyAssignment):
             assignment = arg
             subs_expr = fast_subs(
                 assignment.rhs,
                 substitution_dict,
                 skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
             assignment.rhs = visit_expr(subs_expr)
             rhs_type = get_type_of_expression(assignment.rhs)
             if isinstance(assignment.lhs, TypedSymbol):
                 lhs_type = assignment.lhs.dtype
                 if type(rhs_type) is VectorType and type(
                         lhs_type) is not VectorType:
                     new_lhs_type = VectorType(lhs_type, rhs_type.width)
                     new_lhs = TypedSymbol(assignment.lhs.name,
                                           new_lhs_type)
                     substitution_dict[assignment.lhs] = new_lhs
                     assignment.lhs = new_lhs
             elif isinstance(assignment.lhs.func, cast_func):
                 lhs_type = assignment.lhs.args[1]
                 if type(lhs_type) is VectorType and type(
                         rhs_type) is not VectorType:
                     assignment.rhs = cast_func(assignment.rhs, lhs_type)
         elif isinstance(arg, ast.Conditional):
             arg.condition_expr = fast_subs(
                 arg.condition_expr,
                 substitution_dict,
                 skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
             arg.condition_expr = visit_expr(arg.condition_expr)
             visit_node(arg, substitution_dict)
         else:
             visit_node(arg, substitution_dict)
Example #5
0
    def _print_SympyAssignment(self, node):
        if node.is_declaration:
            if node.is_const:
                prefix = 'const '
            else:
                prefix = ''
            data_type = prefix + self._print(node.lhs.dtype) + " "
            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
        else:
            lhs_type = get_type_of_expression(node.lhs)
            if type(lhs_type) is VectorType and isinstance(
                    node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

                return self._vector_instruction_set[instr].format(
                    "&" + self.sympy_printer.doprint(node.lhs.args[0]),
                    self.sympy_printer.doprint(rhs)) + ';'
            else:
                return "%s = %s;" % (self.sympy_printer.doprint(
                    node.lhs), self.sympy_printer.doprint(node.rhs))
Example #6
0
def test_address_of_with_cse():
    x, y = pystencils.fields('x,y: int64[2d]')

    assignments = pystencils.AssignmentCollection(
        {
            y[0, 0]: cast_func(address_of(x[0, 0]), 'int64'),
            x[0, 0]: cast_func(address_of(x[0, 0]), 'int64') + 1
        }, {})

    ast = pystencils.create_kernel(assignments)
    code = pystencils.show_code(ast)
    assignments_cse = sympy_cse(assignments)

    ast = pystencils.create_kernel(assignments_cse)
    code = pystencils.show_code(ast)
    print(code)
    def __new__(cls, arg1, arg2):
        args = []
        for a in (arg1, arg2):
            if isinstance(a, sp.Number) or isinstance(a, int):
                args.append(cast_func(a, create_type("int")))
            elif isinstance(a, np.generic):
                args.append(cast_func(a, a.dtype))
            else:
                args.append(a)

        for a in args:
            try:
                type = get_type_of_expression(a)
                if not type.is_int():
                    raise ValueError("Argument to integer function is not an int but " + str(type))
            except NotImplementedError:
                raise ValueError("Integer functions can only be constructed with typed expressions")
        return super().__new__(cls, *args)
def test_type_interference():
    x = pystencils.fields('x:  float32[3d]')
    assignments = pystencils.AssignmentCollection({
        a: cast_func(10, create_type('float64')),
        b: cast_func(10, create_type('uint16')),
        e: 11,
        c: b,
        f: c + b,
        d: c + b + x.center + e,
        x.center: c + b + x.center
    })

    ast = pystencils.create_kernel(assignments)

    code = str(pystencils.get_code_str(ast))
    assert 'double a' in code
    assert 'uint16_t b' in code
    assert 'uint16_t f' in code
    assert 'int64_t e' in code
Example #9
0
def test_address_of():
    x, y = pystencils.fields('x,y: int64[2d]')
    s = pystencils.TypedSymbol('s', PointerType('int64'))

    assignments = pystencils.AssignmentCollection(
        {
            s: address_of(x[0, 0]),
            y[0, 0]: cast_func(s, 'int64')
        }, {})

    ast = pystencils.create_kernel(assignments)
    code = pystencils.show_code(ast)
    print(code)

    assignments = pystencils.AssignmentCollection(
        {y[0, 0]: cast_func(address_of(x[0, 0]), 'int64')}, {})

    ast = pystencils.create_kernel(assignments)
    code = pystencils.show_code(ast)
    print(code)
Example #10
0
def test_address_of():
    x, y = pystencils.fields('x,y: int64[2d]')
    s = pystencils.TypedSymbol('s', PointerType(create_type('int64')))

    assert address_of(x[0, 0]).canonical() == x[0, 0]
    assert address_of(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True)
    assert address_of(sp.Symbol("a")).dtype == PointerType('void', restrict=True)

    assignments = pystencils.AssignmentCollection({
        s: address_of(x[0, 0]),
        y[0, 0]: cast_func(s, create_type('int64'))
    }, {})

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast)

    assignments = pystencils.AssignmentCollection({
        y[0, 0]: cast_func(address_of(x[0, 0]), create_type('int64'))
    }, {})

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast)
Example #11
0
def test_abs():
    x, y, z = ps.fields('x, y, z:  float64[2d]')

    default_int_type = create_type('int64')

    assignments = ps.AssignmentCollection(
        {x[0, 0]: sympy.Abs(cast_func(y[0, 0], default_int_type))})

    config = ps.CreateKernelConfig(target=ps.Target.GPU)
    ast = ps.create_kernel(assignments, config=config)
    code = ps.get_code_str(ast)
    print(code)
    assert 'fabs(' not in code
Example #12
0
def upsample(input: {'field_type': FieldType.CUSTOM},
             result,
             sampling_factor=2):

    assert input.spatial_dimensions == result.spatial_dimensions
    assert input.field_type == FieldType.CUSTOM \
        or result.spatial_shape == tuple([2 * x for x in input.spatial_shape])
    assert input.index_shape == result.index_shape
    assignments = []
    ndim = input.spatial_dimensions

    for i in range(result.index_shape[0]):
        assignments.append(
            pystencils.Assignment(result.center(i),
                                  input.absolute_access(
                sympy.Matrix(tuple([cast_func(x // sampling_factor, create_type("int")) for x in pystencils.x_vector(ndim)])), (i,)))
        )
    return assignments
Example #13
0
    def _print_Add(self, expr, order=None):
        try:
            result = self._scalarFallback('_print_Add', expr)
        except Exception:
            result = None
        if result:
            return result
        args = expr.args

        # special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
        suffix = ""
        if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
                or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
            dtype = set([e.dtype for e in args if type(e) is cast_func])
            assert len(dtype) == 1
            dtype = dtype.pop()
            args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
                    for e in args]
            suffix = "int"

        summands = []
        for term in args:
            if term.func == sp.Mul:
                sign, t = self._print_Mul(term, inside_add=True)
            else:
                t = self._print(term)
                sign = 1
            summands.append(self.SummandInfo(sign, t))
        # Use positive terms first
        summands.sort(key=lambda e: e.sign, reverse=True)
        # if no positive term exists, prepend a zero
        if summands[0].sign == -1:
            summands.insert(0, self.SummandInfo(1, "0"))

        assert len(summands) >= 2
        processed = summands[0].term
        for summand in summands[1:]:
            func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
            processed = func.format(processed, summand.term, **self._kwargs)
        return processed
Example #14
0
    def _print_SympyAssignment(self, node):
        if node.is_declaration:
            if node.use_auto:
                data_type = 'auto '
            else:
                if node.is_const:
                    prefix = 'const '
                else:
                    prefix = ''
                data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "

            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
        else:
            lhs_type = get_type_of_expression(node.lhs)
            printed_mask = ""
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
                if mask != True:  # NOQA
                    instr = 'maskStoreA' if aligned else 'maskStoreU'
                    if instr not in self._vector_instruction_set:
                        self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format(
                            '{0}', self._vector_instruction_set['blendv'].format(
                                self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
                                '{1}', '{2}', **self._kwargs), **self._kwargs)
                    printed_mask = self.sympy_printer.doprint(mask)
                    if data_type.base_type.base_name == 'double':
                        if self._vector_instruction_set['double'] == '__m256d':
                            printed_mask = f"_mm256_castpd_si256({printed_mask})"
                        elif self._vector_instruction_set['double'] == '__m128d':
                            printed_mask = f"_mm_castpd_si128({printed_mask})"
                    elif data_type.base_type.base_name == 'float':
                        if self._vector_instruction_set['float'] == '__m256':
                            printed_mask = f"_mm256_castps_si256({printed_mask})"
                        elif self._vector_instruction_set['float'] == '__m128':
                            printed_mask = f"_mm_castps_si128({printed_mask})"

                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

                ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])

                if stride != 1:
                    instr = 'maskStoreS' if mask != True else 'storeS'  # NOQA
                    return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
                                                                      stride, printed_mask, **self._kwargs) + ';'

                pre_code = ''
                if nontemporal and 'cachelineZero' in self._vector_instruction_set:
                    first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
                    offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
                                      * node.lhs.args[0].field.spatial_strides[i] for i in
                                      range(len(node.lhs.args[0].field.spatial_strides))])
                    if stride == 1:
                        offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
                    size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
                    element_size = 8 if data_type.base_type.base_name == 'double' else 4
                    size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
                    pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
                        self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'

                code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
                                                                  printed_mask, **self._kwargs) + ';'
                flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
                if nontemporal and 'flushCacheline' in self._vector_instruction_set:
                    code2 = self._vector_instruction_set['flushCacheline'].format(
                        ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
                    code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
                elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
                    tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
                    code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
                        + self.sympy_printer.doprint(rhs) + ';'
                    code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
                    code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask,
                                                                                           **self._kwargs) + ';'
                    code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
                return pre_code + code
            else:
                return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
Example #15
0
    def visit_expr(expr, default_type='double'):
        if isinstance(expr, vector_memory_access):
            return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
        elif isinstance(expr, cast_func):
            return expr
        elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
            new_arg = visit_expr(expr.args[0], default_type)
            base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \
                else get_type_of_expression(expr.args[0])
            pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)),
                              (new_arg, True))
            return visit_expr(pw, default_type)
        elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
            if expr.func is sp.Mul and expr.args[0] == -1:
                # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                dtype = int
                for arg in expr.atoms(vector_memory_access):
                    if arg.dtype.base_type.is_float():
                        dtype = arg.dtype.base_type.numpy_dtype.type
                for arg in expr.atoms(TypedSymbol):
                    if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
                        dtype = arg.dtype.base_type.numpy_dtype.type
                if dtype is not int:
                    if dtype is np.float32:
                        default_type = 'float'
                    expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
            new_args = [visit_expr(a, default_type) for a in expr.args]
            arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
            if not any(type(t) is VectorType for t in arg_types):
                return expr
            else:
                target_type = collate_types(arg_types)
                casted_args = [
                    cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a
                    for a, t in zip(new_args, arg_types)]
                return expr.func(*casted_args)
        elif expr.func is sp.Pow:
            new_arg = visit_expr(expr.args[0], default_type)
            return expr.func(new_arg, expr.args[1])
        elif expr.func == sp.Piecewise:
            new_results = [visit_expr(a[0], default_type) for a in expr.args]
            new_conditions = [visit_expr(a[1], default_type) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [get_type_of_expression(a) for a in new_conditions]

            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType:
                result_target_type = VectorType(result_target_type, width=condition_target_type.width)
            if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
                condition_target_type = VectorType(condition_target_type, width=result_target_type.width)

            casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
                              for a, t in zip(new_results, types_of_results)]

            casted_conditions = [cast_func(a, condition_target_type)
                                 if t != condition_target_type and a is not True else a
                                 for a, t in zip(new_conditions, types_of_conditions)]

            return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
        else:
            return expr
Example #16
0
def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_aligned, nontemporal_fields,
                                                strided, keep_loop_stop, assume_sufficient_line_padding):
    """Goes over all innermost loops, changes increment to vector width and replaces field accesses by vector type."""
    all_loops = filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
    inner_loops = [n for n in all_loops if n.is_innermost_loop]
    zero_loop_counters = {l.loop_counter_symbol: 0 for l in all_loops}

    for loop_node in inner_loops:
        loop_range = loop_node.stop - loop_node.start

        # cut off loop tail, that is not a multiple of four
        if keep_loop_stop:
            pass
        elif assume_aligned and assume_sufficient_line_padding:
            loop_range = loop_node.stop - loop_node.start
            new_stop = loop_node.start + modulo_ceil(loop_range, vector_width)
            loop_node.stop = new_stop
        else:
            cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start
            loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)]
            assert len(loop_nodes) in (0, 1, 2)  # 2 for main and tail loop, 1 if loop range divisible by vector width
            if len(loop_nodes) == 0:
                continue
            loop_node = loop_nodes[0]

        # Find all array accesses (indexed) that depend on the loop counter as offset
        loop_counter_symbol = ast.LoopOverCoordinate.get_loop_counter_symbol(loop_node.coordinate_to_loop_over)
        substitutions = {}
        successful = True
        for indexed in loop_node.atoms(sp.Indexed):
            base, index = indexed.args
            if loop_counter_symbol in index.atoms(sp.Symbol):
                loop_counter_is_offset = loop_counter_symbol not in (index - loop_counter_symbol).atoms()
                aligned_access = (index - loop_counter_symbol).subs(zero_loop_counters) == 0
                stride = sp.simplify(index.subs({loop_counter_symbol: loop_counter_symbol + 1}) - index)
                if not loop_counter_is_offset and (not strided or loop_counter_symbol in stride.atoms()):
                    successful = False
                    break
                typed_symbol = base.label
                assert type(typed_symbol.dtype) is PointerType, \
                    f"Type of access is {typed_symbol.dtype}, {indexed}"

                vec_type = VectorType(typed_symbol.dtype.base_type, vector_width)
                use_aligned_access = aligned_access and assume_aligned
                nontemporal = False
                if hasattr(indexed, 'field'):
                    nontemporal = (indexed.field in nontemporal_fields) or (indexed.field.name in nontemporal_fields)
                substitutions[indexed] = vector_memory_access(indexed, vec_type, use_aligned_access, nontemporal, True,
                                                              stride if strided else 1)
                if nontemporal:
                    # insert NontemporalFence after the outermost loop
                    parent = loop_node.parent
                    while type(parent.parent.parent) is not ast.KernelFunction:
                        parent = parent.parent
                    parent.parent.insert_after(NontemporalFence(), parent, if_not_exists=True)
                    # insert CachelineSize at the beginning of the kernel
                    parent.parent.insert_front(CachelineSize(), if_not_exists=True)
        if not successful:
            warnings.warn("Could not vectorize loop because of non-consecutive memory access")
            continue

        loop_node.step = vector_width
        loop_node.subs(substitutions)
        vector_int_width = ast_node.instruction_set['intwidth']
        vector_loop_counter = cast_func(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
            + cast_func(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
                        VectorType(loop_counter_symbol.dtype, vector_int_width))

        fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                  skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, vector_memory_access))

        mask_conditionals(loop_node)

        from pystencils.rng import RNGBase
        substitutions = {}
        for rng in loop_node.atoms(RNGBase):
            new_result_symbols = [TypedSymbol(s.name, VectorType(s.dtype, width=vector_width))
                                  for s in rng.result_symbols]
            substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
            rng._symbols_defined = set(new_result_symbols)
        fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
Example #17
0
def test_cast_func():
    assert cast_func(TypedSymbol("s", np.uint), np.int64).canonical == TypedSymbol("s", np.uint).canonical

    a = cast_func(5, np.uint)
    assert a.is_negative is False
    assert a.is_nonnegative