示例#1
0
    def _CastXML(self, node: ET.Element):  # pylint: disable=invalid-name
        self.file_id = self._determine_file_id(node)
        self.find_types(node)

        while self._new_relevant_types:
            _LOG.debug('there are %i new relevant types',
                       len(self._new_relevant_types))
            self.relevant_types.update(self._new_relevant_types)
            self._new_relevant_types = {}
            new_resolved_types = {}
            for id_, node_ in self.relevant_types.items():
                if id_ in self.resolved_types:
                    continue
                new_resolved_types[id_] = self.transform_one(node_)
                _LOG.debug('resolved %s into %s',
                           ET.tostring(node_).decode().rstrip(),
                           horast.unparse(new_resolved_types[id_]))
            self.resolved_types.update(new_resolved_types)

        self._fix_resolved_types(self.resolved_types)

        _LOG.info(
            'type resolution complete: %s',
            pprint.pformat({
                _: horast.unparse(type_).strip()
                for _, type_ in self.resolved_types.items()
            }))
示例#2
0
    def _Call(self, t):
        func_name = horast.unparse(t.func).strip()

        if func_name == 'np.zeros':
            self._includes['valarray'] = True
            self.write('std::valarray<')
            assert len(t.keywords) == 1, len(t.keywords)
            assert t.keywords[0].arg == 'dtype'
            self.dispatch_type(t.keywords[0].value)
            self.write('>(')
            comma = False
            for arg in t.args:
                if comma:
                    self.write(",")
                else:
                    comma = True
                self.dispatch(arg)
            return

        if t.keywords:
            self._unsupported_syntax(t, 'with keyword arguments')

        if func_name == 'print':
            self._includes['iostream'] = True
            self.write('std::cout << ')
            comma = False
            for arg in itertools.chain(t.args, t.keywords):
                if comma:
                    self.write(" << ")
                else:
                    comma = True
                self.dispatch(arg)
            return

        super()._Call(t)
示例#3
0
def inline(target_function, inlined_function, globals_=None) -> object:
    """Inline all calls to given inlined function within the target.

    Can be used as decorator.
    """
    assert isinstance(target_function, types.FunctionType)
    assert isinstance(inlined_function, types.FunctionType)
    language = Language.find('Python 3')
    parser = Parser.find(language)()
    target_code = CodeReader.read_function(target_function)
    inlined_code = CodeReader.read_function(inlined_function)
    target_syntax = parser.parse(target_code).body[0]
    inlined_syntax = parser.parse(inlined_code).body[0]
    target_inlined_syntax = inline_syntax(target_syntax, inlined_syntax, globals_=globals_,
                                          verbose=False)
    target_inlined_code = horast.unparse(target_inlined_syntax).lstrip()

    with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as output_file:
        # TODO: this leaves garbage behind in /tmp/ but is neeeded by subsequent transpiler passes
        code_writer = CodeWriter('.py')
        target_inlined_path = pathlib.Path(output_file.name)
        code_writer.write_file(target_inlined_code, target_inlined_path)

    code_obj = compile(target_inlined_code, filename=str(target_inlined_path), mode='exec')
    if globals_ is None:
        globals_ = {'__builtins__': globals()['__builtins__']}
    locals_ = {}
    eval_result = eval(code_obj, globals_, locals_)
    assert eval_result is None, eval_result
    assert target_function.__name__ in locals_
    target_inlined_function = locals_[target_function.__name__]
    assert isinstance(target_inlined_function, types.FunctionType)
    return target_inlined_function
示例#4
0
    def dispatch_var_type(self, tree):
        code = horast.unparse(tree)
        stripped_code = code.strip()
        if stripped_code in PYTHON_FORTRAN_TYPE_PAIRS:
            type_name, precision = PYTHON_FORTRAN_TYPE_PAIRS[stripped_code]
            self.write(type_name)
            if precision is not None:
                self.write('*')
                self.write(str(precision))
        elif _match_array(tree):
            sli = tree.slice
            assert isinstance(sli,
                              typed_ast3.Index), typed_astunparse.dump(tree)
            assert isinstance(sli.value, typed_ast3.Tuple)
            assert len(sli.value.elts) in (2, 3), sli.value.elts
            elts = sli.value.elts
            self.dispatch_var_type(elts[1])
            self.write(', ')
            self.write('dimension(')

            if len(sli.value.elts) == 2:
                self.write(':')
            else:
                if not self._context_input_args:
                    self.dispatch(elts[2])
                else:
                    assert isinstance(elts[2], typed_ast3.Tuple)
                    _LOG.warning('coercing indices of %i dims to 0-based',
                                 len(elts[2].elts))
                    # _LOG.warning('coercing indices of %s in %s to 0-based', arg.arg, t.name)
                    tup_elts = []
                    for elt in elts[2].elts:
                        if isinstance(elt, typed_ast3.Num):
                            assert isinstance(elt.n, int)
                            upper = typed_ast3.Num(n=elt.n - 1)
                        else:
                            assert isinstance(elt, typed_ast3.Name)
                            upper = typed_ast3.BinOp(left=elt,
                                                     op=typed_ast3.Sub(),
                                                     right=typed_ast3.Num(n=1))
                        tup_elts.append(
                            typed_ast3.Slice(lower=typed_ast3.Num(n=0),
                                             upper=upper,
                                             step=None))
                    tup = typed_ast3.Tuple(elts=tup_elts,
                                           ctx=typed_ast3.Load())
                    self.dispatch(tup)

            self.write(')')
        elif _match_io(tree):
            self.write('integer')
        elif isinstance(tree, typed_ast3.Call) and isinstance(tree.func, typed_ast3.Name) \
                and tree.func.id == 'type':
            self.dispatch(tree)
        else:
            raise NotImplementedError('not yet implemented: {}'.format(
                typed_astunparse.dump(tree)))
示例#5
0
def tree_to_str(tree: Module) -> str:
    # TODO why this fails?
    # validator.visit(tree)

    fix_missing_locations(tree)
    generated_code: str = unparse(tree)
    generated_code = autopep8.fix_code(generated_code,
                                       options={'aggressive': 1})

    return generated_code
示例#6
0
 def dispatch_type(self, type_hint):
     _LOG.debug('dispatching type hint %s', type_hint)
     if is_generic_type(type_hint):
         if type_hint.value.attr in CPP_GENERIC_TYPE_INCLUDES:
             self._includes[CPP_GENERIC_TYPE_INCLUDES[
                 type_hint.value.attr]] = True
         self.dispatch_type(type_hint.value)
         self.write("<")
         self.dispatch_type(type_hint.slice)
         self.write(">")
         return
     if is_pointer(type_hint):
         # _LOG.warning('dispatching pointer type %s', horast.unparse(type_hint).strip())
         assert isinstance(type_hint.slice,
                           typed_ast3.Index), type(type_hint.slice)
         if isinstance(type_hint.slice.value, typed_ast3.Name) \
                 and type_hint.slice.value.id == 'str':
             self.write('char')
         else:
             self.dispatch_type(type_hint.slice.value)
         self.write('*')
         return
     if isinstance(type_hint, typed_ast3.Subscript):
         _LOG.error('encountered unsupported subscript form: %s',
                    horast.unparse(type_hint).strip())
         self._unsupported_syntax(type_hint)
     if isinstance(type_hint, typed_ast3.Attribute):
         if isinstance(type_hint.value, typed_ast3.Name):
             unparsed = horast.unparse(type_hint).strip()
             self.write(PY_TO_CPP_TYPES[unparsed])
             return
         _LOG.error('encountered unsupported attribute form: %s',
                    horast.unparse(type_hint).strip())
         self._unsupported_syntax(type_hint)
     if isinstance(type_hint, typed_ast3.NameConstant):
         assert type_hint.value is None
         self.write('void')
         return
     self.dispatch(type_hint)
示例#7
0
def unparsing_unsupported(language_name: str, syntax, comment: str = None, error: bool = True):
    unparsed = 'invalid'
    try:
        unparsed = '"""{}"""'.format(horast.unparse(syntax).strip())
    except AttributeError:
        pass
    if comment is not None:
        comment = ' ' + comment
    _LOG.error('unparsing %s%s like """%s""" (%s in Python) is unsupported for %s',
               syntax.__class__.__name__, comment, horast.dump(syntax), unparsed, language_name)
    if error:
        raise SyntaxError(
            'unparsing {}{} like """{}""" ({} in Python) is unsupported for {}'.format(
                syntax.__class__.__name__, comment, horast.dump(syntax), unparsed, language_name))
示例#8
0
 def dispatch_type(self, type_hint):
     _LOG.debug('dispatching type hint %s', type_hint)
     if isinstance(type_hint, typed_ast3.Subscript):
         if isinstance(type_hint.value, typed_ast3.Attribute) \
                 and isinstance(type_hint.value.value, typed_ast3.Name):
             unparsed = horast.unparse(type_hint.value).strip()
             self.write(PY_TO_CPP_TYPES[unparsed])
             if unparsed == 'st.ndarray':
                 self.write('<')
                 sli = type_hint.slice
                 self.write('>')
             return
         self._unsupported_syntax(type_hint)
     if isinstance(type_hint, typed_ast3.Attribute):
         if isinstance(type_hint.value, typed_ast3.Name):
             unparsed = horast.unparse(type_hint).strip()
             self.write(PY_TO_CPP_TYPES[unparsed])
             return
         self._unsupported_syntax(type_hint)
     if isinstance(type_hint, typed_ast3.NameConstant):
         assert type_hint.value is None
         self.write('void')
         return
     self.dispatch(type_hint)
 def test_inline_syntax(self):
     language = Language.find('Python 3')
     parser = Parser.find(language)()
     for (target, inlined), target_inlined in INLINING_EXAMPLES.items():
         target_code = CodeReader.read_function(target)
         inlined_code = CodeReader.read_function(inlined)
         reference_code = CodeReader.read_function(target_inlined)
         target_syntax = parser.parse(target_code).body[0]
         inlined_syntax = parser.parse(inlined_code).body[0]
         with self.subTest(target=target, inlined=inlined):
             target_inlined_syntax = inline_syntax(target_syntax, inlined_syntax, verbose=False)
             target_inlined_code = horast.unparse(target_inlined_syntax)
             _LOG.warning('%s', target_inlined_code)
             self.assertEqual(reference_code.replace('_inlined(', '(').lstrip(),
                              target_inlined_code.lstrip())
示例#10
0
def transform_code(code: str, xformer: "ASTMigrator") -> str:
    """Apply a transformer to a given chunk of source code

    This will parse the code using the AST and find the expressions that are interesting according to xformer.

    If those are found the resulting statements will be rewritten and merged into the final source code
    """

    line_ends = list(accumulate([len(x) for x in code.splitlines(keepends=True)]))
    line_starts = [0] + [x for x in line_ends[:-1]]

    try:
        tree = horast.parse(code)
    except Exception as e_horast:
        # fallback to regular typed ast
        try:
            tree = ast.parse(code)
        except Exception as e:
            raise CantParseException(str(e), code)

    matched = list(xformer.scan_ast(tree))

    astmonkey.transformers.ParentChildNodeTransformer().visit(tree)

    def node_to_code_offset(node, use_col_offset=True):
        return line_starts[node.lineno - 1] + use_col_offset * node.col_offset

    # Replace the matched patterns in reverse line order
    for match in sorted(
        matched, key=lambda node: (node.lineno, node.col_offset), reverse=True
    ):
        xformer.transform_match(match)

        parent_statement = find_parent_statement(match)
        next_statement = find_next_sibling(parent_statement)

        code_start = node_to_code_offset(parent_statement)
        if next_statement:
            code_end = node_to_code_offset(next_statement, use_col_offset=False)
        else:
            code_end = len(code)

        new_code = horast.unparse(parent_statement)
        new_code = new_code.strip()

        code = code[:code_start] + new_code + "\n" + code[code_end:]

    return code
示例#11
0
 def _Subscript(self, t):
     val = t.value
     unparsed_val = horast.unparse(val).strip()
     if unparsed_val in PYTHON_FORTRAN_INTRINSICS:
         new_val = PYTHON_FORTRAN_INTRINSICS[unparsed_val]
         if isinstance(new_val, collections.abc.Callable):
             self.dispatch(new_val(t))
             return
         t = copy.copy(t)
         t.value = typed_ast3.Name(id=new_val)
     if isinstance(val, typed_ast3.Attribute) and isinstance(val.value, typed_ast3.Name) \
             and val.value.id == 'Fortran':
         attr = val.attr
         if attr == 'file_handles':
             self.dispatch(t.slice)
         elif attr == 'TypeByNamePrefix':
             base_type, letter_ranges = t.slice.value.elts
             assert isinstance(letter_ranges,
                               typed_ast3.Tuple), type(letter_ranges)
             # _LOG.warning('%s', type(letter_ranges))
             # assert False, (type(letter_ranges), letter_ranges)
             self.dispatch_var_type(base_type)
             self.write(' (')
             interleave(lambda: self.write(', '), lambda _: _.s[1:-1],
                        letter_ranges.elts)
             self.write(')')
         else:
             raise NotImplementedError(
                 'Fortran.{}[] cannot be handled yet.'.format(attr))
         return
     self.dispatch(t.value)
     self.write("(")
     self.dispatch(t.slice)
     # if isinstance(t.slice, typed_ast3.Index):
     # elif isinstance(t.slice, typed_ast3.Slice):
     #    raise NotImplementedError('not yet implemented: {}'.format(typed_astunparse.dump(t)))
     # elif isinstance(t.slice, typed_ast3.ExtSlice):
     #    raise NotImplementedError('not yet implemented: {}'.format(typed_astunparse.dump(t)))
     # else:
     #    raise ValueError()
     self.write(")")
示例#12
0
 def unparse(self, tree: typed_ast.ast3.AST) -> str:
     code = horast.unparse(tree)
     return code
示例#13
0
    def _Call(self, t):
        if getattr(t, 'fortran_metadata', {}).get('is_procedure_call', False):
            self.write('call ')
        func_name = horast.unparse(t.func).strip()
        if has_annotation(t, 'is_mpi_call'):
            self.write('call ')
            # val = last_attribute_value(t.func)
            # fname = t.func.attr
            # assert isinstance(val, typed_ast3.Name), type(val)
            func_name = MPI_PYTHON_TO_FORTRAN[t.func.attr]
            # attribute = t.func
            # while isinstance(attribute.value, typed_ast3.Attribute):
            #     attribute = attribute.value
            components = attribute_chain_components(t.func)[:-1]
            t.func = typed_ast3.Name(func_name, typed_ast3.Load())
            t.args += components
            # print()
            # self.write(func_name)
            # raise
            # raise NotImplementedError(self.f.getvalue())

        elif func_name.startswith('Fortran.file_handles['):
            t = copy.copy(t)
            for suffix in ('read', 'close'):
                if func_name.endswith('].{}'.format(suffix)):
                    t.args.insert(0, t.func.value.slice.value)
                    t.func = typed_ast3.Name(id=suffix, ctx=typed_ast3.Load())
                    break
            # if func_name.endswith('].read'):
            #    t.func = typed_ast3.Name(id='read', ctx=typed_ast3.Load())
            # elif func_name.endswith('].close'):
            #    t.func = typed_ast3.Name(id='close', ctx=typed_ast3.Load())
            else:
                raise NotImplementedError(func_name)
        elif func_name.endswith('.format'):
            t = copy.copy(t)
            prefix, _, label = t.func.value.id.rpartition('_')
            assert prefix == 'format_label', prefix
            self.write(label)
            self.write(' ')
            t.func = typed_ast3.Name(id='format', ctx=typed_ast3.Load())
        elif func_name.endswith('.rstrip'):
            t = copy.copy(t)
            t.args.insert(0, t.func.value)
            t.func = typed_ast3.Name(id='trim', ctx=typed_ast3.Load())
        elif func_name.endswith('.sum'):
            t = copy.copy(t)
            t.args.insert(0, t.func.value)
            t.func = typed_ast3.Name(id='count', ctx=typed_ast3.Load())
        elif func_name.endswith('.size'):
            _LOG.warning('assuming np.size()')
            t = copy.copy(t)
            t.args.insert(0, t.func.value)
            t.func = typed_ast3.Name(id='size', ctx=typed_ast3.Load())
        elif func_name.endswith('.shape'):
            _LOG.warning('assuming np.shape()')
            t = copy.copy(t)
            t.args[0].n += 1
            t.args.insert(0, t.func.value)
            t.func = typed_ast3.Name(id='size', ctx=typed_ast3.Load())
        elif func_name in PYTHON_FORTRAN_INTRINSICS \
                and not getattr(t, 'fortran_metadata', {}).get('is_transformed', False):
            new_func = PYTHON_FORTRAN_INTRINSICS[func_name]
            if isinstance(new_func, collections.abc.Callable):
                self.dispatch(new_func(t))
                return
            t = copy.copy(t)
            t.func = typed_ast3.Name(id=new_func, ctx=typed_ast3.Load())
        elif func_name.startswith('np.'):
            raise NotImplementedError('not yet implemented: {}'.format(typed_astunparse.dump(t)))
        if func_name not in ('print',):
            super()._Call(t)
            return

        self.dispatch(t.func)
        self.write(' ')
        comma = False
        for arg in itertools.chain(t.args, t.keywords):
            if comma:
                self.write(", ")
            else:
                comma = True
            self.dispatch(arg)