Exemple #1
0
    def generate_kernels(self):
        """Generator for ReductionKernel objects.

        Yields:
            ReductionKernel: Each object represents one instance of a BLAS-like
                kernel.
        """
        # TODO using dictionaries for arg_vals does not make a lot of sense
        # the lookup functionality is never used!

        types = [type(variant) for variant in self.variants]
        # iterating over expression variants
        for arg, expr in self.expr_variant.arg_vals.items():

            arguments_copy1 = [
                copy.copy(arg) for arg in self.arguments
            ]  # copy is not enough, deepcopy is not necessary

            if self.expr_variant.arg_name is not None:
                # This kernel does have expression variants.
                # TODO I don't like the previous comment.
                arguments_copy1.append(
                    ConstantArgument(self.expr_variant.arg_name, arg))

            # Dealing with the remaining kernel variants
            # Step 1: Generating all combinations.
            for arg_val_pairs in itertools.product(
                    *[variant.arg_vals.items() for variant in self.variants]):
                expr_copy2 = expr
                arguments_copy2 = [
                    copy.copy(arg) for arg in arguments_copy1
                ]  # copy is not enough, deepcopy is not necessary

                constraints_dict = dict()

                kernel_io = KernelIO()

                # Step 2: For each combination, iterating over all kernel variant objects.
                for i, _type in enumerate(types):
                    if _type == PropertyKV:
                        # Adding property constraints to constraint object
                        arg, property = arg_val_pairs[i]
                        # Properties can not be stored on the symbol, because
                        # the same Symbol instance is used for multiple
                        # generated kernels (the alternative would be to
                        # traverse expr_copy2, which contains a copy of the
                        # symbol).
                        constraints_dict.setdefault(
                            self.variants[i].wildcard_name,
                            set()).add(property)
                        arguments_copy2.append(
                            ConstantArgument(self.variants[i].arg_name, arg))
                    elif _type == DefaultValueKV:
                        # replacing symbol with default value
                        arg, value = arg_val_pairs[i]
                        if arg is not None:
                            expr_copy2 = replace_symbol(
                                expr_copy2, self.variants[i].operand, value)
                            kernel_io.add_input(
                                self.variants[i].operand, value,
                                storage_format.StorageFormat.full)
                    elif _type == OperatorKV:
                        # replacing operator placeholder with actual operator
                        arg, new_operator = arg_val_pairs[i]
                        expr_copy2 = replace_operator(
                            expr_copy2, self.variants[i].operator,
                            new_operator)
                        for argument in arguments_copy2:
                            if argument.operand:
                                argument.operand = replace_operator(
                                    argument.operand,
                                    self.variants[i].operator, new_operator)
                        arguments_copy2.append(
                            ConstantArgument(self.variants[i].arg_name, arg))

                expr_copy2 = simplify(expr_copy2)

                remaining_operands = []
                seen_before = set()
                for node, _ in expr_copy2.preorder_iter():
                    if isinstance(node,
                                  Symbol) and node.name not in seen_before:
                        remaining_operands.append(node)

                # print(expr_copy2)
                # print(self.signature)
                # print(remaining_operands)
                # print(kernel_io)
                # print(self.return_value.operand, self.return_value.storage_format)
                # print(self.wildcards)

                # Add property constraints (including Matrix, Vector and Scalar)
                # This has to be done after simplifying to make sure that only
                # those operands are included that actually show up in the
                # expression.
                constraints_list = []
                for operand in remaining_operands:
                    _wildcard_name = self.wildcards[operand.name].variable_name
                    if operand.properties or _wildcard_name in constraints_dict:
                        property_set = operand.properties.union(
                            constraints_dict.get(_wildcard_name, set()))
                        # Those properties can be removed because variables use symbol_type
                        property_set.difference_update(
                            (Property.MATRIX, Property.VECTOR,
                             Property.SCALAR))
                        if property_set:
                            constraints_list.append(
                                PropertyConstraint(_wildcard_name,
                                                   property_set))

                if self.constraints:
                    operand_names = set(op.name for op in remaining_operands)
                    renaming = {
                        name: self.wildcards[name].variable_name
                        for name in operand_names
                    }
                    for constraint in self.constraints:
                        if constraint.variables <= operand_names:
                            constraints_list.append(
                                constraint.with_renamed_vars(renaming))

                remaining_wildcards = [
                    self.wildcards[operand.name]
                    for operand in remaining_operands
                ]

                remaining_input_operands = []
                for input_operand in self.input_operands:
                    if input_operand.operand in remaining_wildcards:
                        remaining_input_operands.append(input_operand)

                # print([io.operand for io in remaining_input_operands])

                # Replace symbols with wildcards in the expression
                expr_copy2 = replace_symbols(expr_copy2, self.wildcards)

                # print(constraints_list)

                # print(self.wildcards)
                # print(repr(expr_copy2))
                # print(constraints)

                kernel_io.replace_variables(self.wildcards)
                for kernel_type in self.kernel_types:

                    if kernel_type is KernelType.scaling:
                        expr_copy3 = Times(*reversed(expr_copy2.operands))
                    else:
                        expr_copy3 = expr_copy2

                    yield ReductionKernel(
                        matchpy.Pattern(expr_copy3, *constraints_list),
                        remaining_input_operands, self.return_value,
                        self.cost_function, self.pre_code, self.signature,
                        self.post_code, arguments_copy2, kernel_type,
                        kernel_io)
Exemple #2
0
    def set_match(self,
                  match_dict,
                  context,
                  blocked_products=False,
                  set_equivalent=True,
                  equiv_expr=None):

        matched_kernel = super().set_match(match_dict)

        #############
        # operation

        # I don't like this part. I would prefer not to use it at all and always
        # use equivalent expression (even for matrix chain). It's exclusively
        # a performance consideration.
        if equiv_expr:
            # equiv_expr = self.replacement_template.replace_copy({"_op": equiv_expr})
            equiv_expr = matchpy.substitute(self.replacement_template,
                                            {"_op": equiv_expr})
            equiv_expr = simplify(equiv_expr)

        # _operation = self.operation_template.replace_copy(match_dict)
        _operation = matchpy.substitute(self.operation_template, match_dict)
        _tmp = temporaries.create_tmp(_operation, set_equivalent, equiv_expr)

        matched_kernel.operation = Equal(_tmp, _operation)

        #############
        # operand_dict & kernel_io

        kernel_io = copy.deepcopy(self.kernel_io)

        operand_dict = dict()
        for input_operand in self.input_operands:
            matched_operand = match_dict[input_operand.operand.variable_name]
            operand_dict[input_operand.operand.variable_name] = matched_operand
            kernel_io.add_input(input_operand.operand, matched_operand,
                                input_operand.storage_format)

        kernel_io.add_output(self.output_operand.operand, _tmp,
                             self.output_operand.storage_format)
        if not self.is_overwriting:
            operand_dict[self.output_operand.operand.variable_name] = _tmp

        matched_kernel.kernel_io = kernel_io

        matched_kernel.operand_dict = operand_dict

        #############
        # Replacement

        if context:
            # _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}+match_dict)[0] # Plugging in tmp
            _replacement = matchpy.substitute(
                self.replacement_with_context_template,
                {"_op": _tmp})  # Plugging in tmp
            _replacement = matchpy.substitute(
                _replacement, match_dict)  # Plugging in context
        else:
            # _replacement = self.replacement_template.replace_copy({"_op": _tmp})
            _replacement = matchpy.substitute(self.replacement_template,
                                              {"_op": _tmp})

        matched_kernel.replacement = _replacement

        #############
        # Other replacements

        # TODO This is language dependent
        matched_kernel.other_replacements = {
            "type": config.data_type_string,
            "type_prefix": config.blas_data_type_prefix,
        }

        #############
        # Blocked products

        # Not relevant for reductions.

        return matched_kernel