def generate_kernels(self): """Generator for ReductionKernel objects. Yields: ReductionKernel: Each object represents one instance of a BLAS-like kernel. """ # TODO using dictionaries for arg_vals does not make a lot of sense # the lookup functionality is never used! types = [type(variant) for variant in self.variants] # iterating over expression variants for arg, expr in self.expr_variant.arg_vals.items(): arguments_copy1 = [ copy.copy(arg) for arg in self.arguments ] # copy is not enough, deepcopy is not necessary if self.expr_variant.arg_name is not None: # This kernel does have expression variants. # TODO I don't like the previous comment. arguments_copy1.append( ConstantArgument(self.expr_variant.arg_name, arg)) # Dealing with the remaining kernel variants # Step 1: Generating all combinations. for arg_val_pairs in itertools.product( *[variant.arg_vals.items() for variant in self.variants]): expr_copy2 = expr arguments_copy2 = [ copy.copy(arg) for arg in arguments_copy1 ] # copy is not enough, deepcopy is not necessary constraints_dict = dict() kernel_io = KernelIO() # Step 2: For each combination, iterating over all kernel variant objects. for i, _type in enumerate(types): if _type == PropertyKV: # Adding property constraints to constraint object arg, property = arg_val_pairs[i] # Properties can not be stored on the symbol, because # the same Symbol instance is used for multiple # generated kernels (the alternative would be to # traverse expr_copy2, which contains a copy of the # symbol). constraints_dict.setdefault( self.variants[i].wildcard_name, set()).add(property) arguments_copy2.append( ConstantArgument(self.variants[i].arg_name, arg)) elif _type == DefaultValueKV: # replacing symbol with default value arg, value = arg_val_pairs[i] if arg is not None: expr_copy2 = replace_symbol( expr_copy2, self.variants[i].operand, value) kernel_io.add_input( self.variants[i].operand, value, storage_format.StorageFormat.full) elif _type == OperatorKV: # replacing operator placeholder with actual operator arg, new_operator = arg_val_pairs[i] expr_copy2 = replace_operator( expr_copy2, self.variants[i].operator, new_operator) for argument in arguments_copy2: if argument.operand: argument.operand = replace_operator( argument.operand, self.variants[i].operator, new_operator) arguments_copy2.append( ConstantArgument(self.variants[i].arg_name, arg)) expr_copy2 = simplify(expr_copy2) remaining_operands = [] seen_before = set() for node, _ in expr_copy2.preorder_iter(): if isinstance(node, Symbol) and node.name not in seen_before: remaining_operands.append(node) # print(expr_copy2) # print(self.signature) # print(remaining_operands) # print(kernel_io) # print(self.return_value.operand, self.return_value.storage_format) # print(self.wildcards) # Add property constraints (including Matrix, Vector and Scalar) # This has to be done after simplifying to make sure that only # those operands are included that actually show up in the # expression. constraints_list = [] for operand in remaining_operands: _wildcard_name = self.wildcards[operand.name].variable_name if operand.properties or _wildcard_name in constraints_dict: property_set = operand.properties.union( constraints_dict.get(_wildcard_name, set())) # Those properties can be removed because variables use symbol_type property_set.difference_update( (Property.MATRIX, Property.VECTOR, Property.SCALAR)) if property_set: constraints_list.append( PropertyConstraint(_wildcard_name, property_set)) if self.constraints: operand_names = set(op.name for op in remaining_operands) renaming = { name: self.wildcards[name].variable_name for name in operand_names } for constraint in self.constraints: if constraint.variables <= operand_names: constraints_list.append( constraint.with_renamed_vars(renaming)) remaining_wildcards = [ self.wildcards[operand.name] for operand in remaining_operands ] remaining_input_operands = [] for input_operand in self.input_operands: if input_operand.operand in remaining_wildcards: remaining_input_operands.append(input_operand) # print([io.operand for io in remaining_input_operands]) # Replace symbols with wildcards in the expression expr_copy2 = replace_symbols(expr_copy2, self.wildcards) # print(constraints_list) # print(self.wildcards) # print(repr(expr_copy2)) # print(constraints) kernel_io.replace_variables(self.wildcards) for kernel_type in self.kernel_types: if kernel_type is KernelType.scaling: expr_copy3 = Times(*reversed(expr_copy2.operands)) else: expr_copy3 = expr_copy2 yield ReductionKernel( matchpy.Pattern(expr_copy3, *constraints_list), remaining_input_operands, self.return_value, self.cost_function, self.pre_code, self.signature, self.post_code, arguments_copy2, kernel_type, kernel_io)
def set_match(self, match_dict, context, blocked_products=False, set_equivalent=True, equiv_expr=None): matched_kernel = super().set_match(match_dict) ############# # operation # I don't like this part. I would prefer not to use it at all and always # use equivalent expression (even for matrix chain). It's exclusively # a performance consideration. if equiv_expr: # equiv_expr = self.replacement_template.replace_copy({"_op": equiv_expr}) equiv_expr = matchpy.substitute(self.replacement_template, {"_op": equiv_expr}) equiv_expr = simplify(equiv_expr) # _operation = self.operation_template.replace_copy(match_dict) _operation = matchpy.substitute(self.operation_template, match_dict) _tmp = temporaries.create_tmp(_operation, set_equivalent, equiv_expr) matched_kernel.operation = Equal(_tmp, _operation) ############# # operand_dict & kernel_io kernel_io = copy.deepcopy(self.kernel_io) operand_dict = dict() for input_operand in self.input_operands: matched_operand = match_dict[input_operand.operand.variable_name] operand_dict[input_operand.operand.variable_name] = matched_operand kernel_io.add_input(input_operand.operand, matched_operand, input_operand.storage_format) kernel_io.add_output(self.output_operand.operand, _tmp, self.output_operand.storage_format) if not self.is_overwriting: operand_dict[self.output_operand.operand.variable_name] = _tmp matched_kernel.kernel_io = kernel_io matched_kernel.operand_dict = operand_dict ############# # Replacement if context: # _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}+match_dict)[0] # Plugging in tmp _replacement = matchpy.substitute( self.replacement_with_context_template, {"_op": _tmp}) # Plugging in tmp _replacement = matchpy.substitute( _replacement, match_dict) # Plugging in context else: # _replacement = self.replacement_template.replace_copy({"_op": _tmp}) _replacement = matchpy.substitute(self.replacement_template, {"_op": _tmp}) matched_kernel.replacement = _replacement ############# # Other replacements # TODO This is language dependent matched_kernel.other_replacements = { "type": config.data_type_string, "type_prefix": config.blas_data_type_prefix, } ############# # Blocked products # Not relevant for reductions. return matched_kernel