Beispiel #1
0
def compare(input_src, expected_output_src, transformer_class):
    """
    Testing utility. Takes the input source and transforms it with
    `transformer_class`. It then compares the output with the given
    reference and throws an exception if they don't match.
    
    This method also deals with name-mangling.
    """
    
    uid = naming.UniqueIdentifierFactory()
    
    actual_root = ast.parse(unindent(input_src))
    EncodeNames().visit(actual_root)
    actual_root = transformer_class(uid).checked_visit(actual_root)
    actual_root = ast.fix_missing_locations(actual_root)
    compile(actual_root, "<string>", 'exec')
    actual_src = astor.to_source(actual_root)
    expected_root = ast.parse(unindent(expected_output_src))
    EncodeNames().visit(expected_root)
    expected_src = astor.to_source(expected_root)
    
    cmps = itertools.izip_longest(expected_src.splitlines(), actual_src.splitlines())
    for linenr, c in enumerate(cmps, 1):
        expected_line = c[0]
        actual_line = c[1]
        if expected_line != actual_line:
            sys.stderr.write(actual_src)
            sys.stderr.write("\n")
        if expected_line != actual_line:
            raise AssertionError("Line %s differs. Expected %s but got %s." % (linenr, repr(expected_line), repr(actual_line)))
Beispiel #2
0
    def process(self, node, id_factory=None):
        
        if not id_factory:
            id_factory = naming.UniqueIdentifierFactory()
        
        def process_step(node, step_class):
            scoping.ScopeAssigner().visit(node)
            scoping.ExtendedScopeAssigner().visit(node)

            step = step_class(id_factory)
            node = step.visit(node)
            
            return node
        
        for step_class in self.steps:
            node = process_step(node, step_class)

        node = ast.fix_missing_locations(node)
            
        return node
Beispiel #3
0
    def process(self, node, id_factory=None):
        
        if not id_factory:
            id_factory = naming.UniqueIdentifierFactory()
        
        def process_step(node, step_class):
            scoping.ScopeAssigner().visit(node)
            scoping.ExtendedScopeAssigner().visit(node)

            step = step_class(id_factory)
            node = step.visit(node)
            
            return node
        
        
        for step_class in self.steps:
            logger.debug("Applying step %s" % step_class)
            node = process_step(node, step_class)
            logger.debug("After %s:\n" % step_class + astor.to_source(node))

        node = ast.fix_missing_locations(node)
            
        return node
Beispiel #4
0
def translate_function(function, scheduler, saneitize=True):
    """
    Translates a function into a :class:`tasks.ScheduledCallable`.
    """
    def main_workaround(module_name):
        """
        If the function is inside __main__ we have problem
        since the workers will have a different module called
        __main__.
        So we make a best-effort attemt to find if __main__
        is also reachable as a module.
        """
        if module_name != "__main__":
            return module_name  # we are fine.

        # See if we can find a sys.path that matches the location of __main__
        main_file = getattr(sys.modules["__main__"], "__file__", "")
        candidates = {path for path in sys.path if main_file.startswith(path)}
        candidates = sorted(candidates, key=len)
        for candidate in candidates:

            # Try to create the absolute module name from the filename only.
            module_name = main_file[len(candidate):]
            if module_name.endswith(".py"):
                module_name = module_name[:-3]
            if module_name.endswith(".pyc"):
                module_name = module_name[:-4]
            module_name = module_name.replace("/", ".")
            module_name = module_name.replace("\\", ".")
            while module_name.startswith("."):
                module_name = module_name[1:]

            # Check if it actually works.
            try:
                module = importlib.import_module(module_name)
                return module.__name__
            except ImportError:
                pass

        # we were unlucky.
        raise ValueError(
            "The functions in the __main__ module cannot be translated.")

    source = inspect.getsourcelines(function)
    source = "".join(source[0])

    logger.info("Translating: \n%s" % source)

    node = ast.parse(utils.unindent(source))

    # Remove decorators

    # TODO handle decorators properly
    assert len(node.body) == 1
    funcdef = node.body[0]
    assert isinstance(funcdef, ast.FunctionDef)
    funcdef.decorator_list = []

    if len(funcdef.args.defaults) != 0:
        # TODO add support
        raise ValueError(
            "Cannot translate %f: @schedule does not support functions with default arguments"
        )

    id_factory = naming.UniqueIdentifierFactory()

    if saneitize:
        makesane = saneitizer.Saneitizer()
        node = makesane.process(node, id_factory)

    module_name = getattr(function, "__module__", None)
    if not module_name:
        raise ValueError(
            "Cannot translate %f: The module in which it is defined is unknown."
        )
    module_name = main_workaround(module_name)

    import astor
    logger.info("Preprocessed source:\n%s" % astor.to_source(node))

    translator = Translator(id_factory, scheduler, module_name)
    graph = translator.visit(node)

    def find_FunctionDefTask(graph):
        for tick in graph.get_all_ticks():
            task = graph.get_task(tick)
            if isinstance(task, tasks.FunctionDefTask):
                return task
        raise ValueError("No function was translated.")

    funcdeftask = find_FunctionDefTask(graph)

    defaults = function.__defaults__
    if not defaults:
        defaults = tuple()

    if funcdeftask.num_defaults != len(defaults):
        raise ValueError("Number of default arguments doesn't match.")

    if function.__closure__:
        raise ValueError("Translating closures currently not supported.")

    inputs = {"default_%s" % i: v for i, v in enumerate(defaults)}
    scheduled_callable = funcdeftask.evaluate(inputs)['function']
    return scheduled_callable