Пример #1
0
    def _print_SympyAssignment(self, node):
        if node.is_declaration:
            if node.is_const:
                prefix = 'const '
            else:
                prefix = ''
            data_type = prefix + self._print(node.lhs.dtype) + " "
            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
        else:
            lhs_type = get_type_of_expression(node.lhs)
            if type(lhs_type) is VectorType and isinstance(
                    node.lhs, cast_func):
                arg, data_type, aligned, nontemporal = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal else 'storeA'

                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

                return self._vector_instruction_set[instr].format(
                    "&" + self.sympy_printer.doprint(node.lhs.args[0]),
                    self.sympy_printer.doprint(rhs)) + ';'
            else:
                return "%s = %s;" % (self.sympy_printer.doprint(
                    node.lhs), self.sympy_printer.doprint(node.rhs))
Пример #2
0
    def _print_Function(self, expr):
        if isinstance(expr, vector_memory_access):
            arg, data_type, aligned, _, mask, stride = expr.args
            if stride != 1:
                return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
            instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format(f"& {self._print(arg)}", **self._kwargs)
        elif isinstance(expr, cast_func):
            arg, data_type = expr.args
            if type(data_type) is VectorType:
                # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
                assert not isinstance(arg, vector_memory_access)
                if isinstance(arg, sp.Tuple):
                    is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
                    is_integer = get_type_of_expression(arg[0]) == create_type("int")
                    printed_args = [self._print(a) for a in arg]
                    instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
                    if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
                        increments = np.array(arg)[1:] - np.array(arg)[:-1]
                        if len(set(increments)) == 1:
                            return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
                                                                               **self._kwargs)
                    return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
                else:
                    is_boolean = get_type_of_expression(arg) == create_type("bool")
                    is_integer = get_type_of_expression(arg) == create_type("int") or \
                        (isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int())
                    instruction = 'makeVecConstBool' if is_boolean else \
                                  'makeVecConstInt' if is_integer else 'makeVecConst'
                    return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
        elif expr.func == fast_division:
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]),
                                                          **self._kwargs)
            return result
        elif expr.func == fast_sqrt:
            return f"({self._print(sp.sqrt(expr.args[0]))})"
        elif expr.func == fast_inv_sqrt:
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if 'rsqrt' in self.instruction_set:
                    return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs)
                else:
                    return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
        elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
            instr = 'any' if isinstance(expr, vec_any) else 'all'
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                if isinstance(expr.args[0], sp.Rel):
                    op = expr.args[0].rel_op
                    if (instr, op) in self.instruction_set:
                        return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args],
                                                                        **self._kwargs)
                return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs)

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
Пример #3
0
 def _print_Number(self, n):
     if get_type_of_expression(n) == create_type("int"):
         return ir.Constant(self.integer, int(n))
     elif get_type_of_expression(n) == create_type("double"):
         return ir.Constant(self.fp_type, float(n))
     else:
         raise NotImplementedError("Numbers can only have int and double",
                                   n)
Пример #4
0
 def to_c(self, print_func):
     dtype = collate_types((get_type_of_expression(self.args[0]),
                            get_type_of_expression(self.args[1])))
     assert dtype.is_int()
     code = "(({dtype})({0}) / ({dtype})({1}))"
     return code.format(print_func(self.args[0]),
                        print_func(self.args[1]),
                        dtype=dtype)
Пример #5
0
    def visit_expr(expr):

        if isinstance(expr, cast_func) or isinstance(expr,
                                                     vector_memory_access):
            return expr
        elif expr.func in handled_functions or isinstance(
                expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction):
            new_args = [visit_expr(a) for a in expr.args]
            arg_types = [get_type_of_expression(a) for a in new_args]
            if not any(type(t) is VectorType for t in arg_types):
                return expr
            else:
                target_type = collate_types(arg_types)
                casted_args = [
                    cast_func(a, target_type) if t != target_type else a
                    for a, t in zip(new_args, arg_types)
                ]
                return expr.func(*casted_args)
        elif expr.func is sp.Pow:
            new_arg = visit_expr(expr.args[0])
            return expr.func(new_arg, expr.args[1])
        elif expr.func == sp.Piecewise:
            new_results = [visit_expr(a[0]) for a in expr.args]
            new_conditions = [visit_expr(a[1]) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [
                get_type_of_expression(a) for a in new_conditions
            ]

            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(
                    result_target_type) is not VectorType:
                result_target_type = VectorType(
                    result_target_type, width=condition_target_type.width)
            if type(condition_target_type) is not VectorType and type(
                    result_target_type) is VectorType:
                condition_target_type = VectorType(
                    condition_target_type, width=result_target_type.width)

            casted_results = [
                cast_func(a, result_target_type)
                if t != result_target_type else a
                for a, t in zip(new_results, types_of_results)
            ]

            casted_conditions = [
                cast_func(a, condition_target_type)
                if t != condition_target_type and a is not True else a
                for a, t in zip(new_conditions, types_of_conditions)
            ]

            return sp.Piecewise(
                *[(r, c) for r, c in zip(casted_results, casted_conditions)])
        else:
            return expr
Пример #6
0
def test_dtype_of_constants():
    # Some come constants are neither of type Integer,Float,Rational and don't have args
    # >>> isinstance(pi, Integer)
    # False
    # >>> isinstance(pi, Float)
    # False
    # >>> isinstance(pi, Rational)
    # False
    # >>> pi.args
    # ()
    get_type_of_expression(sp.pi)
Пример #7
0
    def __new__(cls, flag_bit, mask_expression, *expressions):

        flag_dtype = get_type_of_expression(flag_bit)
        if not flag_dtype.is_int():
            raise ValueError('Argument flag_bit must be of integer type.')

        mask_dtype = get_type_of_expression(mask_expression)
        if not mask_dtype.is_int():
            raise ValueError(
                'Argument mask_expression must be of integer type.')

        return super().__new__(cls, flag_bit, mask_expression, *expressions)
Пример #8
0
    def _print_cast_func(self, conversion):
        node = self._print(conversion.args[0])
        to_dtype = get_type_of_expression(conversion)
        from_dtype = get_type_of_expression(conversion.args[0])
        if from_dtype == to_dtype:
            return self._print(conversion.args[0])

        # (From, to)
        decision = {
            (create_composite_type_from_string("int16"),
             create_composite_type_from_string("int64")):
            lambda: ir.Constant(self.integer, node),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double")):
            functools.partial(self.builder.sitofp, node, self.fp_type),
            (create_composite_type_from_string("int16"),
             create_composite_type_from_string("double")):
            functools.partial(self.builder.sitofp, node, self.fp_type),
            (create_composite_type_from_string("double"),
             create_composite_type_from_string("int")):
            functools.partial(self.builder.fptosi, node, self.integer),
            (create_composite_type_from_string("double *"),
             create_composite_type_from_string("int")):
            functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double *")):
            functools.partial(self.builder.inttoptr, node, self.fp_pointer),
            (create_composite_type_from_string("double * restrict"),
             create_composite_type_from_string("int")):
            functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double * restrict")):
            functools.partial(self.builder.inttoptr, node, self.fp_pointer),
            (create_composite_type_from_string("double * restrict const"),
             create_composite_type_from_string("int")):
            functools.partial(self.builder.ptrtoint, node, self.integer),
            (create_composite_type_from_string("int"),
             create_composite_type_from_string("double * restrict const")):
            functools.partial(self.builder.inttoptr, node, self.fp_pointer),
        }
        # TODO float, TEST: const, restrict
        # TODO bitcast, addrspacecast
        # TODO unsigned/signed fills
        # print([x for x in decision.keys()])
        # print("Types:")
        # print([(type(x), type(y)) for (x, y) in decision.keys()])
        # print("Cast:")
        # print((from_dtype, to_dtype))
        return decision[(from_dtype, to_dtype)]()
Пример #9
0
 def _comparison(self, cmpop, expr):
     if collate_types([get_type_of_expression(arg)
                       for arg in expr.args]) == create_type('double'):
         comparison = self.builder.fcmp_unordered
     else:
         comparison = self.builder.icmp_signed
     return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))
Пример #10
0
 def _scalarFallback(self, func_name, expr, *args, **kwargs):
     expr_type = get_type_of_expression(expr)
     if type(expr_type) is not VectorType:
         return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
     else:
         assert self.instruction_set['width'] == expr_type.width
         return None
Пример #11
0
 def visit_node(node, substitution_dict, default_type='double'):
     substitution_dict = substitution_dict.copy()
     for arg in node.args:
         if isinstance(arg, ast.SympyAssignment):
             assignment = arg
             subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                   skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
             assignment.rhs = visit_expr(subs_expr, default_type)
             rhs_type = get_type_of_expression(assignment.rhs)
             if isinstance(assignment.lhs, TypedSymbol):
                 lhs_type = assignment.lhs.dtype
                 if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
                     new_lhs_type = VectorType(lhs_type, rhs_type.width)
                     new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
                     substitution_dict[assignment.lhs] = new_lhs
                     assignment.lhs = new_lhs
             elif isinstance(assignment.lhs, vector_memory_access):
                 assignment.lhs = visit_expr(assignment.lhs, default_type)
         elif isinstance(arg, ast.Conditional):
             arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict,
                                            skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
             arg.condition_expr = visit_expr(arg.condition_expr, default_type)
             visit_node(arg, substitution_dict, default_type)
         else:
             visit_node(arg, substitution_dict, default_type)
Пример #12
0
    def _print_Piecewise(self, expr):
        result = self._scalarFallback('_print_Piecewise', expr)
        if result:
            return result

        if expr.args[-1].cond.args[0] is not sp.sympify(True):
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")

        result = self._print(expr.args[-1][0])
        for true_expr, condition in reversed(expr.args[:-1]):
            if isinstance(condition, cast_func) and get_type_of_expression(
                    condition.args[0]) == create_type("bool"):
                if not KERNCRAFT_NO_TERNARY_MODE:
                    result = "(({}) ? ({}) : ({}))".format(
                        self._print(condition.args[0]), self._print(true_expr),
                        result)
                else:
                    print("Warning - skipping ternary op")
            else:
                # noinspection SpellCheckingInspection
                result = self.instruction_set['blendv'].format(
                    result, self._print(true_expr), self._print(condition))
        return result
Пример #13
0
 def _print_Mul(self, expr):
     nodes = [self._print(a) for a in expr.args]
     e = nodes[0]
     if get_type_of_expression(expr) == create_type('double'):
         mul = self.builder.fmul
     else:  # int TODO unsigned/signed
         mul = self.builder.mul
     for node in nodes[1:]:
         e = mul(e, node)
     return e
Пример #14
0
    def _print_Function(self, expr):
        if isinstance(expr, vector_memory_access):
            arg, data_type, aligned, _ = expr.args
            instruction = self.instruction_set[
                'loadA'] if aligned else self.instruction_set['loadU']
            return instruction.format("& " + self._print(arg))
        elif isinstance(expr, cast_func):
            arg, data_type = expr.args
            if type(data_type) is VectorType:
                return self.instruction_set['makeVec'].format(self._print(arg))
        elif expr.func == fast_division:
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                result = self.instruction_set['/'].format(
                    self._print(expr.args[0]), self._print(expr.args[1]))
            return result
        elif expr.func == fast_sqrt:
            return "({})".format(self._print(sp.sqrt(expr.args[0])))
        elif expr.func == fast_inv_sqrt:
            result = self._scalarFallback('_print_Function', expr)
            if not result:
                if self.instruction_set['rsqrt']:
                    return self.instruction_set['rsqrt'].format(
                        self._print(expr.args[0]))
                else:
                    return "({})".format(self._print(1 /
                                                     sp.sqrt(expr.args[0])))
        elif isinstance(expr, vec_any):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['any'].format(
                    self._print(expr.args[0]))
        elif isinstance(expr, vec_all):
            expr_type = get_type_of_expression(expr.args[0])
            if type(expr_type) is not VectorType:
                return self._print(expr.args[0])
            else:
                return self.instruction_set['all'].format(
                    self._print(expr.args[0]))

        return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
Пример #15
0
 def _print_Add(self, expr):
     nodes = [self._print(a) for a in expr.args]
     e = nodes[0]
     if get_type_of_expression(expr) == create_type('double'):
         add = self.builder.fadd
     else:  # int TODO unsigned/signed
         add = self.builder.add
     for node in nodes[1:]:
         e = add(e, node)
     return e
Пример #16
0
    def _print_Pow(self, expr):
        """Don't use std::pow function, for small integer exponents, write as multiplication"""
        if not expr.free_symbols:
            return self._typed_number(expr.evalf(17), get_type_of_expression(expr.base))

        if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
            return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
        elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
            return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
        else:
            return super(CustomSympyPrinter, self)._print_Pow(expr)
Пример #17
0
 def check_type(e):
     if only_type is None:
         return True
     try:
         base_type = get_base_type(get_type_of_expression(e))
     except ValueError:
         return False
     if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
         return True
     if only_type == 'real' and (base_type.is_float()):
         return True
     else:
         return base_type == only_type
Пример #18
0
 def _print_Conditional(self, node):
     cond_type = get_type_of_expression(node.condition_expr)
     if isinstance(cond_type, VectorType):
         raise ValueError(
             "Problem with Conditional inside vectorized loop - use vec_any or vec_all"
         )
     condition_expr = self.sympy_printer.doprint(node.condition_expr)
     true_block = self._print_Block(node.true_block)
     result = "if (%s)\n%s " % (condition_expr, true_block)
     if node.false_block:
         false_block = self._print_Block(node.false_block)
         result += "else " + false_block
     return result
Пример #19
0
 def _print_Conditional(self, node):
     if type(node.condition_expr) is BooleanTrue:
         return self._print_Block(node.true_block)
     elif type(node.condition_expr) is BooleanFalse:
         return self._print_Block(node.false_block)
     cond_type = get_type_of_expression(node.condition_expr)
     if isinstance(cond_type, VectorType):
         raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all")
     condition_expr = self.sympy_printer.doprint(node.condition_expr)
     true_block = self._print_Block(node.true_block)
     result = f"if ({condition_expr})\n{true_block} "
     if node.false_block:
         false_block = self._print_Block(node.false_block)
         result += f"else {false_block}"
     return result
Пример #20
0
    def __new__(cls, arg1, arg2):
        args = []
        for a in (arg1, arg2):
            if isinstance(a, sp.Number) or isinstance(a, int):
                args.append(cast_func(a, create_type("int")))
            elif isinstance(a, np.generic):
                args.append(cast_func(a, a.dtype))
            else:
                args.append(a)

        for a in args:
            try:
                type = get_type_of_expression(a)
                if not type.is_int():
                    raise ValueError("Argument to integer function is not an int but " + str(type))
            except NotImplementedError:
                raise ValueError("Integer functions can only be constructed with typed expressions")
        return super().__new__(cls, *args)
Пример #21
0
    def check_type(e):
        if only_type is None:
            return True
        if isinstance(e, FieldPointerSymbol) and only_type == "real":
            return only_type == "int"

        try:
            base_type = get_type_of_expression(e)
        except ValueError:
            return False
        if isinstance(base_type, VectorType):
            return False
        if isinstance(base_type, PointerType):
            return only_type == 'int'
        if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
            return True
        if only_type == 'real' and (base_type.is_float()):
            return True
        else:
            return base_type == only_type
Пример #22
0
    def _print_Product(self, expr):
        template = """[&]() {{
    {dtype} product = ({dtype}) 1;
    for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
        product *= {expr};
    }}
    return product;
}}()"""
        var = expr.limits[0][0]
        start = expr.limits[0][1]
        end = expr.limits[0][2]
        code = template.format(
            dtype=get_type_of_expression(expr.args[0]),
            iterator_dtype='int',
            var=self._print(var),
            start=self._print(start),
            end=self._print(end),
            expr=self._print(expr.function),
            increment=str(1),
            condition=self._print(var) + ' <= ' + self._print(end)  # if start < end else '>='
        )
        return code
Пример #23
0
    def _print_Piecewise(self, piece):
        if not piece.args[-1].cond:
            # We need the last conditional to be a True, otherwise the resulting
            # function may not return a result.
            raise ValueError("All Piecewise expressions must contain an "
                             "(expr, True) statement to be used as a default "
                             "condition. Without one, the generated "
                             "expression may not evaluate to anything under "
                             "some condition.")
        if piece.has(Assignment):
            raise NotImplementedError(
                'The llvm-backend does not support assignments'
                'in the Piecewise function. It is questionable'
                'whether to implement it. So far there is no'
                'use-case to test it.')
        else:
            phi_data = []
            after_block = self.builder.append_basic_block()
            for (expr, condition) in piece.args:
                if condition == sp.sympify(True):  # Don't use 'is' use '=='!
                    phi_data.append((self._print(expr), self.builder.block))
                    self.builder.branch(after_block)
                    self.builder.position_at_end(after_block)
                else:
                    cond = self._print(condition)
                    true_block = self.builder.append_basic_block()
                    false_block = self.builder.append_basic_block()
                    self.builder.cbranch(cond, true_block, false_block)
                    self.builder.position_at_end(true_block)
                    phi_data.append((self._print(expr), true_block))
                    self.builder.branch(after_block)
                    self.builder.position_at_end(false_block)

            phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece)))
            for (val, block) in phi_data:
                phi.add_incoming(val, block)
            return phi
Пример #24
0
    def _print_SympyAssignment(self, node):
        if node.is_declaration:
            if node.use_auto:
                data_type = 'auto '
            else:
                if node.is_const:
                    prefix = 'const '
                else:
                    prefix = ''
                data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "

            return "%s%s = %s;" % (data_type,
                                   self.sympy_printer.doprint(node.lhs),
                                   self.sympy_printer.doprint(node.rhs))
        else:
            lhs_type = get_type_of_expression(node.lhs)
            printed_mask = ""
            if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func):
                arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args
                instr = 'storeU'
                if aligned:
                    instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA'
                if mask != True:  # NOQA
                    instr = 'maskStoreA' if aligned else 'maskStoreU'
                    if instr not in self._vector_instruction_set:
                        self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format(
                            '{0}', self._vector_instruction_set['blendv'].format(
                                self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs),
                                '{1}', '{2}', **self._kwargs), **self._kwargs)
                    printed_mask = self.sympy_printer.doprint(mask)
                    if data_type.base_type.base_name == 'double':
                        if self._vector_instruction_set['double'] == '__m256d':
                            printed_mask = f"_mm256_castpd_si256({printed_mask})"
                        elif self._vector_instruction_set['double'] == '__m128d':
                            printed_mask = f"_mm_castpd_si128({printed_mask})"
                    elif data_type.base_type.base_name == 'float':
                        if self._vector_instruction_set['float'] == '__m256':
                            printed_mask = f"_mm256_castps_si256({printed_mask})"
                        elif self._vector_instruction_set['float'] == '__m128':
                            printed_mask = f"_mm_castps_si128({printed_mask})"

                rhs_type = get_type_of_expression(node.rhs)
                if type(rhs_type) is not VectorType:
                    rhs = cast_func(node.rhs, VectorType(rhs_type))
                else:
                    rhs = node.rhs

                ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0])

                if stride != 1:
                    instr = 'maskStoreS' if mask != True else 'storeS'  # NOQA
                    return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
                                                                      stride, printed_mask, **self._kwargs) + ';'

                pre_code = ''
                if nontemporal and 'cachelineZero' in self._vector_instruction_set:
                    first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0"
                    offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i))
                                      * node.lhs.args[0].field.spatial_strides[i] for i in
                                      range(len(node.lhs.args[0].field.spatial_strides))])
                    if stride == 1:
                        offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1})
                    size = sp.Mul(*node.lhs.args[0].field.spatial_shape)
                    element_size = 8 if data_type.base_type.base_name == 'double' else 4
                    size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}"
                    pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
                        self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'

                code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
                                                                  printed_mask, **self._kwargs) + ';'
                flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
                if nontemporal and 'flushCacheline' in self._vector_instruction_set:
                    code2 = self._vector_instruction_set['flushCacheline'].format(
                        ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';'
                    code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}"
                elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set:
                    tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8]
                    code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \
                        + self.sympy_printer.doprint(rhs) + ';'
                    code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';'
                    code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask,
                                                                                           **self._kwargs) + ';'
                    code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
                return pre_code + code
            else:
                return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
Пример #25
0
    def visit_expr(expr, default_type='double'):
        if isinstance(expr, vector_memory_access):
            return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
        elif isinstance(expr, cast_func):
            return expr
        elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
            new_arg = visit_expr(expr.args[0], default_type)
            base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \
                else get_type_of_expression(expr.args[0])
            pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)),
                              (new_arg, True))
            return visit_expr(pw, default_type)
        elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
            if expr.func is sp.Mul and expr.args[0] == -1:
                # special treatment for the unary minus: make sure that the -1 has the same type as the argument
                dtype = int
                for arg in expr.atoms(vector_memory_access):
                    if arg.dtype.base_type.is_float():
                        dtype = arg.dtype.base_type.numpy_dtype.type
                for arg in expr.atoms(TypedSymbol):
                    if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float():
                        dtype = arg.dtype.base_type.numpy_dtype.type
                if dtype is not int:
                    if dtype is np.float32:
                        default_type = 'float'
                    expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
            new_args = [visit_expr(a, default_type) for a in expr.args]
            arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
            if not any(type(t) is VectorType for t in arg_types):
                return expr
            else:
                target_type = collate_types(arg_types)
                casted_args = [
                    cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a
                    for a, t in zip(new_args, arg_types)]
                return expr.func(*casted_args)
        elif expr.func is sp.Pow:
            new_arg = visit_expr(expr.args[0], default_type)
            return expr.func(new_arg, expr.args[1])
        elif expr.func == sp.Piecewise:
            new_results = [visit_expr(a[0], default_type) for a in expr.args]
            new_conditions = [visit_expr(a[1], default_type) for a in expr.args]
            types_of_results = [get_type_of_expression(a) for a in new_results]
            types_of_conditions = [get_type_of_expression(a) for a in new_conditions]

            result_target_type = get_type_of_expression(expr)
            condition_target_type = collate_types(types_of_conditions)
            if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType:
                result_target_type = VectorType(result_target_type, width=condition_target_type.width)
            if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType:
                condition_target_type = VectorType(condition_target_type, width=result_target_type.width)

            casted_results = [cast_func(a, result_target_type) if t != result_target_type else a
                              for a, t in zip(new_results, types_of_results)]

            casted_conditions = [cast_func(a, condition_target_type)
                                 if t != condition_target_type and a is not True else a
                                 for a, t in zip(new_conditions, types_of_conditions)]

            return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
        else:
            return expr