Exemple #1
0
def parse_memlet(visitor, src: MemletType, dst: MemletType,
                 defined_arrays_and_symbols: Dict[str, data.Data]):
    srcexpr, dstexpr, localvar = None, None, None
    if isinstance(src,
                  ast.Name) and rname(src) not in defined_arrays_and_symbols:
        localvar = rname(src)
    else:
        srcexpr = ParseMemlet(visitor, defined_arrays_and_symbols, src)
    if isinstance(dst,
                  ast.Name) and rname(dst) not in defined_arrays_and_symbols:
        if localvar is not None:
            raise DaceSyntaxError(
                visitor, src,
                'Memlet source and destination cannot both be local variables')
        localvar = rname(dst)
    else:
        dstexpr = ParseMemlet(visitor, defined_arrays_and_symbols, dst)

    if srcexpr is not None and dstexpr is not None:
        # Create two memlets
        raise NotImplementedError
    elif srcexpr is not None:
        expr = srcexpr
    else:
        expr = dstexpr

    return localvar, Memlet(expr.name,
                            expr.accesses,
                            expr.subset,
                            1,
                            wcr=expr.wcr)
Exemple #2
0
    def _Call(self, t):
        res_type = self.infer(t)[0]
        if not res_type:
            raise util.NotSupportedError(f'Unsupported call')

        if not isinstance(res_type, dtypes.vector):
            # Call does not involve any vectors (to our knowledge)
            # Replace default modules (e.g., math) with dace::math::
            attr_name = astutils.rname(t)
            module_name = attr_name[:attr_name.rfind(".")]
            func_name = attr_name[attr_name.rfind(".") + 1:]
            if module_name not in dtypes._ALLOWED_MODULES:
                raise NotImplementedError(
                    f'Module {module_name} is not implemented')
            cpp_mod_name = dtypes._ALLOWED_MODULES[module_name]
            name = cpp_mod_name + func_name

            self.write(name)
            self.write('(')

            comma = False
            for e in t.args:
                if comma:
                    self.write(", ")
                else:
                    comma = True
                self.dispatch(e)
            self.write(')')
            return

        name = None
        if isinstance(t.func, ast.Name):
            # Could be an internal operation (provided by the preprocessor)
            if not util.is_sve_internal(t.func.id):
                raise NotImplementedError(
                    f'Function {t.func.id} is not implemented')
            name = util.internal_to_external(t.func.id)[0]
        elif isinstance(t.func, ast.Attribute):
            # Some module function (xxx.xxx), make sure it is available
            name = util.MATH_FUNCTION_TO_SVE.get(astutils.rname(t.func))
            if name is None:
                raise NotImplementedError(
                    f'Function {astutils.rname(t.func)} is not implemented')

        # Vectorized function
        self.write('{}_x({}, '.format(name, self.pred_name))
        comma = False
        for e in t.args:
            if comma:
                self.write(", ")
            else:
                comma = True
            self.dispatch_expect(e, res_type)
        self.write(')')
Exemple #3
0
    def visit_TopLevelExpr(self, node):
        if isinstance(node.value, ast.BinOp):
            if isinstance(node.value.op, ast.LShift):
                # Obtain memlet metadata and clean AST node from memlet syntax
                cleaned_right, dynamic, _ = self._clean_memlet(
                    node.value.right)

                # Replace "a << A[i]" with "a = A[i]" at the beginning
                if not dynamic:
                    storenode = copy.deepcopy(node.value.left)
                    storenode.ctx = ast.Store()
                    self.pre_statements.append(
                        _copy_location(
                            ast.Assign(targets=[storenode],
                                       value=cleaned_right), node))
                else:
                    # In-place replacement
                    self.name_replacements[rname(
                        node.value.left)] = cleaned_right
                return None  # Remove from final tasklet code
            elif isinstance(node.value.op, ast.RShift):
                # Obtain memlet metadata and clean AST node from memlet syntax
                cleaned_right, dynamic, wcr = self._clean_memlet(
                    node.value.right)

                # Replace "a >> A[i]" with "A[i] = a" at the end
                if not dynamic:
                    rhs = node.value.left
                    if wcr is not None:
                        # If WCR is involved, change expression to include
                        # lambda: "A[i] = (lambda a,b: a+b)(A[i], a)"
                        rhs = _copy_location(
                            ast.Call(func=wcr,
                                     args=[cleaned_right, rhs],
                                     keywords=[]), rhs)

                    lhs = copy.deepcopy(cleaned_right)
                    lhs.ctx = ast.Store()
                    self.post_statements.append(
                        _copy_location(ast.Assign(targets=[lhs], value=rhs),
                                       node))
                else:
                    if wcr is not None:
                        # Replace Assignments with lambda every time
                        self.wcr_replacements[rname(
                            node.value.left)] = (cleaned_right, wcr)
                    else:
                        # In-place replacement
                        self.assign_replacements[rname(
                            node.value.left)] = cleaned_right

                return None  # Remove from final tasklet code

        return self.generic_visit(node)
Exemple #4
0
def has_replacement(callobj: Callable,
                    parent_object: Optional[Any] = None,
                    node: Optional[ast.AST] = None) -> bool:
    """
    Returns True if the function/operator replacement repository
    has a registered replacement for the called function described by
    a live object.
    """
    from dace.frontend.common import op_repository as oprepo

    # Nothing from the `dace` namespace needs preprocessing
    mod = None
    try:
        mod = callobj.__module__
    except AttributeError:
        try:
            mod = parent_object.__module__
        except AttributeError:
            pass
    if mod and (mod == 'dace' or mod.startswith('dace.') or mod == 'math'
                or mod.startswith('math.')):
        return True

    # Attributes and methods
    classname = None
    if parent_object is not None:
        classname = type(parent_object).__name__
        attrname = callobj.__name__
        repl = oprepo.Replacements.get_attribute(classname, attrname)
        if repl is not None:
            return True
        repl = oprepo.Replacements.get_method(classname, attrname)
        if repl is not None:
            return True

    # NumPy ufuncs
    if (isinstance(callobj, numpy.ufunc)
            or isinstance(parent_object, numpy.ufunc)):
        return True

    # Functions
    # Special case: Constructor method (e.g., numpy.ndarray)
    if classname == "type":
        cbqualname = astutils.rname(node)
        if oprepo.Replacements.get(cbqualname) is not None:
            return True
    full_func_name = callobj.__module__ + '.' + callobj.__qualname__
    if oprepo.Replacements.get(full_func_name) is not None:
        return True

    # Also try the function as it is called in the AST
    return oprepo.Replacements.get(astutils.rname(node)) is not None
Exemple #5
0
    def visit_Expr(self, node):
        # Check for DaCe function calls
        if isinstance(node.value, ast.Call):
            # Some calls should not be parsed
            if rname(node.value.func) == "define_local":
                return None
            elif rname(node.value.func) == "define_local_scalar":
                return None
            elif rname(node.value.func) == "define_stream":
                return None
            elif rname(node.value.func) == "define_streamarray":
                return None

        return self.generic_visit(node)
Exemple #6
0
def has_replacement(callobj: Callable,
                    parent_object: Optional[Any] = None,
                    node: Optional[ast.AST] = None) -> bool:
    """
    Returns True if the function/operator replacement repository
    has a registered replacement for the called function described by
    a live object.
    """
    from dace.frontend.common import op_repository as oprepo

    # Attributes and methods
    if parent_object is not None:
        classname = type(parent_object).__name__
        attrname = callobj.__name__
        repl = oprepo.Replacements.get_attribute(classname, attrname)
        if repl is not None:
            return True
        repl = oprepo.Replacements.get_method(classname, attrname)
        if repl is not None:
            return True

    # NumPy ufuncs
    if isinstance(callobj, numpy.ufunc):
        return True

    # Functions
    full_func_name = callobj.__module__ + '.' + callobj.__qualname__
    if oprepo.Replacements.get(full_func_name) is not None:
        return True

    # Also try the function as it is called in the AST
    return oprepo.Replacements.get(astutils.rname(node)) is not None
Exemple #7
0
 def visit_Subscript(self, node: ast.Subscript) -> Any:
     # Convert subscript to symbol name
     node_name = astutils.rname(node)
     if node_name in self.conn_to_sym:
         return ast.copy_location(
             ast.Name(id=self.conn_to_sym[node_name], ctx=ast.Load()), node)
     return self.generic_visit(node)
Exemple #8
0
 def visit_Subscript(self, node: ast.Subscript) -> Any:
     # Convert subscript to symbol name
     node_name = astutils.rname(node)
     if node_name in self.iconns:
         self.latest[node_name] += 1
         new_name = f'{node_name}_{self.latest[node_name]}'
         subset = subsets.Range(
             astutils.subscript_to_slice(node, self.sdfg.arrays)[1])
         # Check if range can be collapsed
         if _range_is_promotable(subset, self.defined):
             self.in_mapping[new_name] = (node_name, subset)
             return ast.copy_location(ast.Name(id=new_name, ctx=ast.Load()),
                                      node)
         else:
             self.do_not_remove.add(node_name)
     elif node_name in self.oconns:
         self.latest[node_name] += 1
         new_name = f'{node_name}_{self.latest[node_name]}'
         subset = subsets.Range(
             astutils.subscript_to_slice(node, self.sdfg.arrays)[1])
         # Check if range can be collapsed
         if _range_is_promotable(subset, self.defined):
             self.out_mapping[new_name] = (node_name, subset)
             return ast.copy_location(
                 ast.Name(id=new_name, ctx=ast.Store()), node)
         else:
             self.do_not_remove.add(node_name)
     return self.generic_visit(node)
Exemple #9
0
    def visit_Assign(self, node):
        if rname(node.targets[0]) in self.accumOnAssignment:
            var_name = rname(node.targets[0])
            array_name, accum = self.accumOnAssignment[var_name]
            if isinstance(node.targets[0], ast.Subscript):
                array_name += '[' + unparse(node.targets[0].slice) + ']'
            if '[' not in array_name:
                array_name += '[:]'

            newnode = ast.parse('{out} = {accum}({out}, {val})'.format(
                out=array_name, accum=unparse(accum),
                val=unparse(node.value))).body[0]
            newnode = _copy_location(newnode, node)
            return newnode

        return self.generic_visit(node)
Exemple #10
0
 def visit_Attribute(self, node):
     attrname = rname(node)
     module_name = attrname[:attrname.rfind(".")]
     func_name = attrname[attrname.rfind(".") + 1:]
     if module_name in dtypes._ALLOWED_MODULES:
         cppmodname = dtypes._ALLOWED_MODULES[module_name]
         return ast.copy_location(
             ast.Name(id=(cppmodname + func_name), ctx=ast.Load), node)
     return self.generic_visit(node)
Exemple #11
0
    def visit_Call(self, node: ast.Call) -> Any:
        if isinstance(node.func, ast.Attribute):
            # Special case: calling attributed functions on constants (e.g., dace.int64(2))
            if (len(node.args) == 1 and astutils.is_constant(node.args[0])
                    and astutils.rname(node.func.value) == 'dace'):
                return self.generic_visit(node)

            self.detected = True
            return
        return self.generic_visit(node)
Exemple #12
0
 def visit_Name(self, node: ast.Name):
     name = rname(node)
     if name not in self.memlets:
         return self.generic_visit(node)
     memlet, nc, wcr, dtype = self.memlets[name]
     if (isinstance(dtype, dtypes.pointer)
             and memlet.subset.num_elements() == 1):
         return ast.Name(id="(*{})".format(name), ctx=node.ctx)
     else:
         return self.generic_visit(node)
Exemple #13
0
    def visit_AugAssign(self, node):
        if not isinstance(node.target, ast.Subscript):
            return self.generic_visit(node)

        target = rname(node.target)
        if target not in self.memlets:
            return self.generic_visit(node)

        raise SyntaxError("Augmented assignments (e.g. +=) not allowed on " +
                          "array memlets")
Exemple #14
0
    def visit_Assign(self, node):
        target = rname(node.targets[0])
        if target not in self.memlets:
            return self.generic_visit(node)

        memlet, nc, wcr = self.memlets[target]
        value = self.visit(node.value)

        if not isinstance(node.targets[0], ast.Subscript):
            # Dynamic accesses -> every access counts
            try:
                if memlet is not None and memlet.num_accesses < 0:
                    if wcr is not None:
                        newnode = ast.Name(id=write_and_resolve_expr(
                            self.sdfg, memlet, nc, '__' + target,
                            cppunparse.cppunparse(value,
                                                  expr_semicolon=False)))
                    else:
                        newnode = ast.Name(id="__%s.write(%s);" % (
                            target,
                            cppunparse.cppunparse(value, expr_semicolon=False),
                        ))

                    return ast.copy_location(newnode, node)
            except TypeError:  # cannot determine truth value of Relational
                pass

            return self.generic_visit(node)

        slice = self.visit(node.targets[0].slice)
        if not isinstance(slice, ast.Index):
            raise NotImplementedError("Range subscripting not implemented")

        if isinstance(slice.value, ast.Tuple):
            subscript = unparse(slice)[1:-1]
        else:
            subscript = unparse(slice)

        if wcr is not None:
            newnode = ast.Name(id=write_and_resolve_expr(
                self.sdfg,
                memlet,
                nc,
                "__" + target,
                cppunparse.cppunparse(value, expr_semicolon=False),
                indices=subscript,
            ))
        else:
            newnode = ast.Name(id="__%s.write(%s, %s);" % (
                target,
                cppunparse.cppunparse(value, expr_semicolon=False),
                subscript,
            ))

        return ast.copy_location(newnode, node)
Exemple #15
0
    def visit_Call(self, node: ast.Call):
        # Only parse calls to parsed SDFGConvertibles
        if not isinstance(node.func, (ast.Num, ast.Constant)):
            self.seen_calls.add(astutils.rname(node.func))
            return self.generic_visit(node)
        if hasattr(node.func, 'oldnode'):
            if isinstance(node.func.oldnode, ast.Call):
                self.seen_calls.add(astutils.rname(node.func.oldnode.func))
            else:
                self.seen_calls.add(astutils.rname(node.func.oldnode))
        if isinstance(node.func, ast.Num):
            value = node.func.n
        else:
            value = node.func.value

        if not hasattr(value, '__sdfg__') or isinstance(value, SDFG):
            return self.generic_visit(node)

        constant_args = self._eval_args(node)

        # Resolve nested closure as necessary
        qualname = None
        try:
            qualname = next(k for k, v in self.closure.closure_sdfgs.items()
                            if v is value)
            self.seen_calls.add(qualname)
            if hasattr(value, 'closure_resolver'):
                self.closure.nested_closures.append(
                    (qualname,
                     value.closure_resolver(constant_args, self.closure)))
            else:
                self.closure.nested_closures.append((qualname, SDFGClosure()))
        except DaceRecursionError:  # Parsing failed in a nested context, raise
            raise
        except Exception as ex:  # Parsing failed (anything can happen here)
            warnings.warn(f'Parsing SDFGConvertible {value} failed: {ex}')
            if qualname in self.closure.closure_sdfgs:
                del self.closure.closure_sdfgs[qualname]
            # Return old call AST instead
            node.func = node.func.oldnode.func

            return self.generic_visit(node)
Exemple #16
0
    def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
        if astutils.rname(node.value) == self.to_refine:
            rng = subsets.Range(
                astutils.subscript_to_slice(node,
                                            self.sdfg.arrays,
                                            without_array=True))
            rng.offset(self.subset, True, self.indices)
            return ast.copy_location(
                astutils.slice_to_subscript(self.to_refine, rng), node)

        return self.generic_visit(node)
Exemple #17
0
    def visit_Call(self, node: ast.Call):
        # Only parse calls to parsed SDFGConvertibles
        if not isinstance(node.func, (ast.Num, ast.Constant)):
            self.seen_calls.add(astutils.unparse(node.func))
            return self.generic_visit(node)
        if hasattr(node.func, 'oldnode'):
            if isinstance(node.func.oldnode, ast.Call):
                self.seen_calls.add(astutils.unparse(node.func.oldnode.func))
            else:
                self.seen_calls.add(astutils.rname(node.func.oldnode))
        if isinstance(node.func, ast.Num):
            value = node.func.n
        else:
            value = node.func.value

        if not hasattr(value, '__sdfg__') or isinstance(value, SDFG):
            return self.generic_visit(node)

        constant_args = self._eval_args(node)

        # Resolve nested closure as necessary
        qualname = None
        try:
            if id(value) in self.closure.closure_sdfgs:
                qualname, _ = self.closure.closure_sdfgs[id(value)]
            elif hasattr(node.func, 'qualname'):
                qualname = node.func.qualname
            self.seen_calls.add(qualname)
            if hasattr(value, 'closure_resolver'):
                self.closure.nested_closures.append(
                    (qualname,
                     value.closure_resolver(constant_args, self.closure)))
            else:
                self.closure.nested_closures.append((qualname, SDFGClosure()))
        except DaceRecursionError:  # Parsing failed in a nested context, raise
            raise
        except Exception as ex:  # Parsing failed (anything can happen here)
            optional_qname = ''
            if qualname is not None:
                optional_qname = f' ("{qualname}")'
            warnings.warn(
                f'Preprocessing SDFGConvertible {value}{optional_qname} failed with {type(ex).__name__}: {ex}'
            )
            if Config.get_bool('frontend', 'raise_nested_parsing_errors'):
                raise
            if id(value) in self.closure.closure_sdfgs:
                del self.closure.closure_sdfgs[id(value)]
            # Return old call AST instead
            if not hasattr(node.func, 'oldnode'):
                raise
            node.func = node.func.oldnode.func

            return self.generic_visit(node)
Exemple #18
0
    def visit_Call(self, node: ast.Call):
        # Struct initializer
        name = astutils.rname(node.func)
        if name not in self._structs:
            return self.generic_visit(node)

        # Parse name and fields
        struct = self._structs[name]
        name = struct.name
        fields = {astutils.rname(arg.arg): arg.value for arg in node.keywords}
        if tuple(sorted(fields.keys())) != tuple(sorted(struct.fields.keys())):
            raise SyntaxError('Mismatch in fields in struct definition')

        # Create custom node
        #new_node = astutils.StructInitializer(name, fields)
        #return ast.copy_location(new_node, node)

        node.func = ast.copy_location(
            ast.Name(id='__DACESTRUCT_' + name, ctx=ast.Load()), node.func)

        return node
Exemple #19
0
    def visit_Subscript(self, node):
        target = rname(node)
        if target not in self.memlets and target not in self.constants:
            return self.generic_visit(node)

        subscript = self._subscript_expr(node.slice, target)

        # New subscript is created as a name AST object (rather than a
        # subscript), as otherwise the visitor will recursively descend into
        # the new expression and modify it erroneously.
        newnode = ast.Name(id="%s[%s]" % (target, sym2cpp(subscript)))

        return ast.copy_location(newnode, node)
Exemple #20
0
    def visit_Call(self, node):
        if isinstance(node.func,
                      ast.Name) and (node.func.id.startswith('__DACESTRUCT_')
                                     or node.func.id in self._structs):
            fields = ', '.join([
                '.%s = %s' % (rname(arg.arg), cppunparse.pyexpr2cpp(arg.value))
                for arg in sorted(node.keywords, key=lambda x: x.arg)
            ])

            tname = node.func.id
            if node.func.id.startswith('__DACESTRUCT_'):
                tname = node.func.id[len('__DACESTRUCT_'):]

            return ast.copy_location(
                ast.Name(id="%s { %s }" % (tname, fields), ctx=ast.Load), node)

        return self.generic_visit(node)
Exemple #21
0
def ParseMemlet(visitor,
                defined_arrays_and_symbols: Dict[str, Any],
                node: MemletType,
                parsed_slice: Any = None) -> MemletExpr:
    das = defined_arrays_and_symbols
    arrname = rname(node)
    if arrname not in das:
        raise DaceSyntaxError(visitor, node,
                              'Use of undefined data "%s" in memlet' % arrname)
    array = das[arrname]

    # Determine number of accesses to the memlet (default is the slice size)
    num_accesses = None
    write_conflict_resolution = None
    # Detects expressions of the form "A(2)[...]", "A(300)", "A(1, sum)[:]"
    if isinstance(node, ast.Call):
        if len(node.args) < 1 or len(node.args) > 3:
            raise DaceSyntaxError(
                visitor, node,
                'Number of accesses in memlet must be a number, symbolic '
                'expression, or -1 (dynamic)')
        num_accesses = pyexpr_to_symbolic(das, node.args[0])
        if len(node.args) >= 2:
            write_conflict_resolution = node.args[1]
    elif isinstance(node, ast.Subscript) and isinstance(node.value, ast.Call):
        if len(node.value.args) < 1 or len(node.value.args) > 3:
            raise DaceSyntaxError(
                visitor, node,
                'Number of accesses in memlet must be a number, symbolic '
                'expression, or -1 (dynamic)')
        num_accesses = pyexpr_to_symbolic(das, node.value.args[0])
        if len(node.value.args) >= 2:
            write_conflict_resolution = node.value.args[1]

    subset, new_axes, arrdims = parse_memlet_subset(array, node, das,
                                                    parsed_slice)

    # If undefined, default number of accesses is the slice size
    if num_accesses is None:
        num_accesses = subset.num_elements()

    return MemletExpr(arrname, num_accesses, write_conflict_resolution, subset,
                      new_axes, arrdims)
Exemple #22
0
    def visit_Subscript(self, node):
        target = rname(node)
        if target not in self.memlets and target not in self.constants:
            return self.generic_visit(node)

        slice = self.visit(node.slice)
        if not isinstance(slice, ast.Index):
            raise NotImplementedError("Range subscripting not implemented")

        if isinstance(slice.value, ast.Tuple):
            subscript = unparse(slice)[1:-1]
        else:
            subscript = unparse(slice)

        if target in self.constants:
            slice_str = DaCeKeywordRemover.ndslice_cpp(
                subscript.split(", "), self.constants[target].shape)
            newnode = ast.parse("%s[%s]" % (target, slice_str)).body[0].value
        else:
            newnode = ast.parse("__%s(%s)" % (target, subscript)).body[0].value
        return ast.copy_location(newnode, node)
Exemple #23
0
    def visit_FunctionDef(self, node):
        after_nodes = []

        if self.curprim is None:
            self.curprim = self.pdp
            self.curchild = -1
            if isinstance(node.decorator_list[0], ast.Call):
                self.module_name = node.decorator_list[0].func.value.id
            else:
                self.module_name = node.decorator_list[0].value.id
            # Strip decorator
            del node.decorator_list[0]

            oldchild = self.curchild
            oldprim = self.curprim

        else:
            if len(node.decorator_list) == 0:
                return self.generic_visit(node)
            dec = node.decorator_list[0]
            if isinstance(dec, ast.Call):
                decname = rname(dec.func.attr)
            else:
                decname = rname(dec.attr)

            if decname in [
                    'map', 'reduce', 'consume', 'tasklet', 'iterate', 'loop',
                    'conditional'
            ]:
                self.curchild += 1

                oldchild = self.curchild
                oldprim = self.curprim
                self.curprim = self.curprim.children[self.curchild]
                self.curchild = -1

                if isinstance(self.curprim, astnodes._MapNode):
                    newnode = \
                        _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(),
                                                    elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]),
                                                    iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value,
                                                    body=node.body, orelse=[]),
                                            node)
                    node = newnode
                elif isinstance(self.curprim, astnodes._ConsumeNode):
                    stream = self.curprim.stream
                    if isinstance(self.curprim.stream, ast.AST):
                        stream = unparse(self.curprim.stream)
                    if '[' not in stream:
                        stream += '[0]'

                    newnode = \
                        _copy_location(ast.While(
                            test=ast.parse('len(%s) > 0' % stream).body[0].value,
                                           body=node.body, orelse=[]),
                                       node)
                    node = newnode
                    node.body.insert(
                        0,
                        _copy_location(
                            ast.parse(
                                '%s = %s.popleft()' %
                                (str(self.curprim.params[0]), stream)).body[0],
                            node))

                elif isinstance(self.curprim, astnodes._TaskletNode):
                    # Strip decorator
                    del node.decorator_list[0]

                    newnode = \
                        _copy_location(ast.parse('if True: pass').body[0], node)
                    newnode.body = node.body
                    newnode = ast.fix_missing_locations(newnode)
                    node = newnode
                elif isinstance(self.curprim, astnodes._ReduceNode):
                    in_memlet = self.curprim.inputs['input']
                    out_memlet = self.curprim.outputs['output']
                    # Create reduction call
                    params = [unparse(p) for p in node.decorator_list[0].args]
                    params.extend([
                        unparse(kp) for kp in node.decorator_list[0].keywords
                    ])
                    reduction = ast.parse(
                        '%s.simulator.simulate_reduce(%s, %s)' %
                        (self.module_name, node.name,
                         ', '.join(params))).body[0]
                    reduction = _copy_location(reduction, node)
                    reduction = ast.increment_lineno(reduction,
                                                     len(node.body) + 1)
                    reduction = ast.fix_missing_locations(reduction)

                    # Strip decorator
                    del node.decorator_list[0]

                    after_nodes.append(reduction)
                elif isinstance(self.curprim, astnodes._IterateNode):
                    newnode = \
                        _copy_location(ast.For(target=ast.Tuple(ctx=ast.Store(),
                                                    elts=[ast.Name(id=name, ctx=ast.Store()) for name in self.curprim.params]),
                                                    iter=ast.parse('%s.ndrange(%s)' % (self.module_name, self.curprim.range.pystr())).body[0].value,
                                                    body=node.body, orelse=[]),
                                            node)
                    newnode = ast.fix_missing_locations(newnode)
                    node = newnode
                elif isinstance(self.curprim, astnodes._LoopNode):
                    newnode = \
                        _copy_location(ast.While(test=node.decorator_list[0].args[0],
                                                    body=node.body, orelse=[]),
                                            node)
                    newnode = ast.fix_missing_locations(newnode)
                    node = newnode
                else:
                    raise RuntimeError('Unimplemented primitive %s' % decname)
            else:
                return self.generic_visit(node)

        newbody = []
        end_stmts = []
        substitute_stmts = []
        # Incrementally build new body from original body
        for stmt in node.body:
            if isinstance(stmt, ast.Expr):
                res, append, prepend = self.VisitTopLevelExpr(stmt)
                if res is not None:
                    newbody.append(res)
                if append is not None:
                    end_stmts.extend(append)
                if prepend is not None:
                    substitute_stmts.extend(prepend)
            else:
                subnodes = self.visit(stmt)
                if subnodes is not None:
                    if isinstance(subnodes, list):
                        newbody.extend(subnodes)
                    else:
                        newbody.append(subnodes)
        node.body = newbody + end_stmts

        self.curchild = oldchild
        self.curprim = oldprim

        substitute_stmts.append(node)
        if len(after_nodes) > 0:
            return substitute_stmts + after_nodes
        return substitute_stmts
Exemple #24
0
 def visit_Call(self, node):
     if '.push' in rname(node.func):
         node.func.attr = 'append'
     return self.generic_visit(node)
Exemple #25
0
    def visit_Assign(self, node):
        target = rname(node.targets[-1])
        if target not in self.memlets:
            return self.generic_visit(node)

        memlet, nc, wcr, dtype = self.memlets[target]
        value = self.visit(node.value)

        if not isinstance(node.targets[-1], ast.Subscript):
            # Dynamic accesses or streams -> every access counts
            try:
                if memlet and memlet.data and (memlet.dynamic or isinstance(
                        self.sdfg.arrays[memlet.data], data.Stream)):
                    if wcr is not None:
                        newnode = ast.Name(
                            id=self.codegen.write_and_resolve_expr(
                                self.sdfg,
                                memlet,
                                nc,
                                target,
                                cppunparse.cppunparse(value,
                                                      expr_semicolon=False),
                                dtype=dtype))
                        node.value = ast.copy_location(newnode, node.value)
                        return node
                    elif isinstance(self.sdfg.arrays[memlet.data],
                                    data.Stream):
                        newnode = ast.Name(id="%s.push(%s);" % (
                            memlet.data,
                            cppunparse.cppunparse(value, expr_semicolon=False),
                        ))
                    else:
                        var_type, ctypedef = self.codegen._dispatcher.defined_vars.get(
                            memlet.data)
                        if var_type == DefinedType.Scalar:
                            newnode = ast.Name(id="%s = %s;" % (
                                memlet.data,
                                cppunparse.cppunparse(value,
                                                      expr_semicolon=False),
                            ))
                        else:
                            newnode = ast.Name(id="%s = %s;" % (
                                cpp_array_expr(self.sdfg, memlet),
                                cppunparse.cppunparse(value,
                                                      expr_semicolon=False),
                            ))

                    return self._replace_assignment(newnode, node)
            except TypeError:  # cannot determine truth value of Relational
                pass

            return self.generic_visit(node)

        subscript = self._subscript_expr(node.targets[-1].slice, target)

        if wcr is not None:
            newnode = ast.Name(id=self.codegen.write_and_resolve_expr(
                self.sdfg,
                memlet,
                nc,
                target,
                cppunparse.cppunparse(value, expr_semicolon=False),
                indices=sym2cpp(subscript),
                dtype=dtype) + ';')
        else:
            newnode = ast.Name(
                id="%s[%s] = %s;" %
                (target, sym2cpp(subscript),
                 cppunparse.cppunparse(value, expr_semicolon=False)))

        return self._replace_assignment(newnode, node)
Exemple #26
0
def find_promotable_scalars(sdfg: sd.SDFG,
                            transients_only: bool = True,
                            integers_only: bool = True) -> Set[str]:
    """
    Finds scalars that can be promoted to symbols in the given SDFG.
    Conditions for matching a scalar for symbol-promotion are as follows:
        * Size of data must be 1, it must not be a stream and must be transient.
        * Only inputs to candidate scalars must be either arrays or tasklets.
        * All tasklets that lead to it must have one statement, one output, 
          and may have zero or more **array** inputs and not be in a scope.
        * Scalar must not be accessed with a write-conflict resolution.
        * Scalar must not be written to more than once in a state.
        * If scalar is not integral (i.e., int type), it must also appear in
          an inter-state condition to be promotable.

    These conditions must apply on all occurences of the scalar in order for
    it to be promotable.

    :param sdfg: The SDFG to query.
    :param transients_only: If False, also considers global data descriptors (e.g., arguments).
    :param integers_only: If False, also considers non-integral descriptors for promotion.
    :return: A set of promotable scalar names.
    """
    # Keep set of active candidates
    candidates: Set[str] = set()

    # General array checks
    for aname, desc in sdfg.arrays.items():
        if (transients_only and not desc.transient) or isinstance(
                desc, dt.Stream):
            continue
        if desc.total_size != 1:
            continue
        if desc.lifetime is dtypes.AllocationLifetime.Persistent:
            continue
        candidates.add(aname)

    # Check all occurrences of candidates in SDFG and filter out
    candidates_seen: Set[str] = set()
    for state in sdfg.nodes():
        candidates_in_state: Set[str] = set()

        for node in state.nodes():
            if not isinstance(node, nodes.AccessNode):
                continue
            candidate = node.data
            if candidate not in candidates:
                continue

            # If candidate is read-only, continue normally
            if state.in_degree(node) == 0:
                continue

            # If candidate is read by a library node, skip
            removed = False
            for oe in state.out_edges(node):
                for e in state.memlet_tree(oe):
                    if isinstance(e.dst, nodes.LibraryNode):
                        candidates.remove(candidate)
                        removed = True
                        break
                if removed:
                    break
            if removed:
                continue
            # End of read check

            # Candidate may only be accessed in a top-level scope
            if state.entry_node(node) is not None:
                candidates.remove(candidate)
                continue

            # Candidate may only be written to once within a state
            if candidate in candidates_in_state:
                if state.in_degree(node) == 1:
                    candidates.remove(candidate)
                    continue
            candidates_in_state.add(candidate)

            # If input is not a single array nor tasklet, skip
            if state.in_degree(node) > 1:
                candidates.remove(candidate)
                continue
            edge = state.in_edges(node)[0]

            # Edge must not be WCR
            if edge.data.wcr is not None:
                candidates.remove(candidate)
                continue

            # Check inputs
            if isinstance(edge.src, nodes.AccessNode):
                # If input is array, ensure it is not a stream
                if isinstance(sdfg.arrays[edge.src.data], dt.Stream):
                    candidates.remove(candidate)
                    continue
                # Ensure no inputs exist to the array
                if state.in_degree(edge.src) > 0:
                    candidates.remove(candidate)
                    continue
            elif isinstance(edge.src, nodes.Tasklet):
                # If input tasklet has more than one output, skip
                if state.out_degree(edge.src) > 1:
                    candidates.remove(candidate)
                    continue
                # If inputs to tasklets are not arrays, skip
                for tinput in state.in_edges(edge.src):
                    if not isinstance(tinput.src, nodes.AccessNode):
                        candidates.remove(candidate)
                        break
                    if isinstance(sdfg.arrays[tinput.src.data], dt.Stream):
                        candidates.remove(candidate)
                        break
                    # If input is not a single-element memlet, skip
                    if (tinput.data.dynamic
                            or tinput.data.subset.num_elements() != 1):
                        candidates.remove(candidate)
                        break
                    # If input array has inputs of its own (cannot promote within same state), skip
                    if state.in_degree(tinput.src) > 0:
                        candidates.remove(candidate)
                        break
                else:
                    # Check that tasklets have only one statement
                    cb: props.CodeBlock = edge.src.code
                    if cb.language is dtypes.Language.Python:
                        if (len(cb.code) > 1
                                or not isinstance(cb.code[0], ast.Assign)):
                            candidates.remove(candidate)
                            continue
                        # Ensure the candidate is assigned to
                        if (len(cb.code[0].targets) != 1 or astutils.rname(
                                cb.code[0].targets[0]) != edge.src_conn):
                            candidates.remove(candidate)
                            continue
                        # Ensure that the candidate is not assigned through
                        # an "attribute" call, e.g., "dace.int64". These calls
                        # are not supported currently by the SymPy-based
                        # symbolic module.
                        detector = AttributedCallDetector()
                        detector.visit(cb.code[0].value)
                        if detector.detected:
                            candidates.remove(candidate)
                            continue
                    elif cb.language is dtypes.Language.CPP:
                        # Try to match a single C assignment
                        cstr = cb.as_string.strip()
                        # Since we cannot remove subscripts from C++ tasklets,
                        # if the type of the data is an array we will also skip
                        if re.match(r'^[a-zA-Z_][a-zA-Z_0-9]*\s*=.*;$',
                                    cstr) is None:
                            candidates.remove(candidate)
                            continue
                        newcode = translate_cpp_tasklet_to_python(cstr)
                        try:
                            parsed_ast = ast.parse(str(newcode))
                        except SyntaxError:
                            #if we cannot parse the expression to pythonize it, we cannot promote the candidate
                            candidates.remove(candidate)
                            continue
                    else:  # Other languages are currently unsupported
                        candidates.remove(candidate)
                        continue
            else:  # If input is not an acceptable node type, skip
                candidates.remove(candidate)
        candidates_seen |= candidates_in_state

    # Filter out non-integral symbols that do not appear in inter-state edges
    interstate_symbols = set()
    for edge in sdfg.edges():
        interstate_symbols |= edge.data.free_symbols
    for candidate in (candidates - interstate_symbols):
        if integers_only and sdfg.arrays[
                candidate].dtype not in dtypes.INTEGER_TYPES:
            candidates.remove(candidate)

    # Only keep candidates that were found in SDFG
    candidates &= (candidates_seen | interstate_symbols)

    return candidates
Exemple #27
0
    def global_value_to_node(self,
                             value,
                             parent_node,
                             qualname,
                             recurse=False,
                             detect_callables=False):
        # if recurse is false, we don't allow recursion into lists
        # this should not happen anyway; the globals dict should only contain
        # single "level" lists
        if not recurse and isinstance(value, (list, tuple)):
            # bail after more than one level of lists
            return None

        if isinstance(value, list):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.List(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, tuple):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, symbolic.symbol):
            # Symbols resolve to the symbol name
            newnode = ast.Name(id=value.name, ctx=ast.Load())
        elif (dtypes.isconstant(value) or isinstance(value, SDFG)
              or hasattr(value, '__sdfg__')):
            # Could be a constant, an SDFG, or SDFG-convertible object
            if isinstance(value, SDFG) or hasattr(value, '__sdfg__'):
                self.closure.closure_sdfgs[qualname] = value
            else:
                self.closure.closure_constants[qualname] = value

            # Compatibility check since Python changed their AST nodes
            if sys.version_info >= (3, 8):
                newnode = ast.Constant(value=value, kind='')
            else:
                if value is None:
                    newnode = ast.NameConstant(value=None)
                elif isinstance(value, str):
                    newnode = ast.Str(s=value)
                else:
                    newnode = ast.Num(n=value)

            newnode.oldnode = copy.deepcopy(parent_node)

        elif detect_callables and hasattr(value, '__call__') and hasattr(
                value.__call__, '__sdfg__'):
            return self.global_value_to_node(value.__call__, parent_node,
                                             qualname, recurse,
                                             detect_callables)
        elif isinstance(value, numpy.ndarray):
            # Arrays need to be stored as a new name and fed as an argument
            if id(value) in self.closure.array_mapping:
                arrname = self.closure.array_mapping[id(value)]
            else:
                arrname = self._qualname_to_array_name(qualname)
                desc = data.create_datadescriptor(value)
                self.closure.closure_arrays[arrname] = (
                    qualname, desc, lambda: eval(qualname, self.globals),
                    False)
                self.closure.array_mapping[id(value)] = arrname

            newnode = ast.Name(id=arrname, ctx=ast.Load())
        elif detect_callables and callable(value):
            # Try parsing the function as a dace function/method
            newnode = None
            try:
                from dace.frontend.python import parser  # Avoid import loops

                parent_object = None
                if hasattr(value, '__self__'):
                    parent_object = value.__self__

                # If it is a callable object
                if (not inspect.isfunction(value)
                        and not inspect.ismethod(value)
                        and not inspect.isbuiltin(value)
                        and hasattr(value, '__call__')):
                    parent_object = value
                    value = value.__call__

                # Replacements take precedence over auto-parsing
                try:
                    if has_replacement(value, parent_object, parent_node):
                        return None
                except Exception:
                    pass

                # Store the handle to the original callable, in case parsing fails
                cbqualname = astutils.rname(parent_node)
                cbname = self._qualname_to_array_name(cbqualname, prefix='')
                self.closure.callbacks[cbname] = (cbqualname, value, False)

                # From this point on, any failure will result in a callback
                newnode = ast.Name(id=cbname, ctx=ast.Load())

                # Decorated or functions with missing source code
                sast, _, _, _ = astutils.function_to_ast(value)
                if len(sast.body[0].decorator_list) > 0:
                    return newnode

                parsed = parser.DaceProgram(value, [], {}, False,
                                            dtypes.DeviceType.CPU)
                # If method, add the first argument (which disappears due to
                # being a bound method) and the method's object
                if parent_object is not None:
                    parsed.methodobj = parent_object
                    parsed.objname = inspect.getfullargspec(value).args[0]

                res = self.global_value_to_node(parsed, parent_node, qualname,
                                                recurse, detect_callables)
                # Keep callback in callbacks in case of parsing failure
                # del self.closure.callbacks[cbname]
                return res
            except Exception:  # Parsing failed (almost any exception can occur)
                return newnode
        else:
            return None

        if parent_node is not None:
            return ast.copy_location(newnode, parent_node)
        else:
            return newnode