コード例 #1
0
ファイル: propagation.py プロジェクト: fthaler/dace
    def can_be_applied(self, dim_exprs, variable_context, node_range,
                       orig_edges, dim_index, total_dims):
        # Pattern does not support unions of expressions. TODO: Support
        if len(dim_exprs) > 1: return False
        dexpr = dim_exprs[0]

        # Create a wildcard that excludes current map's parameters
        cst = sympy.Wild('cst', exclude=variable_context[-1])

        # Range case
        if isinstance(dexpr, tuple) and len(dexpr) == 3:
            # Try to match a constant expression for the range
            for rngelem in dexpr:
                if dtypes.isconstant(rngelem):
                    continue

                matches = rngelem.match(cst)
                if matches is None or len(matches) != 1:
                    return False
                if not matches[cst].is_constant():
                    return False

        else:  # Single element case
            # Try to match a constant expression
            if not dtypes.isconstant(dexpr):
                matches = dexpr.match(cst)
                if matches is None or len(matches) != 1:
                    return False
                if not matches[cst].is_constant():
                    return False

        return True
コード例 #2
0
def symstr(sym, arrayexprs: Optional[Set[str]] = None) -> str:
    """ 
    Convert a symbolic expression to a C++ compilable expression. 
    :param sym: Symbolic expression to convert.
    :param arrayexprs: Set of names of arrays, used to convert SymPy 
                       user-functions back to array expressions.
    :return: C++-compilable expression.
    """
    def repstr(s):
        return s.replace('Min', 'min').replace('Max', 'max')

    if isinstance(sym, SymExpr):
        return symstr(sym.expr, arrayexprs)

    try:
        sym = sympy_numeric_fix(sym)
        sym = sympy_intdiv_fix(sym)
        sym = sympy_divide_fix(sym)

        sstr = DaceSympyPrinter(arrayexprs).doprint(sym)

        if isinstance(sym,
                      symbol) or isinstance(sym, sympy.Symbol) or isinstance(
                          sym, sympy.Number) or dtypes.isconstant(sym):
            return repstr(sstr)
        else:
            return '(' + repstr(sstr) + ')'
    except (AttributeError, TypeError, ValueError):
        sstr = DaceSympyPrinter(arrayexprs).doprint(sym)
        return '(' + repstr(sstr) + ')'
コード例 #3
0
ファイル: labeling.py プロジェクト: HappySky2046/dace
    def match(self, expressions, variable_context, node_range, orig_edges):
        constant_range = True
        for dim in node_range:
            for rngelem in dim:  # For (begin, end, skip)
                if not dtypes.isconstant(rngelem) and not isinstance(
                        rngelem, sympy.Number):
                    constant_range = False
                    break
        if not constant_range:
            return False

        self.params = variable_context[-1]

        return True
コード例 #4
0
ファイル: symbolic.py プロジェクト: mratsim/dace
def symstr(sym):
    """ Convert a symbolic expression to a C++ compilable expression. """
    def repstr(s):
        return s.replace('Min', 'min').replace('Max', 'max')

    if isinstance(sym, SymExpr):
        return symstr(sym.expr)

    try:
        sym = sympy_numeric_fix(sym)
        sym = sympy_intdiv_fix(sym)
        sym = sympy_divide_fix(sym)

        sstr = DaceSympyPrinter().doprint(sym)

        if isinstance(sym,
                      symbol) or isinstance(sym, sympy.Symbol) or isinstance(
                          sym, sympy.Number) or dtypes.isconstant(sym):
            return repstr(sstr)
        else:
            return '(' + repstr(sstr) + ')'
    except (AttributeError, TypeError, ValueError):
        sstr = DaceSympyPrinter().doprint(sym)
        return '(' + repstr(sstr) + ')'
コード例 #5
0
    def global_value_to_node(self,
                             value,
                             parent_node,
                             qualname,
                             recurse=False,
                             detect_callables=False):
        # if recurse is false, we don't allow recursion into lists
        # this should not happen anyway; the globals dict should only contain
        # single "level" lists
        if not recurse and isinstance(value, (list, tuple)):
            # bail after more than one level of lists
            return None

        if isinstance(value, list):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.List(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, tuple):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, symbolic.symbol):
            # Symbols resolve to the symbol name
            newnode = ast.Name(id=value.name, ctx=ast.Load())
        elif (dtypes.isconstant(value) or isinstance(value, SDFG)
              or hasattr(value, '__sdfg__')):
            # Could be a constant, an SDFG, or SDFG-convertible object
            if isinstance(value, SDFG) or hasattr(value, '__sdfg__'):
                self.closure.closure_sdfgs[qualname] = value
            else:
                self.closure.closure_constants[qualname] = value

            # Compatibility check since Python changed their AST nodes
            if sys.version_info >= (3, 8):
                newnode = ast.Constant(value=value, kind='')
            else:
                if value is None:
                    newnode = ast.NameConstant(value=None)
                elif isinstance(value, str):
                    newnode = ast.Str(s=value)
                else:
                    newnode = ast.Num(n=value)

            newnode.oldnode = copy.deepcopy(parent_node)

        elif detect_callables and hasattr(value, '__call__') and hasattr(
                value.__call__, '__sdfg__'):
            return self.global_value_to_node(value.__call__, parent_node,
                                             qualname, recurse,
                                             detect_callables)
        elif isinstance(value, numpy.ndarray):
            # Arrays need to be stored as a new name and fed as an argument
            if id(value) in self.closure.array_mapping:
                arrname = self.closure.array_mapping[id(value)]
            else:
                arrname = self._qualname_to_array_name(qualname)
                desc = data.create_datadescriptor(value)
                self.closure.closure_arrays[arrname] = (
                    qualname, desc, lambda: eval(qualname, self.globals),
                    False)
                self.closure.array_mapping[id(value)] = arrname

            newnode = ast.Name(id=arrname, ctx=ast.Load())
        elif detect_callables and callable(value):
            # Try parsing the function as a dace function/method
            newnode = None
            try:
                from dace.frontend.python import parser  # Avoid import loops

                parent_object = None
                if hasattr(value, '__self__'):
                    parent_object = value.__self__

                # If it is a callable object
                if (not inspect.isfunction(value)
                        and not inspect.ismethod(value)
                        and not inspect.isbuiltin(value)
                        and hasattr(value, '__call__')):
                    parent_object = value
                    value = value.__call__

                # Replacements take precedence over auto-parsing
                try:
                    if has_replacement(value, parent_object, parent_node):
                        return None
                except Exception:
                    pass

                # Store the handle to the original callable, in case parsing fails
                cbqualname = astutils.rname(parent_node)
                cbname = self._qualname_to_array_name(cbqualname, prefix='')
                self.closure.callbacks[cbname] = (cbqualname, value, False)

                # From this point on, any failure will result in a callback
                newnode = ast.Name(id=cbname, ctx=ast.Load())

                # Decorated or functions with missing source code
                sast, _, _, _ = astutils.function_to_ast(value)
                if len(sast.body[0].decorator_list) > 0:
                    return newnode

                parsed = parser.DaceProgram(value, [], {}, False,
                                            dtypes.DeviceType.CPU)
                # If method, add the first argument (which disappears due to
                # being a bound method) and the method's object
                if parent_object is not None:
                    parsed.methodobj = parent_object
                    parsed.objname = inspect.getfullargspec(value).args[0]

                res = self.global_value_to_node(parsed, parent_node, qualname,
                                                recurse, detect_callables)
                # Keep callback in callbacks in case of parsing failure
                # del self.closure.callbacks[cbname]
                return res
            except Exception:  # Parsing failed (almost any exception can occur)
                return newnode
        else:
            return None

        if parent_node is not None:
            return ast.copy_location(newnode, parent_node)
        else:
            return newnode