Ejemplo n.º 1
0
def parse_memlet(visitor, src: MemletType, dst: MemletType,
                 defined_arrays_and_symbols: Dict[str, data.Data]):
    srcexpr, dstexpr, localvar = None, None, None
    if isinstance(src,
                  ast.Name) and rname(src) not in defined_arrays_and_symbols:
        localvar = rname(src)
    else:
        srcexpr = ParseMemlet(visitor, defined_arrays_and_symbols, src)
    if isinstance(dst,
                  ast.Name) and rname(dst) not in defined_arrays_and_symbols:
        if localvar is not None:
            raise DaceSyntaxError(
                visitor, src,
                'Memlet source and destination cannot both be local variables')
        localvar = rname(dst)
    else:
        dstexpr = ParseMemlet(visitor, defined_arrays_and_symbols, dst)

    if srcexpr is not None and dstexpr is not None:
        # Create two memlets
        raise NotImplementedError
    elif srcexpr is not None:
        expr = srcexpr
    else:
        expr = dstexpr

    return localvar, Memlet(expr.name,
                            expr.accesses,
                            expr.subset,
                            1,
                            wcr=expr.wcr)
Ejemplo n.º 2
0
def ParseMemlet(visitor,
                defined_arrays_and_symbols: Dict[str, Any],
                node: MemletType,
                parsed_slice: Any = None) -> MemletExpr:
    das = defined_arrays_and_symbols
    arrname = rname(node)
    if arrname not in das:
        raise DaceSyntaxError(visitor, node,
                              'Use of undefined data "%s" in memlet' % arrname)
    array = das[arrname]

    # Determine number of accesses to the memlet (default is the slice size)
    num_accesses = None
    write_conflict_resolution = None
    # Detects expressions of the form "A(2)[...]", "A(300)", "A(1, sum)[:]"
    if isinstance(node, ast.Call):
        if len(node.args) < 1 or len(node.args) > 3:
            raise DaceSyntaxError(
                visitor, node,
                'Number of accesses in memlet must be a number, symbolic '
                'expression, or -1 (dynamic)')
        num_accesses = pyexpr_to_symbolic(das, node.args[0])
        if len(node.args) >= 2:
            write_conflict_resolution = node.args[1]
    elif isinstance(node, ast.Subscript) and isinstance(node.value, ast.Call):
        if len(node.value.args) < 1 or len(node.value.args) > 3:
            raise DaceSyntaxError(
                visitor, node,
                'Number of accesses in memlet must be a number, symbolic '
                'expression, or -1 (dynamic)')
        num_accesses = pyexpr_to_symbolic(das, node.value.args[0])
        if len(node.value.args) >= 2:
            write_conflict_resolution = node.value.args[1]

    subset, new_axes, arrdims = parse_memlet_subset(array, node, das,
                                                    parsed_slice)

    # If undefined, default number of accesses is the slice size
    if num_accesses is None:
        num_accesses = subset.num_elements()

    return MemletExpr(arrname, num_accesses, write_conflict_resolution, subset,
                      new_axes, arrdims)
Ejemplo n.º 3
0
    def visit_For(self, node: ast.For) -> Any:
        # Avoid import loops
        EXPLICIT_GENERATORS = [
            range,  # Handled in ProgramVisitor
            dace.map,
            dace.consume,
        ]

        node = self.generic_visit(node)

        # First, skip loops that contain break/continue that is part of this
        # for loop (rather than nested ones)
        cannot_unroll = False
        cflow_finder = _FindBreakContinueStmts()
        for stmt in node.body:
            cflow_finder.visit(stmt)
        if cflow_finder.has_cflow or node.orelse:
            cannot_unroll = True

        niter = node.iter

        # Find out if loop was explicitly requested to be unrolled with unroll,
        # and whether it should be done implicitly
        explicitly_requested = False
        if isinstance(niter, ast.Call):
            # Avoid import loop
            from dace.frontend.python.interface import unroll

            try:
                genfunc = astutils.evalnode(niter.func, self.globals)
            except SyntaxError:
                genfunc = None

            if genfunc is unroll:
                explicitly_requested = True
                niter = niter.args[0]

        if explicitly_requested and cannot_unroll:
            raise DaceSyntaxError(
                None, node, 'Cannot unroll loop due to '
                '"break", "continue", or "else" statements.')

        # Find out if unrolling should be done implicitly
        implicit = True
        # Anything not a call is implicitly allowed
        if isinstance(niter, (ast.Call, ast.Subscript)):
            if isinstance(niter, ast.Subscript):
                nfunc = niter.value
            else:
                nfunc = niter.func

            implicit = False
            # Try to see if it's one of the allowed stateless generators
            try:
                genfunc = astutils.evalnode(nfunc, self.globals)

                # If genfunc is a bound method, try to extract function from type
                if hasattr(genfunc, '__self__'):
                    genfunc = getattr(type(genfunc.__self__), genfunc.__name__,
                                      False)

                if genfunc in LoopUnroller.STATELESS_GENERATORS:
                    implicit = True
                elif genfunc in EXPLICIT_GENERATORS:
                    implicit = False

            except SyntaxError:
                pass

        # Loop will not be unrolled
        if not implicit and not explicitly_requested:
            return node

        # Check if loop target is supported
        if isinstance(node.target, ast.Tuple):
            to_replace = node.target.elts
        elif isinstance(node.target, ast.Name):
            to_replace = [node.target]
        else:
            # Unsupported loop target
            return node

        if isinstance(niter, (ast.Tuple, ast.List, ast.Set)):
            # Check if a literal tuple/list/set
            generator = niter.elts
        elif isinstance(niter, ast.Dict):
            # If dict, take keys (Python compatible)
            generator = niter.keys
        # elif isinstance(iter, (ast.ListComp, ast.DictComp, ast.SetComp, ast.GeneratorExp)):
        #     # Check if a comprehension or generator expression
        #     pass
        else:
            # Check if the generator is compile-time constant
            try:
                generator = astutils.evalnode(niter, self.globals)
            except SyntaxError:
                # Cannot evaluate generator at compile time
                return node

        # Too verbose?
        if implicit and not explicitly_requested:
            warnings.warn(f'Loop at {self.filename}:{node.lineno} will be '
                          'implicitly unrolled.')

        ##########################################
        # Unroll loop
        new_body = []
        for elem in generator:
            # Paste loop body with replaced elements
            try:
                iter(elem)
            except (TypeError, ValueError):
                elem = [elem]

            elembody = copy.deepcopy(node.body)
            replace = astutils.ASTFindReplace(
                {k: v
                 for k, v in zip(to_replace, elem)})
            for stmt in elembody:
                new_body.append(replace.visit(stmt))

        return new_body