Beispiel #1
0
 def test_simple(self):
     test_ast, _, _, _ = astutils.function_to_ast(toresolve)
     code = astunparse.unparse(
         GlobalResolver({
             'b': 9,
             'a': -4
         }).visit(test_ast))
     self.assertTrue('return 9' in code)
     self.assertTrue('f(a, b)' in code)
     self.assertTrue('g(b' in code)
Beispiel #2
0
def _analyze_and_unparse_code(func: DaceProgram) -> str:
    src_ast, _, _, _ = astutils.function_to_ast(func.f)
    resolved = {
        k: v
        for k, v in func.global_vars.items() if k not in func.argnames
    }
    src_ast = GlobalResolver(resolved).visit(src_ast)
    src_ast = ConditionalCodeResolver(resolved).visit(src_ast)
    src_ast = DeadCodeEliminator().visit(src_ast)

    return astutils.unparse(src_ast)
Beispiel #3
0
def test_dace_unroll_break():
    """ Tests unrolling functionality with control flow statements. """
    @dace.program
    def tounroll(A: dace.float64[1]):
        for i in dace.unroll(range(1, 4)):
            A[0] += i * i
            if i in (2, 3):
                break

    src_ast, fname, _, _ = astutils.function_to_ast(tounroll.f)
    lu = LoopUnroller(tounroll.global_vars, fname)
    with pytest.raises(DaceSyntaxError):
        unrolled = lu.visit(src_ast)
Beispiel #4
0
def test_dace_unroll():
    """ Tests that unrolling functionality works within DaCe programs. """
    @dace.program
    def tounroll(A: dace.float64[1]):
        for i in dace.unroll(range(1, 4)):
            A[0] += i * i

    src_ast, fname, _, _ = astutils.function_to_ast(tounroll.f)
    lu = LoopUnroller(tounroll.global_vars, fname)
    unrolled = lu.visit(src_ast)
    assert len(unrolled.body[0].body) == 3

    a = np.zeros([1])
    tounroll(a)
    assert a[0] == 14
Beispiel #5
0
def test_dace_unroll_multistatement():
    """ Tests unrolling functionality with multiple statements. """
    @dace.program
    def tounroll(A: dace.float64[1]):
        for i in dace.unroll(range(1, 4)):
            A[0] += i * i
            if i in (3, ):
                A[0] += 2

    src_ast, fname, _, _ = astutils.function_to_ast(tounroll.f)
    lu = LoopUnroller(tounroll.global_vars, fname)
    unrolled = lu.visit(src_ast)
    assert len(unrolled.body[0].body) == 6

    a = np.zeros([1])
    tounroll(a)
    assert a[0] == 16
Beispiel #6
0
def preprocess_dace_program(
    f: Callable[..., Any],
    argtypes: Dict[str, data.Data],
    global_vars: Dict[str, Any],
    modules: Dict[str, Any],
    resolve_functions: bool = False,
    parent_closure: Optional[SDFGClosure] = None
) -> Tuple[PreprocessedAST, SDFGClosure]:
    """
    Preprocesses a ``@dace.program`` and all its nested functions, returning
    a preprocessed AST object and the closure of the resulting SDFG.
    :param f: A Python function to parse.
    :param argtypes: An dictionary of (name, type) for the given
                        function's arguments, which may pertain to data
                        nodes or symbols (scalars).
    :param global_vars: A dictionary of global variables in the closure
                        of `f`.
    :param modules: A dictionary from an imported module name to the
                    module itself.
    :param constants: A dictionary from a name to a constant value.
    :param resolve_functions: If True, treats all global functions defined
                                outside of the program as returning constant
                                values.
    :param parent_closure: If not None, represents the closure of the parent of
                           the currently processed function.
    :return: A 2-tuple of the AST and its reduced (used) closure.
    """
    src_ast, src_file, src_line, src = astutils.function_to_ast(f)

    # Resolve data structures
    src_ast = StructTransformer(global_vars).visit(src_ast)

    src_ast = ModuleResolver(modules).visit(src_ast)
    # Convert modules after resolution
    for mod, modval in modules.items():
        if mod == 'builtins':
            continue
        newmod = global_vars[mod]
        #del global_vars[mod]
        global_vars[modval] = newmod

    # Resolve constants to their values (if they are not already defined in this scope)
    # and symbols to their names
    resolved = {
        k: v
        for k, v in global_vars.items() if k not in argtypes and k != '_'
    }
    closure_resolver = GlobalResolver(resolved, resolve_functions)

    # Append element to call stack and handle max recursion depth
    if parent_closure is not None:
        fid = id(f)
        if fid in parent_closure.callstack:
            raise DaceRecursionError(fid)
        if len(parent_closure.callstack) > Config.get(
                'frontend', 'implicit_recursion_depth'):
            raise TypeError(
                'Implicit (automatically parsed) recursion depth '
                'exceeded. Functions below this call will not be '
                'parsed. To change this setting, modify the value '
                '`frontend.implicit_recursion_depth` in .dace.conf')

        closure_resolver.closure.callstack = parent_closure.callstack + [fid]

    src_ast = closure_resolver.visit(src_ast)
    src_ast = LoopUnroller(resolved, src_file).visit(src_ast)
    src_ast = ConditionalCodeResolver(resolved).visit(src_ast)
    src_ast = DeadCodeEliminator().visit(src_ast)
    try:
        ctr = CallTreeResolver(closure_resolver.closure, resolved)
        ctr.visit(src_ast)
    except DaceRecursionError as ex:
        if id(f) == ex.fid:
            raise TypeError(
                'Parsing failed due to recursion in a data-centric '
                'context called from this function')
        else:
            raise ex
    used_arrays = ArrayClosureResolver(closure_resolver.closure)
    used_arrays.visit(src_ast)

    # Filter out arrays that are not used after dead code elimination
    closure_resolver.closure.closure_arrays = {
        k: v
        for k, v in closure_resolver.closure.closure_arrays.items()
        if k in used_arrays.arrays
    }

    # Filter out callbacks that were removed after dead code elimination
    closure_resolver.closure.callbacks = {
        k: v
        for k, v in closure_resolver.closure.callbacks.items()
        if k in ctr.seen_calls
    }

    # Filter remaining global variables according to type and scoping rules
    program_globals = {
        k: v
        for k, v in global_vars.items() if k not in argtypes
    }

    # Fill in data descriptors from closure arrays
    argtypes.update({
        arrname: v[1]
        for arrname, v in closure_resolver.closure.closure_arrays.items()
    })

    # Combine nested closures with the current one
    closure_resolver.closure.combine_nested_closures()

    past = PreprocessedAST(src_file, src_line, src, src_ast, program_globals)

    return past, closure_resolver.closure
Beispiel #7
0
    def global_value_to_node(self,
                             value,
                             parent_node,
                             qualname,
                             recurse=False,
                             detect_callables=False):
        # if recurse is false, we don't allow recursion into lists
        # this should not happen anyway; the globals dict should only contain
        # single "level" lists
        if not recurse and isinstance(value, (list, tuple)):
            # bail after more than one level of lists
            return None

        if isinstance(value, list):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.List(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, tuple):
            elts = [
                self.global_value_to_node(v,
                                          parent_node,
                                          qualname + f'[{i}]',
                                          detect_callables=detect_callables)
                for i, v in enumerate(value)
            ]
            if any(e is None for e in elts):
                return None
            newnode = ast.Tuple(elts=elts, ctx=parent_node.ctx)
        elif isinstance(value, symbolic.symbol):
            # Symbols resolve to the symbol name
            newnode = ast.Name(id=value.name, ctx=ast.Load())
        elif (dtypes.isconstant(value) or isinstance(value, SDFG)
              or hasattr(value, '__sdfg__')):
            # Could be a constant, an SDFG, or SDFG-convertible object
            if isinstance(value, SDFG) or hasattr(value, '__sdfg__'):
                self.closure.closure_sdfgs[qualname] = value
            else:
                self.closure.closure_constants[qualname] = value

            # Compatibility check since Python changed their AST nodes
            if sys.version_info >= (3, 8):
                newnode = ast.Constant(value=value, kind='')
            else:
                if value is None:
                    newnode = ast.NameConstant(value=None)
                elif isinstance(value, str):
                    newnode = ast.Str(s=value)
                else:
                    newnode = ast.Num(n=value)

            newnode.oldnode = copy.deepcopy(parent_node)

        elif detect_callables and hasattr(value, '__call__') and hasattr(
                value.__call__, '__sdfg__'):
            return self.global_value_to_node(value.__call__, parent_node,
                                             qualname, recurse,
                                             detect_callables)
        elif isinstance(value, numpy.ndarray):
            # Arrays need to be stored as a new name and fed as an argument
            if id(value) in self.closure.array_mapping:
                arrname = self.closure.array_mapping[id(value)]
            else:
                arrname = self._qualname_to_array_name(qualname)
                desc = data.create_datadescriptor(value)
                self.closure.closure_arrays[arrname] = (
                    qualname, desc, lambda: eval(qualname, self.globals),
                    False)
                self.closure.array_mapping[id(value)] = arrname

            newnode = ast.Name(id=arrname, ctx=ast.Load())
        elif detect_callables and callable(value):
            # Try parsing the function as a dace function/method
            newnode = None
            try:
                from dace.frontend.python import parser  # Avoid import loops

                parent_object = None
                if hasattr(value, '__self__'):
                    parent_object = value.__self__

                # If it is a callable object
                if (not inspect.isfunction(value)
                        and not inspect.ismethod(value)
                        and not inspect.isbuiltin(value)
                        and hasattr(value, '__call__')):
                    parent_object = value
                    value = value.__call__

                # Replacements take precedence over auto-parsing
                try:
                    if has_replacement(value, parent_object, parent_node):
                        return None
                except Exception:
                    pass

                # Store the handle to the original callable, in case parsing fails
                cbqualname = astutils.rname(parent_node)
                cbname = self._qualname_to_array_name(cbqualname, prefix='')
                self.closure.callbacks[cbname] = (cbqualname, value, False)

                # From this point on, any failure will result in a callback
                newnode = ast.Name(id=cbname, ctx=ast.Load())

                # Decorated or functions with missing source code
                sast, _, _, _ = astutils.function_to_ast(value)
                if len(sast.body[0].decorator_list) > 0:
                    return newnode

                parsed = parser.DaceProgram(value, [], {}, False,
                                            dtypes.DeviceType.CPU)
                # If method, add the first argument (which disappears due to
                # being a bound method) and the method's object
                if parent_object is not None:
                    parsed.methodobj = parent_object
                    parsed.objname = inspect.getfullargspec(value).args[0]

                res = self.global_value_to_node(parsed, parent_node, qualname,
                                                recurse, detect_callables)
                # Keep callback in callbacks in case of parsing failure
                # del self.closure.callbacks[cbname]
                return res
            except Exception:  # Parsing failed (almost any exception can occur)
                return newnode
        else:
            return None

        if parent_node is not None:
            return ast.copy_location(newnode, parent_node)
        else:
            return newnode