def _set_match(self, match_dict, input_expr): """Auxiliary function for set_match() Computes only those things that are independent of whether temporaries are reused or not. """ # Constructing input. kernel_io = KernelIO() for input_operand in self.input_operands: kernel_io.add_input( input_operand.operand, match_dict[input_operand.operand.variable_name], input_operand.storage_format) _arg_dict = dict() for arg in self.arguments: if isinstance(arg, SizeArgument): _arg_dict[arg.name] = arg.get_value(match_dict) _partial_operand_dict = dict() # replacement_dict maps wildcard names to operands replacement_dict = dict() for output_operand in self.output_operands: # output_operand.operand.name[1:] because it's a Wildcard, we drop the _ name = "".join([ output_operand.operand.variable_name[1:], temporaries.get_identifier() ]) size = (_arg_dict[output_operand.size[0]], _arg_dict[output_operand.size[1]]) # TODO what if the output is a scalar? Check sizes. operand = Matrix(name, size, input_expr.indices) operand.set_property(Property.FACTOR) operand.factorization_labels = set( operand[0].name for operand in kernel_io.input_operands) for property in output_operand.properties: operand.set_property(property) replacement_dict[output_operand.operand.variable_name] = operand # Constructing output. if output_operand.overwriting: kernel_io.add_output(output_operand.overwriting, operand, output_operand.storage_format) else: kernel_io.add_output(output_operand.operand, operand, output_operand.storage_format) _partial_operand_dict[ output_operand.operand.variable_name] = operand _output_expr = matchpy.substitute(self.replacement_template, replacement_dict) return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
def get_value(self, match_dict, memory=None): # size = self.operand.replace_copy(match_dict).size # TODO don't use substitute here size = matchpy.substitute(self.operand, match_dict).size if self.dimension == "rows": return size[0] elif self.dimension == "columns": return size[1] elif self.dimension == "entries": return size[0]*size[1] else: raise ValueError("{} is not a valid dimension.".format(self.dimension))
register(Total(x), lambda x: Pi(Shape(x))) class GetBySubstituting(matchpy.Operation): """ GetBySubstituting(variable_name, form) """ name = "->" arity = matchpy.Arity(2, True) infix = True register( Get(x, GetBySubstituting(scalar_accessor, x1)), lambda x, scalar_accessor, x1: matchpy.substitute( x1, {scalar_accessor.value: x}), ) _counter = 0 def get_index_accessor(): """ Returns index variable and fn of array -> GetBySubstituting """ global _counter variable_name = f"idx_{_counter}" _counter += 1 idx_variable = UnboundAccessor(variable_name=variable_name) return ( idx_variable,
def set_match(self, match_dict, context): matched_kernel = super().set_match(match_dict) ############# # operation # Constructing the input expression _input_expr = matchpy.substitute(self.operation_template, match_dict) # create output expression and arg_dict try: op_dict = temporaries._table_of_factors[self.id] except KeyError: # if there is no dict for current factorization, create # everything and store them in new dict ops = self._set_match(match_dict, _input_expr) temporaries._table_of_factors[self.id] = {_input_expr: ops} else: try: ops = op_dict[_input_expr] except KeyError: # if there is nothing stored for this match, create and store everything ops = self._set_match(match_dict, _input_expr) op_dict[_input_expr] = ops _output_expr, _arg_dict, _partial_operand_dict, kernel_io = ops # print(_partial_operand_dict) matched_kernel.operation = Equal(_output_expr, _input_expr) ############# # operand_dict matched_kernel.kernel_io = kernel_io matched_kernel.operand_dict = copy.copy(match_dict) matched_kernel.operand_dict.update(_partial_operand_dict) # print(matched_kernel.operand_dict) # print(match_dict, _output_expr, _output, matched_kernel.operand_dict) ############# # Replacement if context: # When this gets implemented, don't forget to remove context variables from match_dict raise NotImplementedError() else: _replacement = _output_expr matched_kernel.replacement = _replacement ############# # Other replacements matched_kernel.other_replacements = { "type": config.data_type_string, "type_prefix": config.blas_data_type_prefix, # TODO this is language dependent "work_id": hex(hash(self.signature))[-5:] } return matched_kernel
def _set_match(self, match_dict, input_expr): """Auxiliary function for set_match() Computes only those things that are independent of whether temporaries are reused or not. """ # Constructing input. kernel_io = KernelIO() for input_operand in self.input_operands: kernel_io.add_input( input_operand.operand, match_dict[input_operand.operand.variable_name], input_operand.storage_format) _arg_dict = dict() for arg in self.arguments: if isinstance(arg, SizeArgument): _arg_dict[arg.name] = arg.get_value(match_dict) _partial_operand_dict = dict() # replacement_dict maps wildcard names to operands replacement_dict = dict() for output_operand in self.output_operands: # output_operand.operand.name[1:] because it's a Wildcard, we drop the _ name = "".join([ output_operand.operand.variable_name[1:], temporaries.get_identifier() ]) size = (_arg_dict[output_operand.size[0]], _arg_dict[output_operand.size[1]]) # TODO what if the output is a scalar? Check sizes. operand = Matrix(name, size, input_expr.indices) operand.set_property(Property.FACTOR) operand.factorization_labels = set( operand[0].name for operand in kernel_io.input_operands) for property in output_operand.properties: operand.set_property(property) replacement_dict[output_operand.operand.variable_name] = operand # Constructing output. if output_operand.overwriting: kernel_io.add_output(output_operand.overwriting, operand, output_operand.storage_format) else: kernel_io.add_output(output_operand.operand, operand, output_operand.storage_format) _partial_operand_dict[ output_operand.operand.variable_name] = operand _output_expr = matchpy.substitute(self.replacement_template, replacement_dict) input_equiv = temporaries.get_equivalent(input_expr) temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern(Times(ctx1, _output_expr, ctx2)), lambda ctx1, ctx2: Times(*ctx1, input_equiv, *ctx2))) if input_expr.has_property(Property.SQUARE): temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern(Times(ctx1, invert(_output_expr), ctx2)), lambda ctx1, ctx2: Times(*ctx1, invert(input_equiv), *ctx2 ))) # There is no need to generate transposed pattern for factorizations # with symmetric output; Cholesky (id 0) and Eigen (id 4). if self.id in {1, 2, 3, 5, 6, 7}: temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern(Times(ctx1, transpose(_output_expr), ctx2)), lambda ctx1, ctx2: Times(*ctx1, transpose(input_equiv), *ctx2))) if input_expr.has_property(Property.SQUARE): temporaries.equivalence_replacer.add( matchpy.ReplacementRule( matchpy.Pattern( Times(ctx1, invert_transpose(_output_expr), ctx2)), lambda ctx1, ctx2: Times( *ctx1, invert_transpose(input_equiv), *ctx2))) return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
def set_match(self, match_dict, context, blocked_products=False, set_equivalent=True, equiv_expr=None): matched_kernel = super().set_match(match_dict) ############# # operation # I don't like this part. I would prefer not to use it at all and always # use equivalent expression (even for matrix chain). It's exclusively # a performance consideration. if equiv_expr: # equiv_expr = self.replacement_template.replace_copy({"_op": equiv_expr}) equiv_expr = matchpy.substitute(self.replacement_template, {"_op": equiv_expr}) equiv_expr = simplify(equiv_expr) # _operation = self.operation_template.replace_copy(match_dict) _operation = matchpy.substitute(self.operation_template, match_dict) _tmp = temporaries.create_tmp(_operation, set_equivalent, equiv_expr) matched_kernel.operation = Equal(_tmp, _operation) ############# # operand_dict & kernel_io kernel_io = copy.deepcopy(self.kernel_io) operand_dict = dict() for input_operand in self.input_operands: matched_operand = match_dict[input_operand.operand.variable_name] operand_dict[input_operand.operand.variable_name] = matched_operand kernel_io.add_input(input_operand.operand, matched_operand, input_operand.storage_format) kernel_io.add_output(self.output_operand.operand, _tmp, self.output_operand.storage_format) if not self.is_overwriting: operand_dict[self.output_operand.operand.variable_name] = _tmp matched_kernel.kernel_io = kernel_io matched_kernel.operand_dict = operand_dict ############# # Replacement if context: # _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}+match_dict)[0] # Plugging in tmp _replacement = matchpy.substitute( self.replacement_with_context_template, {"_op": _tmp}) # Plugging in tmp _replacement = matchpy.substitute( _replacement, match_dict) # Plugging in context else: # _replacement = self.replacement_template.replace_copy({"_op": _tmp}) _replacement = matchpy.substitute(self.replacement_template, {"_op": _tmp}) matched_kernel.replacement = _replacement ############# # Other replacements # TODO This is language dependent matched_kernel.other_replacements = { "type": config.data_type_string, "type_prefix": config.blas_data_type_prefix, } ############# # Blocked products # Not relevant for reductions. return matched_kernel
def set_match(self, match_dict, context, blocked_products=False): matched_kernel = super().set_match(match_dict) ############# # operation _operation = matchpy.substitute(self.operation_template, match_dict) _tmp = temporaries.create_tmp(_operation) matched_kernel.operation = Equal(_tmp, _operation) ############# # operand_dict & kernel_io kernel_io = copy.deepcopy(self.kernel_io) operand_dict = dict() for input_operand in self.input_operands: matched_operand = match_dict[input_operand.operand.variable_name] operand_dict[input_operand.operand.variable_name] = matched_operand kernel_io.add_input(input_operand.operand, matched_operand, input_operand.storage_format) kernel_io.add_output(self.output_operand.operand, _tmp, self.output_operand.storage_format) if not self.is_overwriting: operand_dict[self.output_operand.operand.variable_name] = _tmp matched_kernel.kernel_io = kernel_io matched_kernel.operand_dict = operand_dict ############# # Replacement if context: # _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}+match_dict)[0] # Plugging in tmp _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}) # Plugging in tmp _replacement = matchpy.substitute(_replacement, match_dict) # Plugging in context else: # _replacement = self.replacement_template.replace_copy({"_op": _tmp}) _replacement = matchpy.substitute(self.replacement_template, {"_op": _tmp}) matched_kernel.replacement = _replacement ############# # Other replacements # TODO This is language dependent matched_kernel.other_replacements = { "type": config.data_type_string, "type_prefix": config.blas_data_type_prefix, } ############# # Blocked products # Not relevant for kernels. return matched_kernel