def visit_Operator(self, node: Operator, *args, **kwargs) -> C.operator:
     if node == Operator.Add:
         return C.Add()
     elif node == Operator.Sub:
         return C.Sub()
     elif node == Operator.Mult:
         return C.Mult()
     elif node == Operator.MatMult:
         return C.MatMult()
     elif node == Operator.Div:
         return C.Div()
     elif node == Operator.Mod:
         return C.Mod()
     elif node == Operator.Pow:
         return C.Pow()
     elif node == Operator.LShift:
         return C.LShift()
     elif node == Operator.RShift:
         return C.RShift()
     elif node == Operator.BitOr:
         return C.BitOr()
     elif node == Operator.BitXor:
         return C.BitXor()
     elif node == Operator.BitAnd:
         return C.BitAnd()
     elif node == Operator.FloorDiv:
         return C.FloorDiv()
     else:
         raise Exception(f'unknown Operator {node!r}')
示例#2
0
 def to_node(self):
     node = self.left.to_node()
     if len(self.operator) == 0:
         return node
     else:
         for i in range(len(self.right)):
             if self.operator[i] in ['*', 'times']:
                 node = ast.BinOp(node, ast.Mult(), self.right[i].to_node())
             elif self.operator[i] in ['/', 'divided by']:
                 node = ast.BinOp(node, ast.Div(), self.right[i].to_node())
             elif self.operator[i] in ['%', 'modulo']:
                 node = ast.BinOp(node, ast.Mod(), self.right[i].to_node())
             elif self.operator[i] == '@':
                 node = ast.BinOp(node, ast.MatMult(),
                                  self.right[i].to_node())
             elif self.operator[i] == '//':
                 node = ast.BinOp(node, ast.FloorDiv(),
                                  self.right[i].to_node())
         return node
示例#3
0
def AugAssign(draw):
    op = draw(
        sampled_from([
            ast.Add(),
            ast.Sub(),
            ast.Mult(),
            ast.Div(),
            ast.FloorDiv(),
            ast.Mod(),
            ast.Pow(),
            ast.LShift(),
            ast.RShift(),
            ast.BitOr(),
            ast.BitXor(),
            ast.BitOr(),
            ast.BitAnd(),
            ast.MatMult()
        ]))

    return ast.AugAssign(target=draw(Name(ast.Store)),
                         op=op,
                         value=draw(expression()))
示例#4
0
def BinOp(draw, expression) -> ast.BinOp:
    op = draw(
        sampled_from([
            ast.Add(),
            ast.Sub(),
            ast.Mult(),
            ast.Div(),
            ast.FloorDiv(),
            ast.Mod(),
            ast.Pow(),
            ast.LShift(),
            ast.RShift(),
            ast.BitOr(),
            ast.BitXor(),
            ast.BitOr(),
            ast.BitAnd(),
            ast.MatMult()
        ]))

    le = draw(lists(expression, min_size=2, max_size=2))

    return ast.BinOp(le[0], op, le[1])
示例#5
0
 def p_term_op5(self, p):
     ''' term_op : AT factor '''
     p[0] = [ast.MatMult(), p[2]]
示例#6
0
class Python35EnamlParser(Python34EnamlParser):
    """Enaml parser supporting Python 3.5 syntax.

    Main differences from base parser are :

    - support for matmult syntax
    - support for async/await syntax

    Notes
    -----
    Because the lexer turn await and async into names outside of async def
    blocks we do not need to check that async for, async with and await are
    used in the proper places. (will break for 3.7)

    """
    parser_id = '35'

    lexer = Python35EnamlLexer

    augassign_table = dict(
        list(Python34EnamlParser.augassign_table.items()) +
        [('@=', ast.MatMult())])

    _NOTIFICATION_DISALLOWED =\
        dict(list(Python34EnamlParser._NOTIFICATION_DISALLOWED.items()) +
             [(ast.AsyncFunctionDef, 'async function definition')])

    _DECL_FUNCDEF_DISALLOWED =\
        dict(list(Python34EnamlParser._DECL_FUNCDEF_DISALLOWED.items()) +
             [(ast.AsyncFunctionDef, 'async function definition')])

    def set_call_arguments(self, node, args):
        """Set the arguments for an ast.Call node.

        On Python 3.5+, the starargs and kwargs attributes does not exists
        anymore.

        Parameters
        ----------
        node : ast.Call
            Node was arguments should be set.

        args : Arguments
            Arguments for the function call.

        """
        pos_args = args.args
        if args.starargs:
            pos_args += [ast.Starred(value=args.starargs, ctx=Load)]
        key_args = args.keywords
        if args.kwargs:
            key_args += [ast.keyword(arg=None, value=args.kwargs)]
        node.args = pos_args
        node.keywords = key_args

    def p_test_or_star_new2(self, p):
        ''' test_or_star_new : star_expr '''
        p[0] = p[1]

    def p_augassign(self, p):
        ''' augassign : AMPEREQUAL
                      | CIRCUMFLEXEQUAL
                      | DOUBLESLASHEQUAL
                      | DOUBLESTAREQUAL
                      | LEFTSHIFTEQUAL
                      | MINUSEQUAL
                      | PERCENTEQUAL
                      | PLUSEQUAL
                      | RIGHTSHIFTEQUAL
                      | SLASHEQUAL
                      | STAREQUAL
                      | VBAREQUAL
                      | ATEQUAL '''
        super(Python35EnamlParser, self).p_augassign(p)

    def p_term_op5(self, p):
        ''' term_op : AT factor '''
        p[0] = [ast.MatMult(), p[2]]

    def p_dosm_colon(self, p):
        ''' dosm_colon : DOUBLESTAR expr '''
        p[0] = (None, p[2])

    def p_compound_stmt(self, p):
        ''' compound_stmt : if_stmt
                          | while_stmt
                          | for_stmt
                          | try_stmt
                          | with_stmt
                          | funcdef
                          | classdef
                          | decorated
                          | async_funcdef
                          | async_for_stmt
                          | async_with_stmt '''
        super(Python35EnamlParser, self).p_compound_stmt(p)

    def p_decorated(self, p):
        ''' decorated : decorators funcdef
                      | decorators classdef
                      | decorators async_funcdef'''
        decs = p[1]
        target = p[2]
        target.decorator_list = decs
        p[0] = target

    def p_async_funcdef1(self, p):
        ''' async_funcdef : ASYNC funcdef '''
        async_funcdef = ast.AsyncFunctionDef()
        funcdef = p[2]
        for attr in tuple(funcdef._fields) + ('lineno', 'col_offset'):
            setattr(async_funcdef, attr, getattr(funcdef, attr))
        p[0] = async_funcdef

    def p_async_for_stmt(self, p):
        ''' async_for_stmt : ASYNC for_stmt '''
        async_for = ast.AsyncFor()
        for_node = p[2]
        for attr in tuple(for_node._fields) + ('lineno', 'col_offset'):
            setattr(async_for, attr, getattr(for_node, attr))
        p[0] = async_for

    def p_async_with_stmt(self, p):
        ''' async_with_stmt : ASYNC with_stmt '''
        async_with = ast.AsyncWith()
        with_node = p[2]
        for attr in tuple(with_node._fields) + ('lineno', 'col_offset'):
            setattr(async_with, attr, getattr(with_node, attr))
        p[0] = async_with

    def p_atom_expr3(self, p):
        ''' atom_expr : AWAIT atom '''
        p[0] = ast.Await(value=p[2])

    def p_atom_expr4(self, p):
        ''' atom_expr : AWAIT atom trailer_list '''
        root = p[2]
        for node in p[3]:
            if isinstance(node, ast.Call):
                node.func = root
            elif isinstance(node, ast.Attribute):
                node.value = root
            elif isinstance(node, ast.Subscript):
                node.value = root
            else:
                raise TypeError('Unexpected trailer node: %s' % node)
            root = node
        p[0] = ast.Await(value=node)

    def p_enamldef_suite_item(self, p):
        ''' enamldef_suite_item : enamldef_simple_item
                                | decl_funcdef
                                | async_decl_funcdef
                                | child_def
                                | template_inst '''
        p[0] = p[1]

    def p_child_def_suite_item(self, p):
        ''' child_def_suite_item : child_def_simple_item
                                 | decl_funcdef
                                 | async_decl_funcdef
                                 | child_def
                                 | template_inst '''
        p[0] = p[1]

    def p_async_decl_funcdef(self, p):
        ''' async_decl_funcdef : ASYNC decl_funcdef '''
        decl_funcdef = p[2]
        funcdef = decl_funcdef.funcdef
        async_funcdef = ast.AsyncFunctionDef()
        for attr in tuple(funcdef._fields) + ('lineno', 'col_offset'):
            setattr(async_funcdef, attr, getattr(funcdef, attr))
        ast.fix_missing_locations(async_funcdef)
        # Skip validate because the original function was already validated
        async_decl_funcdef = enaml_ast.AsyncFuncDef()
        async_decl_funcdef.funcdef = async_funcdef
        async_decl_funcdef.lineno = decl_funcdef.lineno
        async_decl_funcdef.is_override = decl_funcdef.is_override
        p[0] = async_decl_funcdef
示例#7
0
class PythonASTWriter(ASTProcessor, StandardOps):
    """
    An AST processor, which translates a given SKAST node to a Python AST.

    >>> import ast
    >>> topy = PythonASTWriter()
    >>> print(ast.dump(topy(Identifier('x'))))
    Name(id='x', ctx=Load())
    """
    def Identifier(self, name):
        return ast.Name(id=name.id, ctx=ast.Load(), **_linearg)

    VectorIdentifier = Identifier

    def IndexedIdentifier(self, sub):
        return ast.Subscript(
            value=self(Identifier(sub.id)),
            slice=ast.Index(value=self(NumberConstant(sub.index))),
            ctx=ast.Load(),
            **_linearg)

    def NumberConstant(self, num):
        return ast.Num(n=denumpyfy(num.value), **_linearg)

    def VectorConstant(self, vec):
        result = ast.parse('__np__.array()', mode='eval').body
        result.args = [
            ast.List(elts=[
                ast.Num(n=denumpyfy(el), **_linearg) for el in vec.value
            ],
                     ctx=ast.Load(),
                     **_linearg)
        ]
        return result

    def MakeVector(self, mv):
        result = ast.parse('__np__.array()', mode='eval').body
        result.args = [
            ast.List(elts=[self(el) for el in mv.elems],
                     ctx=ast.Load(),
                     **_linearg)
        ]
        return result

    def MatrixConstant(self, mat):
        result = ast.parse('__np__.array()', mode='eval').body
        result.args = [
            ast.List(elts=[
                ast.List(
                    elts=[ast.Num(n=denumpyfy(el), **_linearg) for el in row],
                    ctx=ast.Load(),
                    **_linearg) for row in mat.value
            ],
                     ctx=ast.Load(),
                     **_linearg)
        ]
        return result

    def UnaryFunc(self, node, **kw):
        if isinstance(node.op, USub):
            return ast.UnaryOp(op=self(node.op),
                               operand=self(node.arg),
                               **_linearg)
        else:
            return ast.Call(func=self(node.op),
                            args=[self(node.arg)],
                            keywords=[],
                            **_linearg)

    def BinOp(self, node, **kw):
        op, left, right = self(node.op), self(node.left), self(node.right)
        if isinstance(node.op, IsBoolean):
            return ast.Compare(left=left,
                               ops=[op],
                               comparators=[right],
                               **_linearg)
        elif isinstance(node.op, Max):
            return ast.Call(func=self(node.op),
                            args=[self(node.left),
                                  self(node.right)],
                            keywords=[],
                            **_linearg)
        else:
            return ast.BinOp(op=op, left=left, right=right, **_linearg)

    def IfThenElse(self, node):
        return ast.IfExp(test=self(node.test),
                         body=self(node.iftrue),
                         orelse=self(node.iffalse),
                         **_linearg)

    def Let(self, node, **kw):
        code = [
            ast.Assign(targets=[
                ast.Name(id='_def_' + defn.name, ctx=ast.Store(), **_linearg)
            ],
                       value=self(defn.body),
                       **_linearg) for defn in node.defs
        ]
        # Evaluate the expression body into a "__result__" variable
        code.append(
            ast.Assign(targets=[
                ast.Name(id='__result__', ctx=ast.Store(), **_linearg)
            ],
                       value=self(node.body),
                       **_linearg))
        return ast.Module(body=code,
                          **({} if sys.version < '3.8' else {
                              "type_ignores": []
                          }))

    def Reference(self, ref):
        return ast.Name(id='_def_' + ref.name, ctx=ast.Load(), **_linearg)

    TypedReference = Reference

    # Functions
    Exp = _ident('__exp__')
    Sqrt = _ident('__sqrt__')
    Log = _ident('__log__')
    Step = _ident('__step__')
    VecSum = _ident('__sum__')
    ArgMax = _ident('__argmax__')
    Sigmoid = _ident('__sigmoid__')
    Softmax = _ident('__softmax__')
    VecMax = _ident('__vecmax__')
    Max = _ident('__max__')
    Abs = _ident('__abs__')

    # Operators
    Mul = _is(ast.Mult())
    Div = _is(ast.Div())
    Add = _is(ast.Add())
    Sub = _is(ast.Sub())
    USub = _is(ast.USub())
    DotProduct = _is(ast.MatMult())
    MatVecProduct = DotProduct

    # Predicates
    LtEq = _is(ast.LtE())
    Eq = _is(ast.Eq())
示例#8
0
 def test_get_mat_mult(self):
     self.assertEqual('@', astor.get_op_symbol(ast.MatMult()))
示例#9
0
 def __matmul__(self, other: Union[T, Expr[T]]) -> BinOp[T]:
     return BinOp(self, ast.MatMult(), other)
示例#10
0
    def _create_assign_lambda(s, o, lamb):
        assert isinstance(
            o, Signal
        ), "You can only assign(//=) a lambda function to a Wire/InPort/OutPort."

        srcs, line = inspect.getsourcelines(lamb)

        src = compiled_re.sub(r'\2', ''.join(srcs)).lstrip(' ')
        root = ast.parse(src)
        assert isinstance(root, ast.Module) and len(
            root.body) == 1, "We only support single-statement lambda."

        root = root.body[0]
        assert isinstance(root, ast.AugAssign) and isinstance(
            root.op, ast.FloorDiv)

        # lhs, rhs = root.target, root.value
        # Shunning: here we need to use ast from repr(o), because root.target
        # can be "m.in_" in some cases where we actually know what m is but the
        # source code still captures "m"
        lhs, rhs = ast.parse(
            f"s{repr(o)[len(repr(s)):]}").body[0].value, root.value
        lhs.ctx = ast.Store()
        # We expect the lambda to have no argument:
        # {'args': [], 'vararg': None, 'kwonlyargs': [], 'kw_defaults': [], 'kwarg': None, 'defaults': []}
        assert isinstance( rhs, ast.Lambda ) and not rhs.args.args and rhs.args.vararg is None, \
          "The lambda shouldn't contain any argument."

        rhs = rhs.body

        # Compose a new and valid function based on the lambda's lhs and rhs
        # Note that we don't need to add those source code of closure var
        # assignment to linecache. To get the matching line number in the
        # error message, we set the line number of update block
        # Shunning: bugfix:

        blk_name = "_lambda__{}".format(
            repr(o).replace(".",
                            "_").replace("[",
                                         "_").replace("]",
                                                      "_").replace(":", "_"))
        lambda_upblk = ast.FunctionDef(
            name=blk_name,
            args=ast.arguments(args=[],
                               vararg=None,
                               kwonlyargs=[],
                               kw_defaults=[],
                               kwarg=None,
                               defaults=[]),
            body=[
                ast.AugAssign(target=lhs,
                              op=ast.MatMult(),
                              value=rhs,
                              lineno=2,
                              col_offset=6)
            ],
            decorator_list=[],
            returns=None,
            lineno=1,
            col_offset=4,
        )
        lambda_upblk_module = ast.Module(body=[lambda_upblk])

        # Manually wrap the lambda upblk with a closure function that adds the
        # desired variables to the closure of `_lambda__*`
        # We construct AST for the following function to add free variables in the
        # closure of the lambda function to the closure of the generated lambda
        # update block.
        #
        # def closure( lambda_closure ):
        #   <FreeVarName1> = lambda_closure[<Idx1>].cell_contents
        #   <FreeVarName2> = lambda_closure[<Idx2>].cell_contents
        #   ...
        #   <FreeVarNameN> = lambda_closure[<IdxN>].cell_contents
        #   def _lambda__<lambda_blk_name>():
        #     # the assignment statement appears here
        #   return _lambda__<lambda_blk_name>

        new_root = ast.Module(body=[
            ast.FunctionDef(
                name="closure",
                args=ast.arguments(args=[
                    ast.arg(arg="lambda_closure",
                            annotation=None,
                            lineno=1,
                            col_offset=12)
                ],
                                   vararg=None,
                                   kwonlyargs=[],
                                   kw_defaults=[],
                                   kwarg=None,
                                   defaults=[]),
                body=[
                    ast.Assign(
                        targets=[
                            ast.Name(id=var,
                                     ctx=ast.Store(),
                                     lineno=1 + idx,
                                     col_offset=2)
                        ],
                        value=ast.Attribute(
                            value=ast.Subscript(
                                value=ast.Name(
                                    id='lambda_closure',
                                    ctx=ast.Load(),
                                    lineno=1 + idx,
                                    col_offset=5 + len(var),
                                ),
                                slice=ast.Index(value=ast.Num(
                                    n=idx,
                                    lineno=1 + idx,
                                    col_offset=19 + len(var),
                                ), ),
                                ctx=ast.Load(),
                                lineno=1 + idx,
                                col_offset=5 + len(var),
                            ),
                            attr='cell_contents',
                            ctx=ast.Load(),
                            lineno=1 + idx,
                            col_offset=5 + len(var),
                        ),
                        lineno=1 + idx,
                        col_offset=2,
                    ) for idx, var in enumerate(lamb.__code__.co_freevars)
                ] + [lambda_upblk] + [
                    ast.Return(
                        value=ast.Name(
                            id=blk_name,
                            ctx=ast.Load(),
                            lineno=4 + len(lamb.__code__.co_freevars),
                            col_offset=9,
                        ),
                        lineno=4 + len(lamb.__code__.co_freevars),
                        col_offset=2,
                    )
                ],
                decorator_list=[],
                returns=None,
                lineno=1,
                col_offset=0,
            )
        ])

        # In Python 3 we need to supply a dict as local to get the newly
        # compiled function from closure.
        # Then `closure(lamb.__closure__)` returns the lambda update block with
        # the correct free variables in its closure.

        dict_local = {}
        custom_exec(compile(new_root, blk_name, "exec"), lamb.__globals__,
                    dict_local)
        blk = dict_local['closure'](lamb.__closure__)

        # Add the source code to linecache for the compiled function

        new_src = "def {}():\n {}\n".format(blk_name, src.replace("//=", "@="))
        linecache.cache[blk_name] = (len(new_src), None, new_src.splitlines(),
                                     blk_name)

        ComponentLevel1._update(s, blk)

        # This caching here does no caching because the block name contains
        # the signal name intentionally to avoid conflicts. With //= it is
        # more possible than normal update block to have conflicts:
        # if param == 1:  s.out //= s.in_ + 1
        # else:           s.out //= s.out + 100
        # Here these two blocks will implicity have the same name but they
        # have different contents based on different param.
        # So the cache call here is just to reuse the existing interface to
        # register the AST/src of the generated block for elaborate or passes
        # to use.
        s._cache_func_meta(blk,
                           is_update_ff=False,
                           given=("".join(srcs), lambda_upblk_module, line,
                                  inspect.getsourcefile(lamb)))
        return blk