Exemple #1
0
    def _visit_CommentLine(self, stmt):

        # if annotated comment

        if stmt.s.startswith('#$'):
            env = stmt.s[2:].lstrip()
            if env.startswith('omp'):
                return omp_parse(stmts=stmt.s)
            elif env.startswith('acc'):
                return acc_parse(stmts=stmt.s)
            elif env.startswith('header'):
                expr = hdr_parse(stmts=stmt.s)
                if isinstance(expr, MetaVariable):

                    # a metavar will not appear in the semantic stage.
                    # but can be used to modify the ast

                    self._metavars[str(expr.name)] = str(expr.value)
                    expr = EmptyNode()
                else:
                    expr.set_fst(stmt)

                return expr
            else:

                errors.report(PYCCEL_INVALID_HEADER,
                              symbol=stmt,
                              severity='error')

        else:
            txt = stmt.s[1:].lstrip()
            return Comment(txt)
Exemple #2
0
    def _visit_CommentMultiLine(self, stmt):

        exprs = []
        # if annotated comment
        for com in stmt.s.split('\n'):
            if com.startswith('#$'):
                env = com[2:].lstrip()
                if env.startswith('omp'):
                    exprs.append(omp_parse(stmts=com))
                elif env.startswith('acc'):
                    exprs.append(acc_parse(stmts=com))
                elif env.startswith('header'):
                    expr = hdr_parse(stmts=com)
                    if isinstance(expr, MetaVariable):

                        # a metavar will not appear in the semantic stage.
                        # but can be used to modify the ast

                        self._metavars[str(expr.name)] = str(expr.value)
                        expr = EmptyNode()
                    else:
                        expr.set_fst(stmt)

                    exprs.append(expr)
                else:
                    errors.report(PYCCEL_INVALID_HEADER,
                                  symbol=stmt,
                                  severity='error')
            else:

                txt = com[1:].lstrip()
                exprs.append(Comment(txt))

        if len(exprs) == 1:
            return exprs[0]
        else:
            return CodeBlock(exprs)
Exemple #3
0
    def _visit_FunctionDef(self, stmt):

        #  TODO check all inputs and which ones should be treated in stage 1 or 2

        name = self._visit(stmt.name)
        name = name.replace("'", '')

        arguments = self._visit(stmt.args)

        local_vars = []
        global_vars = []
        headers = []
        templates = {}
        hide = False
        kind = 'function'
        is_pure = False
        is_elemental = False
        is_private = False
        imports = []

        def fill_types(ls):
            container = []
            for arg in ls:
                if isinstance(arg, Symbol):
                    arg = arg.name
                    container.append(arg)
                elif isinstance(arg, LiteralString):
                    arg = str(arg)
                    arg = arg.strip("'").strip('"')
                    container.append(arg)
                else:
                    msg = 'Invalid argument of type {} passed to types decorator'.format(
                        type(arg))
                    errors.report(msg,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='error')
            return container

        decorators = {}
        for d in self._visit(stmt.decorator_list):
            tmp_var = str(d) if isinstance(d, Symbol) else str(type(d))
            if tmp_var in decorators:
                decorators[tmp_var] += [d]
            else:
                decorators[tmp_var] = [d]

        if 'bypass' in decorators:
            return EmptyNode()

        if 'stack_array' in decorators:
            decorators['stack_array'] = tuple(
                str(b) for a in decorators['stack_array'] for b in a.args)

        if 'allow_negative_index' in decorators:
            decorators['allow_negative_index'] = tuple(
                str(b) for a in decorators['allow_negative_index']
                for b in a.args)

        # extract the templates
        if 'template' in decorators:
            for comb_types in decorators['template']:
                cache.clear_cache()
                types = []
                if len(comb_types.args) != 2:
                    msg = 'Number of Arguments provided to the template decorator is not valid'
                    errors.report(msg,
                                  symbol=comb_types,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='error')

                for i in comb_types.args:
                    if isinstance(i,
                                  ValuedArgument) and not i.name in ('name',
                                                                     'types'):
                        msg = 'Argument provided to the template decorator is not valid'
                        errors.report(msg,
                                      symbol=comb_types,
                                      bounding_box=(stmt.lineno,
                                                    stmt.col_offset),
                                      severity='error')
                if all(isinstance(i, ValuedArgument) for i in comb_types.args):
                    tp_name, ls = (comb_types.args[0].value, comb_types.args[1].value) if\
                            comb_types.args[0].name == 'name' else\
                            (comb_types.args[1].value, comb_types.args[0].value)
                else:
                    tp_name = comb_types.args[0]
                    ls = comb_types.args[1]
                    ls = ls.value if isinstance(ls, ValuedArgument) else ls
                try:
                    tp_name = str(tp_name)
                    ls = ls if isinstance(ls, PythonTuple) else list(ls)
                except TypeError:
                    msg = 'Argument provided to the template decorator is not valid'
                    errors.report(msg,
                                  symbol=comb_types,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='fatal')

                types = fill_types(ls)

                txt = '#$ header template ' + str(tp_name)
                txt += '(' + '|'.join(types) + ')'
                if tp_name in templates:
                    msg = 'The template "{}" is duplicated'.format(tp_name)
                    errors.report(msg,
                                  bounding_box=(stmt.lineno, stmt.col_offset),
                                  severity='warning')

                templates[tp_name] = hdr_parse(stmts=txt)

        # extract the types to construct a header
        if 'types' in decorators:
            for comb_types in decorators['types']:

                cache.clear_cache()
                results = []
                ls = comb_types.args

                if len(ls) > 0 and isinstance(ls[-1], ValuedArgument):
                    arg_name = ls[-1].name
                    if not arg_name == 'results':
                        msg = 'Argument "{}" provided to the types decorator is not valid'.format(
                            arg_name)
                        errors.report(msg,
                                      symbol=comb_types,
                                      bounding_box=(stmt.lineno,
                                                    stmt.col_offset),
                                      severity='error')
                    else:
                        container = ls[-1].value
                        container = container if isinstance(
                            container, PythonTuple) else [container]
                        results = fill_types(container)
                    types = fill_types(ls[:-1])
                else:
                    types = fill_types(ls)

                txt = '#$ header ' + name
                txt += '(' + ','.join(types) + ')'

                if results:
                    txt += ' results(' + ','.join(results) + ')'

                header = hdr_parse(stmts=txt)
                if name in self.namespace.static_functions:
                    header = header.to_static()
                headers += [header]

        body = stmt.body

        if 'sympy' in decorators.keys():
            # TODO maybe we should run pylint here
            stmt.decorators.pop()
            func = SympyFunction(name, arguments, [], [stmt.__str__()])
            func.set_fst(stmt)
            self.insert_function(func)
            return EmptyNode()

        elif 'python' in decorators.keys():

            # TODO maybe we should run pylint here

            stmt.decorators.pop()
            func = PythonFunction(name, arguments, [], [stmt.__str__()])
            func.set_fst(stmt)
            self.insert_function(func)
            return EmptyNode()

        else:
            body = self._visit(body)

        if 'pure' in decorators.keys():
            is_pure = True

        if 'elemental' in decorators.keys():
            is_elemental = True
            if len(arguments) > 1:
                errors.report(FORTRAN_ELEMENTAL_SINGLE_ARGUMENT,
                              symbol=decorators['elemental'],
                              bounding_box=(stmt.lineno, stmt.col_offset),
                              severity='error')

        if 'private' in decorators.keys():
            is_private = True

        returns = [i.expr for i in _atomic(body, cls=Return)]
        assert all(len(i) == len(returns[0]) for i in returns)
        results = []
        result_counter = 1
        for i in zip(*returns):
            if not all(i[0] == j for j in i) or not isinstance(i[0], Symbol):
                result_name, result_counter = create_variable(
                    self._used_names, prefix='Out', counter=result_counter)
                results.append(result_name)
            elif isinstance(i[0], Symbol) and any(i[0].name == x.name
                                                  for x in arguments):
                result_name, result_counter = create_variable(
                    self._used_names, prefix='Out', counter=result_counter)
                results.append(result_name)
            else:
                results.append(i[0])

        func = FunctionDef(name,
                           arguments,
                           results,
                           body,
                           local_vars=local_vars,
                           global_vars=global_vars,
                           hide=hide,
                           kind=kind,
                           is_pure=is_pure,
                           is_elemental=is_elemental,
                           is_private=is_private,
                           imports=imports,
                           decorators=decorators,
                           headers=headers,
                           templates=templates)

        func.set_fst(stmt)
        return func