def _print_SympyAssignment(self, node): if node.is_declaration: if node.is_const: prefix = 'const ' else: prefix = '' data_type = prefix + self._print(node.lhs.dtype) + " " return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) if type(lhs_type) is VectorType and isinstance( node.lhs, cast_func): arg, data_type, aligned, nontemporal = node.lhs.args instr = 'storeU' if aligned: instr = 'stream' if nontemporal else 'storeA' rhs_type = get_type_of_expression(node.rhs) if type(rhs_type) is not VectorType: rhs = cast_func(node.rhs, VectorType(rhs_type)) else: rhs = node.rhs return self._vector_instruction_set[instr].format( "&" + self.sympy_printer.doprint(node.lhs.args[0]), self.sympy_printer.doprint(rhs)) + ';' else: return "%s = %s;" % (self.sympy_printer.doprint( node.lhs), self.sympy_printer.doprint(node.rhs))
def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _, mask, stride = expr.args if stride != 1: return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs) instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] return instruction.format(f"& {self._print(arg)}", **self._kwargs) elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: # vector_memory_access is a cast_func itself so it should't be directly inside a cast_func assert not isinstance(arg, vector_memory_access) if isinstance(arg, sp.Tuple): is_boolean = get_type_of_expression(arg[0]) == create_type("bool") is_integer = get_type_of_expression(arg[0]) == create_type("int") printed_args = [self._print(a) for a in arg] instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec' if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set: increments = np.array(arg)[1:] - np.array(arg)[:-1] if len(set(increments)) == 1: return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0], **self._kwargs) return self.instruction_set[instruction].format(*printed_args, **self._kwargs) else: is_boolean = get_type_of_expression(arg) == create_type("bool") is_integer = get_type_of_expression(arg) == create_type("int") or \ (isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int()) instruction = 'makeVecConstBool' if is_boolean else \ 'makeVecConstInt' if is_integer else 'makeVecConst' return self.instruction_set[instruction].format(self._print(arg), **self._kwargs) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1]), **self._kwargs) return result elif expr.func == fast_sqrt: return f"({self._print(sp.sqrt(expr.args[0]))})" elif expr.func == fast_inv_sqrt: result = self._scalarFallback('_print_Function', expr) if not result: if 'rsqrt' in self.instruction_set: return self.instruction_set['rsqrt'].format(self._print(expr.args[0]), **self._kwargs) else: return f"({self._print(1 / sp.sqrt(expr.args[0]))})" elif isinstance(expr, vec_any) or isinstance(expr, vec_all): instr = 'any' if isinstance(expr, vec_any) else 'all' expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: return self._print(expr.args[0]) else: if isinstance(expr.args[0], sp.Rel): op = expr.args[0].rel_op if (instr, op) in self.instruction_set: return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args], **self._kwargs) return self.instruction_set[instr].format(self._print(expr.args[0]), **self._kwargs) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_Number(self, n): if get_type_of_expression(n) == create_type("int"): return ir.Constant(self.integer, int(n)) elif get_type_of_expression(n) == create_type("double"): return ir.Constant(self.fp_type, float(n)) else: raise NotImplementedError("Numbers can only have int and double", n)
def to_c(self, print_func): dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) assert dtype.is_int() code = "(({dtype})({0}) / ({dtype})({1}))" return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
def visit_expr(expr): if isinstance(expr, cast_func) or isinstance(expr, vector_memory_access): return expr elif expr.func in handled_functions or isinstance( expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): new_args = [visit_expr(a) for a in expr.args] arg_types = [get_type_of_expression(a) for a in new_args] if not any(type(t) is VectorType for t in arg_types): return expr else: target_type = collate_types(arg_types) casted_args = [ cast_func(a, target_type) if t != target_type else a for a, t in zip(new_args, arg_types) ] return expr.func(*casted_args) elif expr.func is sp.Pow: new_arg = visit_expr(expr.args[0]) return expr.func(new_arg, expr.args[1]) elif expr.func == sp.Piecewise: new_results = [visit_expr(a[0]) for a in expr.args] new_conditions = [visit_expr(a[1]) for a in expr.args] types_of_results = [get_type_of_expression(a) for a in new_results] types_of_conditions = [ get_type_of_expression(a) for a in new_conditions ] result_target_type = get_type_of_expression(expr) condition_target_type = collate_types(types_of_conditions) if type(condition_target_type) is VectorType and type( result_target_type) is not VectorType: result_target_type = VectorType( result_target_type, width=condition_target_type.width) if type(condition_target_type) is not VectorType and type( result_target_type) is VectorType: condition_target_type = VectorType( condition_target_type, width=result_target_type.width) casted_results = [ cast_func(a, result_target_type) if t != result_target_type else a for a, t in zip(new_results, types_of_results) ] casted_conditions = [ cast_func(a, condition_target_type) if t != condition_target_type and a is not True else a for a, t in zip(new_conditions, types_of_conditions) ] return sp.Piecewise( *[(r, c) for r, c in zip(casted_results, casted_conditions)]) else: return expr
def test_dtype_of_constants(): # Some come constants are neither of type Integer,Float,Rational and don't have args # >>> isinstance(pi, Integer) # False # >>> isinstance(pi, Float) # False # >>> isinstance(pi, Rational) # False # >>> pi.args # () get_type_of_expression(sp.pi)
def __new__(cls, flag_bit, mask_expression, *expressions): flag_dtype = get_type_of_expression(flag_bit) if not flag_dtype.is_int(): raise ValueError('Argument flag_bit must be of integer type.') mask_dtype = get_type_of_expression(mask_expression) if not mask_dtype.is_int(): raise ValueError( 'Argument mask_expression must be of integer type.') return super().__new__(cls, flag_bit, mask_expression, *expressions)
def _print_cast_func(self, conversion): node = self._print(conversion.args[0]) to_dtype = get_type_of_expression(conversion) from_dtype = get_type_of_expression(conversion.args[0]) if from_dtype == to_dtype: return self._print(conversion.args[0]) # (From, to) decision = { (create_composite_type_from_string("int16"), create_composite_type_from_string("int64")): lambda: ir.Constant(self.integer, node), (create_composite_type_from_string("int"), create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type), (create_composite_type_from_string("int16"), create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type), (create_composite_type_from_string("double"), create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer), (create_composite_type_from_string("double *"), create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer), (create_composite_type_from_string("int"), create_composite_type_from_string("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), (create_composite_type_from_string("double * restrict"), create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer), (create_composite_type_from_string("int"), create_composite_type_from_string("double * restrict")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), (create_composite_type_from_string("double * restrict const"), create_composite_type_from_string("int")): functools.partial(self.builder.ptrtoint, node, self.integer), (create_composite_type_from_string("int"), create_composite_type_from_string("double * restrict const")): functools.partial(self.builder.inttoptr, node, self.fp_pointer), } # TODO float, TEST: const, restrict # TODO bitcast, addrspacecast # TODO unsigned/signed fills # print([x for x in decision.keys()]) # print("Types:") # print([(type(x), type(y)) for (x, y) in decision.keys()]) # print("Cast:") # print((from_dtype, to_dtype)) return decision[(from_dtype, to_dtype)]()
def _comparison(self, cmpop, expr): if collate_types([get_type_of_expression(arg) for arg in expr.args]) == create_type('double'): comparison = self.builder.fcmp_unordered else: comparison = self.builder.icmp_signed return comparison(cmpop, self._print(expr.lhs), self._print(expr.rhs))
def _scalarFallback(self, func_name, expr, *args, **kwargs): expr_type = get_type_of_expression(expr) if type(expr_type) is not VectorType: return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) else: assert self.instruction_set['width'] == expr_type.width return None
def visit_node(node, substitution_dict, default_type='double'): substitution_dict = substitution_dict.copy() for arg in node.args: if isinstance(arg, ast.SympyAssignment): assignment = arg subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) assignment.rhs = visit_expr(subs_expr, default_type) rhs_type = get_type_of_expression(assignment.rhs) if isinstance(assignment.lhs, TypedSymbol): lhs_type = assignment.lhs.dtype if type(rhs_type) is VectorType and type(lhs_type) is not VectorType: new_lhs_type = VectorType(lhs_type, rhs_type.width) new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type) substitution_dict[assignment.lhs] = new_lhs assignment.lhs = new_lhs elif isinstance(assignment.lhs, vector_memory_access): assignment.lhs = visit_expr(assignment.lhs, default_type) elif isinstance(arg, ast.Conditional): arg.condition_expr = fast_subs(arg.condition_expr, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) arg.condition_expr = visit_expr(arg.condition_expr, default_type) visit_node(arg, substitution_dict, default_type) else: visit_node(arg, substitution_dict, default_type)
def _print_Piecewise(self, expr): result = self._scalarFallback('_print_Piecewise', expr) if result: return result if expr.args[-1].cond.args[0] is not sp.sympify(True): # We need the last conditional to be a True, otherwise the resulting # function may not return a result. raise ValueError("All Piecewise expressions must contain an " "(expr, True) statement to be used as a default " "condition. Without one, the generated " "expression may not evaluate to anything under " "some condition.") result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): if isinstance(condition, cast_func) and get_type_of_expression( condition.args[0]) == create_type("bool"): if not KERNCRAFT_NO_TERNARY_MODE: result = "(({}) ? ({}) : ({}))".format( self._print(condition.args[0]), self._print(true_expr), result) else: print("Warning - skipping ternary op") else: # noinspection SpellCheckingInspection result = self.instruction_set['blendv'].format( result, self._print(true_expr), self._print(condition)) return result
def _print_Mul(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] if get_type_of_expression(expr) == create_type('double'): mul = self.builder.fmul else: # int TODO unsigned/signed mul = self.builder.mul for node in nodes[1:]: e = mul(e, node) return e
def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _ = expr.args instruction = self.instruction_set[ 'loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) elif isinstance(expr, cast_func): arg, data_type = expr.args if type(data_type) is VectorType: return self.instruction_set['makeVec'].format(self._print(arg)) elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: result = self.instruction_set['/'].format( self._print(expr.args[0]), self._print(expr.args[1])) return result elif expr.func == fast_sqrt: return "({})".format(self._print(sp.sqrt(expr.args[0]))) elif expr.func == fast_inv_sqrt: result = self._scalarFallback('_print_Function', expr) if not result: if self.instruction_set['rsqrt']: return self.instruction_set['rsqrt'].format( self._print(expr.args[0])) else: return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) elif isinstance(expr, vec_any): expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: return self._print(expr.args[0]) else: return self.instruction_set['any'].format( self._print(expr.args[0])) elif isinstance(expr, vec_all): expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: return self._print(expr.args[0]) else: return self.instruction_set['all'].format( self._print(expr.args[0])) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
def _print_Add(self, expr): nodes = [self._print(a) for a in expr.args] e = nodes[0] if get_type_of_expression(expr) == create_type('double'): add = self.builder.fadd else: # int TODO unsigned/signed add = self.builder.add for node in nodes[1:]: e = add(e, node) return e
def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: return self._typed_number(expr.evalf(17), get_type_of_expression(expr.base)) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})" elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" else: return super(CustomSympyPrinter, self)._print_Pow(expr)
def check_type(e): if only_type is None: return True try: base_type = get_base_type(get_type_of_expression(e)) except ValueError: return False if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): return True if only_type == 'real' and (base_type.is_float()): return True else: return base_type == only_type
def _print_Conditional(self, node): cond_type = get_type_of_expression(node.condition_expr) if isinstance(cond_type, VectorType): raise ValueError( "Problem with Conditional inside vectorized loop - use vec_any or vec_all" ) condition_expr = self.sympy_printer.doprint(node.condition_expr) true_block = self._print_Block(node.true_block) result = "if (%s)\n%s " % (condition_expr, true_block) if node.false_block: false_block = self._print_Block(node.false_block) result += "else " + false_block return result
def _print_Conditional(self, node): if type(node.condition_expr) is BooleanTrue: return self._print_Block(node.true_block) elif type(node.condition_expr) is BooleanFalse: return self._print_Block(node.false_block) cond_type = get_type_of_expression(node.condition_expr) if isinstance(cond_type, VectorType): raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") condition_expr = self.sympy_printer.doprint(node.condition_expr) true_block = self._print_Block(node.true_block) result = f"if ({condition_expr})\n{true_block} " if node.false_block: false_block = self._print_Block(node.false_block) result += f"else {false_block}" return result
def __new__(cls, arg1, arg2): args = [] for a in (arg1, arg2): if isinstance(a, sp.Number) or isinstance(a, int): args.append(cast_func(a, create_type("int"))) elif isinstance(a, np.generic): args.append(cast_func(a, a.dtype)) else: args.append(a) for a in args: try: type = get_type_of_expression(a) if not type.is_int(): raise ValueError("Argument to integer function is not an int but " + str(type)) except NotImplementedError: raise ValueError("Integer functions can only be constructed with typed expressions") return super().__new__(cls, *args)
def check_type(e): if only_type is None: return True if isinstance(e, FieldPointerSymbol) and only_type == "real": return only_type == "int" try: base_type = get_type_of_expression(e) except ValueError: return False if isinstance(base_type, VectorType): return False if isinstance(base_type, PointerType): return only_type == 'int' if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): return True if only_type == 'real' and (base_type.is_float()): return True else: return base_type == only_type
def _print_Product(self, expr): template = """[&]() {{ {dtype} product = ({dtype}) 1; for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{ product *= {expr}; }} return product; }}()""" var = expr.limits[0][0] start = expr.limits[0][1] end = expr.limits[0][2] code = template.format( dtype=get_type_of_expression(expr.args[0]), iterator_dtype='int', var=self._print(var), start=self._print(start), end=self._print(end), expr=self._print(expr.function), increment=str(1), condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' ) return code
def _print_Piecewise(self, piece): if not piece.args[-1].cond: # We need the last conditional to be a True, otherwise the resulting # function may not return a result. raise ValueError("All Piecewise expressions must contain an " "(expr, True) statement to be used as a default " "condition. Without one, the generated " "expression may not evaluate to anything under " "some condition.") if piece.has(Assignment): raise NotImplementedError( 'The llvm-backend does not support assignments' 'in the Piecewise function. It is questionable' 'whether to implement it. So far there is no' 'use-case to test it.') else: phi_data = [] after_block = self.builder.append_basic_block() for (expr, condition) in piece.args: if condition == sp.sympify(True): # Don't use 'is' use '=='! phi_data.append((self._print(expr), self.builder.block)) self.builder.branch(after_block) self.builder.position_at_end(after_block) else: cond = self._print(condition) true_block = self.builder.append_basic_block() false_block = self.builder.append_basic_block() self.builder.cbranch(cond, true_block, false_block) self.builder.position_at_end(true_block) phi_data.append((self._print(expr), true_block)) self.builder.branch(after_block) self.builder.position_at_end(false_block) phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece))) for (val, block) in phi_data: phi.add_incoming(val, block) return phi
def _print_SympyAssignment(self, node): if node.is_declaration: if node.use_auto: data_type = 'auto ' else: if node.is_const: prefix = 'const ' else: prefix = '' data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) printed_mask = "" if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): arg, data_type, aligned, nontemporal, mask, stride = node.lhs.args instr = 'storeU' if aligned: instr = 'stream' if nontemporal and 'stream' in self._vector_instruction_set else 'storeA' if mask != True: # NOQA instr = 'maskStoreA' if aligned else 'maskStoreU' if instr not in self._vector_instruction_set: self._vector_instruction_set[instr] = self._vector_instruction_set['store' + instr[-1]].format( '{0}', self._vector_instruction_set['blendv'].format( self._vector_instruction_set['load' + instr[-1]].format('{0}', **self._kwargs), '{1}', '{2}', **self._kwargs), **self._kwargs) printed_mask = self.sympy_printer.doprint(mask) if data_type.base_type.base_name == 'double': if self._vector_instruction_set['double'] == '__m256d': printed_mask = f"_mm256_castpd_si256({printed_mask})" elif self._vector_instruction_set['double'] == '__m128d': printed_mask = f"_mm_castpd_si128({printed_mask})" elif data_type.base_type.base_name == 'float': if self._vector_instruction_set['float'] == '__m256': printed_mask = f"_mm256_castps_si256({printed_mask})" elif self._vector_instruction_set['float'] == '__m128': printed_mask = f"_mm_castps_si128({printed_mask})" rhs_type = get_type_of_expression(node.rhs) if type(rhs_type) is not VectorType: rhs = cast_func(node.rhs, VectorType(rhs_type)) else: rhs = node.rhs ptr = "&" + self.sympy_printer.doprint(node.lhs.args[0]) if stride != 1: instr = 'maskStoreS' if mask != True else 'storeS' # NOQA return self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), stride, printed_mask, **self._kwargs) + ';' pre_code = '' if nontemporal and 'cachelineZero' in self._vector_instruction_set: first_cond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == 0" offset = sp.Add(*[sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i)) * node.lhs.args[0].field.spatial_strides[i] for i in range(len(node.lhs.args[0].field.spatial_strides))]) if stride == 1: offset = offset.subs({node.lhs.args[0].field.spatial_strides[0]: 1}) size = sp.Mul(*node.lhs.args[0].field.spatial_shape) element_size = 8 if data_type.base_type.base_name == 'double' else 4 size_cond = f"({offset} + {CachelineSize.symbol/element_size}) < {size}" pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \ self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n' code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs), printed_mask, **self._kwargs) + ';' flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}" if nontemporal and 'flushCacheline' in self._vector_instruction_set: code2 = self._vector_instruction_set['flushCacheline'].format( ptr, self.sympy_printer.doprint(rhs), **self._kwargs) + ';' code = f"{code}\nif ({flushcond}) {{\n\t{code2}\n}}" elif nontemporal and 'storeAAndFlushCacheline' in self._vector_instruction_set: tmpvar = '_tmp_' + hashlib.sha1(self.sympy_printer.doprint(rhs).encode('ascii')).hexdigest()[:8] code = 'const ' + self._print(node.lhs.dtype).replace(' const', '') + ' ' + tmpvar + ' = ' \ + self.sympy_printer.doprint(rhs) + ';' code1 = self._vector_instruction_set[instr].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';' code2 = self._vector_instruction_set['storeAAndFlushCacheline'].format(ptr, tmpvar, printed_mask, **self._kwargs) + ';' code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}" return pre_code + code else: return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
def visit_expr(expr, default_type='double'): if isinstance(expr, vector_memory_access): return vector_memory_access(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) elif isinstance(expr, cast_func): return expr elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0], default_type) base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is vector_memory_access \ else get_type_of_expression(expr.args[0]) pw = sp.Piecewise((-new_arg, new_arg < base_type.numpy_dtype.type(0)), (new_arg, True)) return visit_expr(pw, default_type) elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction): if expr.func is sp.Mul and expr.args[0] == -1: # special treatment for the unary minus: make sure that the -1 has the same type as the argument dtype = int for arg in expr.atoms(vector_memory_access): if arg.dtype.base_type.is_float(): dtype = arg.dtype.base_type.numpy_dtype.type for arg in expr.atoms(TypedSymbol): if type(arg.dtype) is VectorType and arg.dtype.base_type.is_float(): dtype = arg.dtype.base_type.numpy_dtype.type if dtype is not int: if dtype is np.float32: default_type = 'float' expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:]) new_args = [visit_expr(a, default_type) for a in expr.args] arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args] if not any(type(t) is VectorType for t in arg_types): return expr else: target_type = collate_types(arg_types) casted_args = [ cast_func(a, target_type) if t != target_type and not isinstance(a, vector_memory_access) else a for a, t in zip(new_args, arg_types)] return expr.func(*casted_args) elif expr.func is sp.Pow: new_arg = visit_expr(expr.args[0], default_type) return expr.func(new_arg, expr.args[1]) elif expr.func == sp.Piecewise: new_results = [visit_expr(a[0], default_type) for a in expr.args] new_conditions = [visit_expr(a[1], default_type) for a in expr.args] types_of_results = [get_type_of_expression(a) for a in new_results] types_of_conditions = [get_type_of_expression(a) for a in new_conditions] result_target_type = get_type_of_expression(expr) condition_target_type = collate_types(types_of_conditions) if type(condition_target_type) is VectorType and type(result_target_type) is not VectorType: result_target_type = VectorType(result_target_type, width=condition_target_type.width) if type(condition_target_type) is not VectorType and type(result_target_type) is VectorType: condition_target_type = VectorType(condition_target_type, width=result_target_type.width) casted_results = [cast_func(a, result_target_type) if t != result_target_type else a for a, t in zip(new_results, types_of_results)] casted_conditions = [cast_func(a, condition_target_type) if t != condition_target_type and a is not True else a for a, t in zip(new_conditions, types_of_conditions)] return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)]) else: return expr