def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, dim_index, total_dims): # Pattern does not support unions of expressions. TODO: Support if len(dim_exprs) > 1: return False dexpr = dim_exprs[0] # Create a wildcard that excludes current map's parameters cst = sympy.Wild('cst', exclude=variable_context[-1]) # Range case if isinstance(dexpr, tuple) and len(dexpr) == 3: # Try to match a constant expression for the range for rngelem in dexpr: if dtypes.isconstant(rngelem): continue matches = rngelem.match(cst) if matches is None or len(matches) != 1: return False if not matches[cst].is_constant(): return False else: # Single element case # Try to match a constant expression if not dtypes.isconstant(dexpr): matches = dexpr.match(cst) if matches is None or len(matches) != 1: return False if not matches[cst].is_constant(): return False return True
def symstr(sym, arrayexprs: Optional[Set[str]] = None) -> str: """ Convert a symbolic expression to a C++ compilable expression. :param sym: Symbolic expression to convert. :param arrayexprs: Set of names of arrays, used to convert SymPy user-functions back to array expressions. :return: C++-compilable expression. """ def repstr(s): return s.replace('Min', 'min').replace('Max', 'max') if isinstance(sym, SymExpr): return symstr(sym.expr, arrayexprs) try: sym = sympy_numeric_fix(sym) sym = sympy_intdiv_fix(sym) sym = sympy_divide_fix(sym) sstr = DaceSympyPrinter(arrayexprs).doprint(sym) if isinstance(sym, symbol) or isinstance(sym, sympy.Symbol) or isinstance( sym, sympy.Number) or dtypes.isconstant(sym): return repstr(sstr) else: return '(' + repstr(sstr) + ')' except (AttributeError, TypeError, ValueError): sstr = DaceSympyPrinter(arrayexprs).doprint(sym) return '(' + repstr(sstr) + ')'
def match(self, expressions, variable_context, node_range, orig_edges): constant_range = True for dim in node_range: for rngelem in dim: # For (begin, end, skip) if not dtypes.isconstant(rngelem) and not isinstance( rngelem, sympy.Number): constant_range = False break if not constant_range: return False self.params = variable_context[-1] return True
def symstr(sym): """ Convert a symbolic expression to a C++ compilable expression. """ def repstr(s): return s.replace('Min', 'min').replace('Max', 'max') if isinstance(sym, SymExpr): return symstr(sym.expr) try: sym = sympy_numeric_fix(sym) sym = sympy_intdiv_fix(sym) sym = sympy_divide_fix(sym) sstr = DaceSympyPrinter().doprint(sym) if isinstance(sym, symbol) or isinstance(sym, sympy.Symbol) or isinstance( sym, sympy.Number) or dtypes.isconstant(sym): return repstr(sstr) else: return '(' + repstr(sstr) + ')' except (AttributeError, TypeError, ValueError): sstr = DaceSympyPrinter().doprint(sym) return '(' + repstr(sstr) + ')'
def global_value_to_node(self, value, parent_node, qualname, recurse=False, detect_callables=False): # if recurse is false, we don't allow recursion into lists # this should not happen anyway; the globals dict should only contain # single "level" lists if not recurse and isinstance(value, (list, tuple)): # bail after more than one level of lists return None if isinstance(value, list): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.List(elts=elts, ctx=parent_node.ctx) elif isinstance(value, tuple): elts = [ self.global_value_to_node(v, parent_node, qualname + f'[{i}]', detect_callables=detect_callables) for i, v in enumerate(value) ] if any(e is None for e in elts): return None newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx) elif isinstance(value, symbolic.symbol): # Symbols resolve to the symbol name newnode = ast.Name(id=value.name, ctx=ast.Load()) elif (dtypes.isconstant(value) or isinstance(value, SDFG) or hasattr(value, '__sdfg__')): # Could be a constant, an SDFG, or SDFG-convertible object if isinstance(value, SDFG) or hasattr(value, '__sdfg__'): self.closure.closure_sdfgs[qualname] = value else: self.closure.closure_constants[qualname] = value # Compatibility check since Python changed their AST nodes if sys.version_info >= (3, 8): newnode = ast.Constant(value=value, kind='') else: if value is None: newnode = ast.NameConstant(value=None) elif isinstance(value, str): newnode = ast.Str(s=value) else: newnode = ast.Num(n=value) newnode.oldnode = copy.deepcopy(parent_node) elif detect_callables and hasattr(value, '__call__') and hasattr( value.__call__, '__sdfg__'): return self.global_value_to_node(value.__call__, parent_node, qualname, recurse, detect_callables) elif isinstance(value, numpy.ndarray): # Arrays need to be stored as a new name and fed as an argument if id(value) in self.closure.array_mapping: arrname = self.closure.array_mapping[id(value)] else: arrname = self._qualname_to_array_name(qualname) desc = data.create_datadescriptor(value) self.closure.closure_arrays[arrname] = ( qualname, desc, lambda: eval(qualname, self.globals), False) self.closure.array_mapping[id(value)] = arrname newnode = ast.Name(id=arrname, ctx=ast.Load()) elif detect_callables and callable(value): # Try parsing the function as a dace function/method newnode = None try: from dace.frontend.python import parser # Avoid import loops parent_object = None if hasattr(value, '__self__'): parent_object = value.__self__ # If it is a callable object if (not inspect.isfunction(value) and not inspect.ismethod(value) and not inspect.isbuiltin(value) and hasattr(value, '__call__')): parent_object = value value = value.__call__ # Replacements take precedence over auto-parsing try: if has_replacement(value, parent_object, parent_node): return None except Exception: pass # Store the handle to the original callable, in case parsing fails cbqualname = astutils.rname(parent_node) cbname = self._qualname_to_array_name(cbqualname, prefix='') self.closure.callbacks[cbname] = (cbqualname, value, False) # From this point on, any failure will result in a callback newnode = ast.Name(id=cbname, ctx=ast.Load()) # Decorated or functions with missing source code sast, _, _, _ = astutils.function_to_ast(value) if len(sast.body[0].decorator_list) > 0: return newnode parsed = parser.DaceProgram(value, [], {}, False, dtypes.DeviceType.CPU) # If method, add the first argument (which disappears due to # being a bound method) and the method's object if parent_object is not None: parsed.methodobj = parent_object parsed.objname = inspect.getfullargspec(value).args[0] res = self.global_value_to_node(parsed, parent_node, qualname, recurse, detect_callables) # Keep callback in callbacks in case of parsing failure # del self.closure.callbacks[cbname] return res except Exception: # Parsing failed (almost any exception can occur) return newnode else: return None if parent_node is not None: return ast.copy_location(newnode, parent_node) else: return newnode