Beispiel #1
0
    def _set_match(self, match_dict, input_expr):
        """Auxiliary function for set_match()

        Computes only those things that are independent of whether temporaries
        are reused or not.
        """

        # Constructing input.
        kernel_io = KernelIO()
        for input_operand in self.input_operands:
            kernel_io.add_input(
                input_operand.operand,
                match_dict[input_operand.operand.variable_name],
                input_operand.storage_format)

        _arg_dict = dict()
        for arg in self.arguments:
            if isinstance(arg, SizeArgument):
                _arg_dict[arg.name] = arg.get_value(match_dict)

        _partial_operand_dict = dict()

        # replacement_dict maps wildcard names to operands
        replacement_dict = dict()
        for output_operand in self.output_operands:
            # output_operand.operand.name[1:] because it's a Wildcard, we drop the _
            name = "".join([
                output_operand.operand.variable_name[1:],
                temporaries.get_identifier()
            ])
            size = (_arg_dict[output_operand.size[0]],
                    _arg_dict[output_operand.size[1]])
            # TODO what if the output is a scalar? Check sizes.

            operand = Matrix(name, size, input_expr.indices)
            operand.set_property(Property.FACTOR)
            operand.factorization_labels = set(
                operand[0].name for operand in kernel_io.input_operands)
            for property in output_operand.properties:
                operand.set_property(property)

            replacement_dict[output_operand.operand.variable_name] = operand

            # Constructing output.
            if output_operand.overwriting:
                kernel_io.add_output(output_operand.overwriting, operand,
                                     output_operand.storage_format)
            else:
                kernel_io.add_output(output_operand.operand, operand,
                                     output_operand.storage_format)
                _partial_operand_dict[
                    output_operand.operand.variable_name] = operand

        _output_expr = matchpy.substitute(self.replacement_template,
                                          replacement_dict)

        return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
Beispiel #2
0
 def get_value(self, match_dict, memory=None):
     # size = self.operand.replace_copy(match_dict).size
     # TODO don't use substitute here
     size = matchpy.substitute(self.operand, match_dict).size
     if self.dimension == "rows":
         return size[0]
     elif self.dimension == "columns":
         return size[1]
     elif self.dimension == "entries":
         return size[0]*size[1]
     else:
         raise ValueError("{} is not a valid dimension.".format(self.dimension))
Beispiel #3
0
register(Total(x), lambda x: Pi(Shape(x)))


class GetBySubstituting(matchpy.Operation):
    """
    GetBySubstituting(variable_name, form)
    """

    name = "->"
    arity = matchpy.Arity(2, True)
    infix = True


register(
    Get(x, GetBySubstituting(scalar_accessor, x1)),
    lambda x, scalar_accessor, x1: matchpy.substitute(
        x1, {scalar_accessor.value: x}),
)

_counter = 0


def get_index_accessor():
    """
    Returns index variable and fn of array -> GetBySubstituting 
    """
    global _counter
    variable_name = f"idx_{_counter}"
    _counter += 1
    idx_variable = UnboundAccessor(variable_name=variable_name)
    return (
        idx_variable,
Beispiel #4
0
    def set_match(self, match_dict, context):

        matched_kernel = super().set_match(match_dict)

        #############
        # operation

        # Constructing the input expression
        _input_expr = matchpy.substitute(self.operation_template, match_dict)

        # create output expression and arg_dict
        try:
            op_dict = temporaries._table_of_factors[self.id]
        except KeyError:
            # if there is no dict for current factorization, create
            # everything and store them in new dict
            ops = self._set_match(match_dict, _input_expr)
            temporaries._table_of_factors[self.id] = {_input_expr: ops}
        else:
            try:
                ops = op_dict[_input_expr]
            except KeyError:
                # if there is nothing stored for this match, create and store everything
                ops = self._set_match(match_dict, _input_expr)
                op_dict[_input_expr] = ops

        _output_expr, _arg_dict, _partial_operand_dict, kernel_io = ops
        # print(_partial_operand_dict)

        matched_kernel.operation = Equal(_output_expr, _input_expr)

        #############
        # operand_dict

        matched_kernel.kernel_io = kernel_io

        matched_kernel.operand_dict = copy.copy(match_dict)
        matched_kernel.operand_dict.update(_partial_operand_dict)
        # print(matched_kernel.operand_dict)
        # print(match_dict, _output_expr, _output, matched_kernel.operand_dict)

        #############
        # Replacement

        if context:
            # When this gets implemented, don't forget to remove context variables from match_dict
            raise NotImplementedError()
        else:
            _replacement = _output_expr

        matched_kernel.replacement = _replacement

        #############
        # Other replacements

        matched_kernel.other_replacements = {
            "type": config.data_type_string,
            "type_prefix":
            config.blas_data_type_prefix,  # TODO this is language dependent
            "work_id": hex(hash(self.signature))[-5:]
        }

        return matched_kernel
Beispiel #5
0
    def _set_match(self, match_dict, input_expr):
        """Auxiliary function for set_match()

        Computes only those things that are independent of whether temporaries
        are reused or not.
        """

        # Constructing input.
        kernel_io = KernelIO()
        for input_operand in self.input_operands:
            kernel_io.add_input(
                input_operand.operand,
                match_dict[input_operand.operand.variable_name],
                input_operand.storage_format)

        _arg_dict = dict()
        for arg in self.arguments:
            if isinstance(arg, SizeArgument):
                _arg_dict[arg.name] = arg.get_value(match_dict)

        _partial_operand_dict = dict()

        # replacement_dict maps wildcard names to operands
        replacement_dict = dict()
        for output_operand in self.output_operands:
            # output_operand.operand.name[1:] because it's a Wildcard, we drop the _
            name = "".join([
                output_operand.operand.variable_name[1:],
                temporaries.get_identifier()
            ])
            size = (_arg_dict[output_operand.size[0]],
                    _arg_dict[output_operand.size[1]])
            # TODO what if the output is a scalar? Check sizes.

            operand = Matrix(name, size, input_expr.indices)
            operand.set_property(Property.FACTOR)
            operand.factorization_labels = set(
                operand[0].name for operand in kernel_io.input_operands)
            for property in output_operand.properties:
                operand.set_property(property)

            replacement_dict[output_operand.operand.variable_name] = operand

            # Constructing output.
            if output_operand.overwriting:
                kernel_io.add_output(output_operand.overwriting, operand,
                                     output_operand.storage_format)
            else:
                kernel_io.add_output(output_operand.operand, operand,
                                     output_operand.storage_format)
                _partial_operand_dict[
                    output_operand.operand.variable_name] = operand

        _output_expr = matchpy.substitute(self.replacement_template,
                                          replacement_dict)

        input_equiv = temporaries.get_equivalent(input_expr)
        temporaries.equivalence_replacer.add(
            matchpy.ReplacementRule(
                matchpy.Pattern(Times(ctx1, _output_expr, ctx2)),
                lambda ctx1, ctx2: Times(*ctx1, input_equiv, *ctx2)))
        if input_expr.has_property(Property.SQUARE):
            temporaries.equivalence_replacer.add(
                matchpy.ReplacementRule(
                    matchpy.Pattern(Times(ctx1, invert(_output_expr), ctx2)),
                    lambda ctx1, ctx2: Times(*ctx1, invert(input_equiv), *ctx2
                                             )))
        # There is no need to generate transposed pattern for factorizations
        # with symmetric output; Cholesky (id 0) and Eigen (id 4).
        if self.id in {1, 2, 3, 5, 6, 7}:
            temporaries.equivalence_replacer.add(
                matchpy.ReplacementRule(
                    matchpy.Pattern(Times(ctx1, transpose(_output_expr),
                                          ctx2)), lambda ctx1, ctx2:
                    Times(*ctx1, transpose(input_equiv), *ctx2)))
            if input_expr.has_property(Property.SQUARE):
                temporaries.equivalence_replacer.add(
                    matchpy.ReplacementRule(
                        matchpy.Pattern(
                            Times(ctx1, invert_transpose(_output_expr), ctx2)),
                        lambda ctx1, ctx2: Times(
                            *ctx1, invert_transpose(input_equiv), *ctx2)))

        return _output_expr, _arg_dict, _partial_operand_dict, kernel_io
Beispiel #6
0
    def set_match(self,
                  match_dict,
                  context,
                  blocked_products=False,
                  set_equivalent=True,
                  equiv_expr=None):

        matched_kernel = super().set_match(match_dict)

        #############
        # operation

        # I don't like this part. I would prefer not to use it at all and always
        # use equivalent expression (even for matrix chain). It's exclusively
        # a performance consideration.
        if equiv_expr:
            # equiv_expr = self.replacement_template.replace_copy({"_op": equiv_expr})
            equiv_expr = matchpy.substitute(self.replacement_template,
                                            {"_op": equiv_expr})
            equiv_expr = simplify(equiv_expr)

        # _operation = self.operation_template.replace_copy(match_dict)
        _operation = matchpy.substitute(self.operation_template, match_dict)
        _tmp = temporaries.create_tmp(_operation, set_equivalent, equiv_expr)

        matched_kernel.operation = Equal(_tmp, _operation)

        #############
        # operand_dict & kernel_io

        kernel_io = copy.deepcopy(self.kernel_io)

        operand_dict = dict()
        for input_operand in self.input_operands:
            matched_operand = match_dict[input_operand.operand.variable_name]
            operand_dict[input_operand.operand.variable_name] = matched_operand
            kernel_io.add_input(input_operand.operand, matched_operand,
                                input_operand.storage_format)

        kernel_io.add_output(self.output_operand.operand, _tmp,
                             self.output_operand.storage_format)
        if not self.is_overwriting:
            operand_dict[self.output_operand.operand.variable_name] = _tmp

        matched_kernel.kernel_io = kernel_io

        matched_kernel.operand_dict = operand_dict

        #############
        # Replacement

        if context:
            # _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}+match_dict)[0] # Plugging in tmp
            _replacement = matchpy.substitute(
                self.replacement_with_context_template,
                {"_op": _tmp})  # Plugging in tmp
            _replacement = matchpy.substitute(
                _replacement, match_dict)  # Plugging in context
        else:
            # _replacement = self.replacement_template.replace_copy({"_op": _tmp})
            _replacement = matchpy.substitute(self.replacement_template,
                                              {"_op": _tmp})

        matched_kernel.replacement = _replacement

        #############
        # Other replacements

        # TODO This is language dependent
        matched_kernel.other_replacements = {
            "type": config.data_type_string,
            "type_prefix": config.blas_data_type_prefix,
        }

        #############
        # Blocked products

        # Not relevant for reductions.

        return matched_kernel
Beispiel #7
0
    def set_match(self, match_dict, context, blocked_products=False):
        
        matched_kernel = super().set_match(match_dict)

        #############
        # operation

        _operation = matchpy.substitute(self.operation_template, match_dict)
        _tmp = temporaries.create_tmp(_operation)

        matched_kernel.operation = Equal(_tmp, _operation)

        #############
        # operand_dict & kernel_io

        kernel_io = copy.deepcopy(self.kernel_io)

        operand_dict = dict()
        for input_operand in self.input_operands:
            matched_operand = match_dict[input_operand.operand.variable_name]
            operand_dict[input_operand.operand.variable_name] = matched_operand
            kernel_io.add_input(input_operand.operand, matched_operand, input_operand.storage_format)

        kernel_io.add_output(self.output_operand.operand, _tmp, self.output_operand.storage_format)
        if not self.is_overwriting:
            operand_dict[self.output_operand.operand.variable_name] = _tmp

        matched_kernel.kernel_io = kernel_io

        matched_kernel.operand_dict = operand_dict


        #############
        # Replacement

        if context:
            # _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}+match_dict)[0] # Plugging in tmp
            _replacement = matchpy.substitute(self.replacement_with_context_template, {"_op": _tmp}) # Plugging in tmp
            _replacement = matchpy.substitute(_replacement, match_dict) # Plugging in context
        else:
            # _replacement = self.replacement_template.replace_copy({"_op": _tmp})
            _replacement = matchpy.substitute(self.replacement_template, {"_op": _tmp})

        matched_kernel.replacement = _replacement

        #############
        # Other replacements

        # TODO This is language dependent
        matched_kernel.other_replacements = {
            "type": config.data_type_string,
            "type_prefix": config.blas_data_type_prefix,
            }


        #############
        # Blocked products

        # Not relevant for kernels.   

        return matched_kernel