Beispiel #1
0
 def get_arp_for_context(self, node: ast.AST, ctx: Context):
     if (ctx.actual_function
             and isinstance(ctx.actual_function, PythonMethod)
             and ctx.actual_function.pure):
         return self.viper.WildcardPerm(self.to_position(node, ctx),
                                        self.no_info(ctx))
     else:
         if not ctx.arp:
             raise UnsupportedException(
                 node, 'ARP not supported. Use --arp flag.')
         if ctx.current_thread_object is not None:
             formal_arg = self.viper.LocalVarDecl(
                 'tk', self.viper.Ref, self.to_position(node, ctx),
                 self.no_info(ctx))
             if ctx.is_thread_start:
                 return self.viper.FuncApp('rd_token_fresh',
                                           [ctx.current_thread_object],
                                           self.to_position(node, ctx),
                                           self.no_info(ctx),
                                           self.viper.Perm, [formal_arg])
             else:
                 return self.viper.FuncApp('rd_token',
                                           [ctx.current_thread_object],
                                           self.to_position(node, ctx),
                                           self.no_info(ctx),
                                           self.viper.Perm, [formal_arg])
         else:
             return self.viper.FuncApp('rd', [],
                                       self.to_position(node, ctx),
                                       self.no_info(ctx), self.viper.Perm,
                                       {})
Beispiel #2
0
 def translate_builtin_predicate(self, node: ast.Call, perm: Expr,
                                 args: List[Expr], ctx: Context) -> Expr:
     name = node.func.id
     seq_ref = self.viper.SeqType(self.viper.Ref)
     set_ref = self.viper.SetType(self.viper.Ref)
     pos = self.to_position(node, ctx)
     if name == 'list_pred':
         # field list_acc : Seq[Ref]
         return self._get_field_perm('list_acc', seq_ref, perm, args[0],
                                     pos, ctx)
     elif name == 'set_pred':
         # field set_acc : Set[Ref]
         return self._get_field_perm('set_acc', set_ref, perm, args[0], pos,
                                     ctx)
     elif name == 'dict_pred':
         # field dict_acc : Set[Ref] && dict_acc2 : Ref
         acc1 = self._get_field_perm('dict_acc', set_ref, perm, args[0],
                                     pos, ctx)
         acc2 = self._get_field_perm('dict_acc2', self.viper.Ref, perm,
                                     args[0], pos, ctx)
         return self.viper.And(acc1, acc2, pos, self.no_info(ctx))
     elif name == 'MayStart':
         return self.translate_may_start(node, args, perm, ctx)
     elif name == 'ThreadPost':
         return self.translate_thread_post(node, args, perm, ctx)
     else:
         raise UnsupportedException(node)
Beispiel #3
0
    def visit_Call(self, node: ast.Call) -> None:
        """Parse IO operation properties.

        Currently, only parses properties such as ``Terminates`` and
        ``TerminationMeasure``.
        """
        assert self._current_io_operation is not None
        assert self._current_node is not None

        if (isinstance(node.func, ast.Name)
                and node.func.id in IO_OPERATION_PROPERTY_FUNCS):

            for child in self._current_node.body:
                if (isinstance(child, ast.Expr) and child.value == node):
                    break
            else:
                self._raise_invalid_operation('misplaced_property', node)

            operation = self._current_io_operation
            arg = node.args[0]
            if node.func.id == 'Terminates':
                if not operation.set_terminates(arg):
                    self._raise_invalid_operation('duplicate_property', node)
            elif node.func.id == 'TerminationMeasure':
                if not operation.set_termination_measure(arg):
                    self._raise_invalid_operation('duplicate_property', node)
            else:
                raise UnsupportedException(node, 'Unsupported property type.')
            self._in_property = True
            for arg in node.args:
                self.visit(arg)
            self._in_property = False
        else:
            for arg in node.args:
                self.visit(arg)
Beispiel #4
0
 def translate_contract_Call(self, node: ast.Call, ctx: Context) -> Expr:
     if get_func_name(node) in CONTRACT_WRAPPER_FUNCS:
         stmt, res = self.translate_expr(node.args[0], ctx, self.viper.Bool,
                                         True)
         if stmt:
             raise InvalidProgramException(node, 'purity.violated')
         return res
     else:
         raise UnsupportedException(node)
Beispiel #5
0
 def _translate_wrapper(self, wrapper: Wrapper, previous: Expr,
                        function: PythonMethod, ctx: Context) -> Expr:
     if isinstance(wrapper, ReturnWrapper):
         return self._translate_return_wrapper(wrapper, previous, function,
                                               ctx)
     elif isinstance(wrapper, AssignWrapper):
         return self._translate_assign_wrapper(wrapper, previous, function,
                                               ctx)
     else:
         raise UnsupportedException(wrapper)
Beispiel #6
0
    def translate_perm_Call(self, node: ast.Call, ctx: Context) -> Expr:
        func_name = get_func_name(node)
        if func_name == 'ARP':
            if not ctx.arp:
                raise UnsupportedException(
                    node, 'ARP not supported. Use --arp flag.')
            if len(node.args) == 0:
                return self.get_arp_for_context(node, ctx)
            elif len(node.args) == 1:
                arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx,
                                                      self.viper.Int)
                if arg0_stmt:
                    raise InvalidProgramException(node, 'purity.violated')
                formal_arg = self.viper.LocalVarDecl(
                    'count', self.viper.Int, self.to_position(node, ctx),
                    self.no_info(ctx))
                return self.viper.FuncApp('rdc', [arg0],
                                          self.to_position(node, ctx),
                                          self.no_info(ctx), self.viper.Perm,
                                          [formal_arg])
        elif func_name == 'getARP':
            if not ctx.arp:
                raise UnsupportedException(
                    node, 'ARP not supported. Use --arp flag.')
            if len(node.args) == 1:
                formal_arg = self.viper.LocalVarDecl(
                    'tk', self.viper.Ref, self.to_position(node, ctx),
                    self.no_info(ctx))
                arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx,
                                                      self.viper.Ref)
                if arg0_stmt:
                    raise InvalidProgramException(node, 'purity.violated')
                return self.viper.FuncApp('rd_token', [arg0],
                                          self.to_position(node, ctx),
                                          self.no_info(ctx), self.viper.Perm,
                                          [formal_arg])

        call_stmt, call = self.translate_expr(node, ctx, self.viper.Int)
        if not call_stmt:
            return call
        raise InvalidProgramException(node, 'purity.violated')
Beispiel #7
0
def _get_subscript_type(value_type: PythonType, module: PythonModule,
                        node: ast.AST) -> PythonType:
    if isinstance(value_type, OptionalType):
        value_type = value_type.cls
    if value_type.name == TUPLE_TYPE:
        if isinstance(node, ast.Subscript):
            if isinstance(node.slice, ast.Slice):
                raise UnsupportedException(node, 'tuple slicing')
            if len(value_type.type_args) == 1:
                return value_type.type_args[0]
            if isinstance(node.slice.value, ast.UnaryOp):
                if (isinstance(node.slice.value.op, ast.USub)
                        and isinstance(node.slice.value.operand, ast.Num)):
                    index = -node.slice.value.operand.n
                else:
                    raise UnsupportedException(node, 'dynamic subscript type')
            elif isinstance(node.slice.value, ast.Num):
                index = node.slice.value.n
            return value_type.type_args[index]
        else:
            return common_supertype(value_type.type_args)
    elif value_type.name == LIST_TYPE:
        return value_type.type_args[0]
    elif value_type.name == SET_TYPE:
        return value_type.type_args[0]
    elif value_type.name in (DICT_TYPE, 'defaultdict', 'ExpiringDict'):
        # FIXME: This is very unfortunate, but right now we cannot handle this
        # generically, so we have to hard code these two cases for the moment.
        return value_type.type_args[1]
    elif value_type.name in (RANGE_TYPE, BYTES_TYPE):
        return module.global_module.classes[INT_TYPE]
    elif value_type.name == PSEQ_TYPE:
        return value_type.type_args[0]
    elif value_type.name == PSET_TYPE:
        return value_type.type_args[0]
    elif value_type.name == PMSET_TYPE:
        return value_type.type_args[0]
    elif value_type.python_class.get_func_or_method('__getitem__'):
        return value_type.python_class.get_func_or_method('__getitem__').type
    else:
        raise UnsupportedException(node)
Beispiel #8
0
 def translate_perm_Name(self, node: ast.Name, ctx: Context) -> Expr:
     if node.id == 'RD_PRED':
         if not ctx.arp:
             raise UnsupportedException(
                 node, 'ARP not supported. Use --arp flag.')
         return self.viper.FuncApp('globalRd', [],
                                   self.to_position(node, ctx),
                                   self.no_info(ctx), self.viper.Perm, {})
     stmt, res = self.translate_expr(node, ctx)
     if stmt:
         raise InvalidProgramException(node, 'purity.violated')
     return res
Beispiel #9
0
 def translate_pure_Assign(self, conds: List, node: ast.Assign,
                           ctx: Context) -> List[Wrapper]:
     """
     Translates an assign statement to an AssignWrapper
     """
     assert len(node.targets) == 1
     if not isinstance(node.targets[0], ast.Name):
         raise UnsupportedException(
             node,
             "Multi-target assignments are not supported in pure functions."
         )
     wrapper = AssignWrapper(node.targets[0].id, conds, node.value, node)
     return [wrapper]
Beispiel #10
0
 def translate_low(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
     """
     Translates a call to the Low() contract function.
     """
     if len(node.args) != 1:
         raise UnsupportedException(node,
                                    "Low() requires exactly one argument")
     stmts, expr = self.translate_expr(node.args[0], ctx)
     if stmts:
         raise InvalidProgramException(node, 'purity.violated')
     info = self._create_dyn_check_info(ctx)
     return [], self.viper.Low(expr, None, self.to_position(node, ctx),
                               info)
Beispiel #11
0
 def translate_obligation_contractfunc_call(self, node: ast.Call,
                                            ctx: Context) -> StmtsAndExpr:
     """Translate a call to obligation contract function."""
     func_name = get_func_name(node)
     if func_name == 'MustTerminate':
         return self._translate_must_terminate(node, ctx)
     elif func_name == 'MustRelease':
         return self._translate_must_release(node, ctx)
     elif func_name == 'WaitLevel':
         raise InvalidProgramException(node, 'invalid.wait_level.use')
     elif func_name == 'Level':
         raise InvalidProgramException(node, 'invalid.level.use')
     else:
         raise UnsupportedException(node, 'Unsupported contract function.')
Beispiel #12
0
    def translate_perm_BinOp(self, node: ast.BinOp, ctx: Context) -> Expr:

        if isinstance(node.op, ast.Div):
            left, left_int = self.translate_perm_or_int(node.left, ctx)
            right_stmt, right = self.translate_expr(node.right,
                                                    ctx,
                                                    target_type=self.viper.Int)
            if right_stmt:
                raise InvalidProgramException(node, 'purity.violated')

            if left_int:
                return self.viper.FractionalPerm(left, right,
                                                 self.to_position(node, ctx),
                                                 self.no_info(ctx))
            return self.viper.PermDiv(left, right, self.to_position(node, ctx),
                                      self.no_info(ctx))

        if isinstance(node.op, ast.Mult):
            left, left_int = self.translate_perm_or_int(node.left, ctx)
            right, right_int = self.translate_perm_or_int(node.right, ctx)

            if left_int != right_int and right_int:
                left, left_int, right, right_int = right, right_int, left, left_int

            if left_int and right_int:
                return self.viper.Mul(left, right, self.to_position(node, ctx),
                                      self.no_info(ctx))
            if left_int or right_int:
                return self.viper.IntPermMul(left, right,
                                             self.to_position(node, ctx),
                                             self.no_info(ctx))

            return self.viper.PermMul(left, right, self.to_position(node, ctx),
                                      self.no_info(ctx))

        new_node = None
        if isinstance(node.op, ast.Add):
            new_node = self.viper.PermAdd
        elif isinstance(node.op, ast.Sub):
            new_node = self.viper.PermSub

        if new_node:
            left = self.translate_perm(node.left, ctx)
            right = self.translate_perm(node.right, ctx)
            return new_node(left, right, self.to_position(node, ctx),
                            self.no_info(ctx))
        raise UnsupportedException(node)
Beispiel #13
0
 def translate_acc_predicate(self, node: ast.Call, perm: Expr,
                             ctx: Context) -> StmtsAndExpr:
     """
     Translates a call to the Acc() contract function with a predicate call
     inside to a predicate access.
     """
     assert isinstance(node.args[0], ast.Call)
     call = node.args[0]
     # The predicate inside is a function call in python.
     args = []
     arg_stmts = []
     for arg in call.args:
         arg_stmt, arg_expr = self.translate_expr(arg, ctx)
         arg_stmts = arg_stmts + arg_stmt
         args.append(arg_expr)
     # Get the predicate inside the Acc()
     if isinstance(call.func, ast.Name):
         if call.func.id in BUILTIN_PREDICATES:
             return arg_stmts, self.translate_builtin_predicate(
                 call, perm, args, ctx)
         else:
             pred = self.get_target(call.func, ctx)
     elif isinstance(call.func, ast.Attribute):
         receiver = self.get_target(call.func.value, ctx)
         if isinstance(receiver, PythonModule):
             pred = receiver.predicates[call.func.attr]
         else:
             rec_stmt, receiver = self.translate_expr(call.func.value, ctx)
             assert not rec_stmt
             receiver_class = self.get_type(call.func.value, ctx)
             name = call.func.attr
             pred = receiver_class.get_predicate(name)
             args = [receiver] + args
     else:
         raise UnsupportedException(node)
     if not (isinstance(pred, PythonMethod) and pred.predicate):
         raise InvalidProgramException(node, 'invalid.acc')
     pred_name = pred.sil_name
     # If the predicate is part of a family, find the correct version.
     if pred.cls:
         family_root = pred.cls
         while (family_root.superclass
                and family_root.superclass.get_predicate(name)):
             family_root = family_root.superclass
         pred_name = family_root.get_predicate(name).sil_name
     return arg_stmts, self.create_predicate_access(pred_name, args, perm,
                                                    node, ctx)
Beispiel #14
0
    def _get_function_call(self,
                           receiver: PythonType,
                           func_name: str,
                           args: List[Expr],
                           arg_types: List[PythonType],
                           node: ast.AST,
                           ctx: Context,
                           position: Position = None) -> FuncApp:
        """
        Creates a function application of the function called func_name, with
        the given receiver and arguments. Boxes arguments if necessary, and
        unboxed the result if needed as well. This method only handles receivers
        of non-union types.
        """
        if receiver:
            target_cls = receiver
            func = target_cls.get_function(func_name)
        else:
            for container in ctx.module.get_included_modules():
                if func_name in container.functions:
                    func = container.functions[func_name]
                    break
        if not func:
            if receiver and target_cls.get_method(func_name):
                msg = 'Called method is expected to be pure: ' + func_name
                raise UnsupportedException(node, msg)
            raise InvalidProgramException(node, 'unknown.function.called')
        formal_args = []
        actual_args = []
        assert len(args) == len(func.get_args())
        for arg, param, type in zip(args, func.get_args(), arg_types):
            formal_args.append(param.decl)
            if param.type.name == '__prim__bool':
                actual_arg = self.to_bool(arg, ctx)
            elif param.type.name == '__prim__int':
                actual_arg = self.to_int(arg, ctx)
            else:
                actual_arg = self.to_ref(arg, ctx)
            actual_args.append(actual_arg)
        type = self.translate_type(func.type, ctx)
        sil_name = func.sil_name

        actual_position = position if position else self.to_position(node, ctx)
        call = self.viper.FuncApp(sil_name, actual_args, actual_position,
                                  self.no_info(ctx), type, formal_args)
        return call
Beispiel #15
0
 def _translate_wrapper_expr(self, wrapper: Wrapper, ctx: Context) -> Expr:
     info = self.no_info(ctx)
     position = self.to_position(wrapper.node, ctx)
     if isinstance(wrapper.expr, BinOpWrapper):
         assert isinstance(wrapper, AssignWrapper)
         stmt, val = self.translate_expr(wrapper.expr.rhs, ctx)
         val = self.to_int(val, ctx)
         var = ctx.var_aliases[wrapper.name].ref()
         var = self.to_int(var, ctx)
         if isinstance(wrapper.expr.op, ast.Add):
             val = self.viper.Add(var, val, position, info)
         elif isinstance(wrapper.expr.op, ast.Sub):
             val = self.viper.Sub(var, val, position, info)
         elif isinstance(wrapper.expr.op, ast.Mult):
             val = self.viper.Mul(var, val, position, info)
         else:
             raise UnsupportedException(wrapper.node)
     else:
         stmt, val = self.translate_expr(wrapper.expr, ctx)
     if stmt:
         raise InvalidProgramException(wrapper.expr, 'purity.violated')
     return val
Beispiel #16
0
    def translate_io_contractfunc_call(self, node: ast.Call, ctx: Context,
                                       impure: bool,
                                       statement: bool) -> StmtsAndExpr:
        """Translate a call to a IO contract function.

        Currently supported functions:

        +   ``token``
        +   ``ctoken``
        +   ``Open``
        """
        func_name = get_func_name(node)
        if func_name == 'token':
            if not impure:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self.translate_must_invoke_token(node, ctx)
        elif func_name == 'ctoken':
            if not impure:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self.translate_must_invoke_ctoken(node, ctx)
        elif func_name == 'Open':
            if not statement:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self._translate_open(node, ctx)
        elif func_name == 'Eval':
            return self._translate_eval(node, ctx)
        elif func_name == 'eval_io':
            if not impure:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self._translate_eval_io(node, ctx)
        else:
            raise UnsupportedException(node, 'Unsupported contract function.')
Beispiel #17
0
 def translate_contractfunc_call(self,
                                 node: ast.Call,
                                 ctx: Context,
                                 impure=False,
                                 statement=False) -> StmtsAndExpr:
     """
     Translates calls to contract functions like Result() and Acc()
     """
     func_name = get_func_name(node)
     if func_name == 'Result':
         return self.translate_result(node, ctx)
     elif func_name == 'RaisedException':
         return self.translate_raised_exception(node, ctx)
     elif func_name in ('Acc', 'Rd', 'Wildcard'):
         if not impure:
             raise InvalidProgramException(node,
                                           'invalid.contract.position')
         if func_name == 'Rd':
             perm = self.get_arp_for_context(node, ctx)
         elif func_name == 'Wildcard':
             perm = self.viper.WildcardPerm(self.to_position(node, ctx),
                                            self.no_info(ctx))
         else:
             perm = self._get_perm(node, ctx)
         if isinstance(node.args[0], ast.Call):
             return self.translate_acc_predicate(node, perm, ctx)
         else:
             if isinstance(node.args[0], ast.Attribute):
                 type = self.get_type(node.args[0].value, ctx)
                 if isinstance(type, UnionType):
                     guarded_field_access = []
                     stmt, receiver = self.translate_expr(
                         node.args[0].value, ctx)
                     for recv_type in toposort_classes(type.get_types() -
                                                       {None}):
                         target = self.get_target(node.args[0].value, ctx)
                         field_guard = self.var_type_check(
                             target.sil_name, recv_type,
                             self.to_position(node, ctx), ctx)
                         field = recv_type.get_field(
                             node.args[0].attr).actual_field
                         field_access = self.viper.FieldAccess(
                             receiver, field.sil_field,
                             self.to_position(node, ctx), self.no_info(ctx))
                         field_acc = self._translate_acc_field(
                             field_access, field.type, perm,
                             self.to_position(node, ctx), ctx)
                         guarded_field_access.append(
                             (field_guard, field_acc))
                     if len(guarded_field_access) == 1:
                         _, field_acc = guarded_field_access[0]
                         return stmt, field_acc
                     else:
                         return (stmt,
                                 chain_cond_exp(guarded_field_access,
                                                self.viper,
                                                self.to_position(node, ctx),
                                                self.no_info(ctx), ctx))
             target = self.get_target(node.args[0], ctx)
             if isinstance(target, PythonField):
                 return self.translate_acc_field(node, perm, ctx)
             else:
                 if not isinstance(target, PythonGlobalVar):
                     raise InvalidProgramException(node, 'invalid.acc')
                 return self.translate_acc_global(node, perm, ctx)
     elif func_name in BUILTIN_PREDICATES:
         return [], self.translate_unwrapped_builtin_predicate(node, ctx)
     elif func_name == 'MaySet':
         return self.translate_may_set(node, ctx)
     elif func_name == 'MayCreate':
         return self.translate_may_create(node, ctx)
     elif func_name in ('Assert', 'Assume', 'Fold', 'Unfold'):
         if not statement:
             raise InvalidProgramException(node,
                                           'invalid.contract.position')
         if func_name == 'Assert':
             return self.translate_assert(node, ctx)
         elif func_name == 'Assume':
             return self.translate_assume(node, ctx)
         elif func_name == 'Fold':
             return self.translate_fold(node, ctx)
         elif func_name == 'Unfold':
             return self.translate_unfold(node, ctx)
     elif func_name == 'Implies':
         return self.translate_implies(node, ctx, impure)
     elif func_name == 'Old':
         return self.translate_old(node, ctx)
     elif func_name == 'Unfolding':
         return self.translate_unfolding(node, ctx, impure)
     elif func_name == 'Low':
         return self.translate_low(node, ctx)
     elif func_name == 'LowVal':
         return self.translate_lowval(node, ctx)
     elif func_name == 'LowEvent':
         return self.translate_lowevent(node, ctx)
     elif func_name == 'LowExit':
         return self.translate_lowexit(node, ctx)
     elif func_name == 'Declassify':
         return self.translate_declassify(node, ctx)
     elif func_name == 'TerminatesSif':
         return self.translate_terminates_sif(node, ctx)
     elif func_name == 'Forall':
         return self.translate_forall(node, ctx, impure)
     elif func_name == 'Previous':
         return self.translate_previous(node, ctx)
     elif func_name == 'Let':
         return self.translate_let(node, ctx, impure)
     elif func_name == PSEQ_TYPE:
         return self.translate_sequence(node, ctx)
     elif func_name == PSET_TYPE:
         return self.translate_pset(node, ctx)
     elif func_name == PMSET_TYPE:
         return self.translate_mset(node, ctx)
     elif func_name == 'ToSeq':
         return self.translate_to_sequence(node, ctx)
     elif func_name == 'Joinable':
         return self.translate_joinable(node, ctx)
     elif func_name == 'getArg':
         return self.translate_get_arg(node, ctx)
     elif func_name == 'getOld':
         return self.translate_get_old(node, ctx)
     elif func_name == 'getMethod':
         raise InvalidProgramException(node, 'invalid.get.method.use')
     elif func_name == 'arg':
         raise InvalidProgramException(node, 'invalid.arg.use')
     else:
         raise UnsupportedException(node)
Beispiel #18
0
 def translate_pure_Expr(self, conds: List, node: ast.Expr,
                         ctx: Context) -> List[Wrapper]:
     if isinstance(node.value, ast.Str):
         # Ignore docstrings.
         return []
     raise UnsupportedException(node)
Beispiel #19
0
 def translate_pure_generic(self, conds: List, node: ast.AST,
                            ctx: Context) -> List[Wrapper]:
     raise UnsupportedException(node)
Beispiel #20
0
 def translate_generic(self, node: ast.AST, ctx: Context) -> None:
     """
     Visitor that is used if no other visitor is implemented.
     Simply raises an exception.
     """
     raise UnsupportedException(node)
Beispiel #21
0
 def translate_perm_Num(self, node: ast.Num, ctx: Context) -> Expr:
     if node.n == 1:
         return self.viper.FullPerm(self.to_position(node, ctx),
                                    self.no_info(ctx))
     raise UnsupportedException(node)
Beispiel #22
0
def _get_call_type(node: ast.Call, module: PythonModule,
                   current_function: PythonMethod,
                   containers: List[ContainerInterface],
                   container: PythonNode) -> PythonType:
    func_name = get_func_name(node)
    if func_name == 'super':
        if len(node.args) == 2:
            return module.classes[node.args[0].id].superclass
        elif not node.args:
            return container.cls.superclass
        else:
            raise InvalidProgramException(node, 'invalid.super.call')
    if func_name == 'len':
        return module.global_module.classes[INT_TYPE]
    if func_name in ('token', 'ctoken', 'MustTerminate', 'MustRelease'):
        return module.global_module.classes[BOOL_TYPE]
    if func_name == PSEQ_TYPE:
        return _get_collection_literal_type(node, ['args'], PSEQ_TYPE, module,
                                            containers, container)
    if func_name == PSET_TYPE:
        return _get_collection_literal_type(node, ['args'], PSET_TYPE, module,
                                            containers, container)
    if func_name == PMSET_TYPE:
        return _get_collection_literal_type(node, ['args'], PMSET_TYPE, module,
                                            containers, container)
    if func_name == 'enumerate':
        if len(node.args) != 1:
            raise UnsupportedException(
                node, 'enumerate only supported with single arg.')
        list_type = module.global_module.classes[LIST_TYPE]
        int_type = module.global_module.classes[INT_TYPE]
        tuple_type = module.global_module.classes[TUPLE_TYPE]
        arg_type = get_type(node.args[0], containers, container)
        iterable_type = _get_iteration_type(arg_type, module, node)
        return GenericType(
            list_type, [GenericType(tuple_type, [int_type, iterable_type])])
    if isinstance(node.func, ast.Name):
        if node.func.id in CONTRACT_FUNCS:
            if node.func.id == 'Result':
                return current_function.type
            elif node.func.id == 'RaisedException':
                ctxs = [
                    cont for cont in containers
                    if hasattr(cont, 'var_aliases')
                ]
                ctx = ctxs[0] if ctxs else None
                assert ctx
                assert ctx.current_contract_exception is not None
                return ctx.current_contract_exception
            elif node.func.id in ('Acc', 'Rd', 'Read', 'Implies', 'Forall',
                                  'Exists', 'MayCreate', 'MaySet', 'Low',
                                  'LowVal', 'LowEvent', 'LowExit'):
                return module.global_module.classes[BOOL_TYPE]
            elif node.func.id == 'Declassify':
                return None
            elif node.func.id == 'Old':
                return get_type(node.args[0], containers, container)
            elif node.func.id == 'Unfolding':
                return get_type(node.args[1], containers, container)
            elif node.func.id == 'ToSeq':
                arg_type = get_type(node.args[0], containers, container)
                seq_class = module.global_module.classes[PSEQ_TYPE]
                content_type = _get_iteration_type(arg_type, module, node)
                return GenericType(seq_class, [content_type])
            elif node.func.id == 'Previous':
                arg_type = get_type(node.args[0], containers, container)
                list_class = module.global_module.classes[PSEQ_TYPE]
                return GenericType(list_class, [arg_type])
            elif node.func.id in ('getArg', 'getOld', 'getMethod'):
                object_class = module.global_module.classes[OBJECT_TYPE]
                return object_class
            elif node.func.id == 'Let':
                body_type = get_target(node.args[1], containers, container)
                if isinstance(body_type, PythonType):
                    return body_type
                raise InvalidProgramException(node, 'invalid.let')
            else:
                raise UnsupportedException(node)
        elif node.func.id in BUILTINS:
            if node.func.id in ('isinstance', BOOL_TYPE):
                return module.global_module.classes[BOOL_TYPE]
            elif node.func.id == 'cast':
                return get_target(node.args[0], containers, container)
            else:
                raise UnsupportedException(node)
        if node.func.id in module.classes:
            return module.global_module.classes[node.func.id]
        elif module.get_func_or_method(node.func.id) is not None:
            target = module.get_func_or_method(node.func.id)
            return target.type
    elif isinstance(node.func, ast.Attribute):
        rectype = get_type(node.func.value, containers, container)
        if isinstance(rectype, UnionType):
            set_of_classes = rectype.get_types() - {None}
            set_of_return_types = {
                type.get_func_or_method(node.func.attr).type
                for type in set_of_classes
            }
            if len(set_of_return_types) == 1:
                return set_of_return_types.pop()
            elif len(set_of_return_types) == 2 and None in set_of_return_types:
                return OptionalType((set_of_return_types - {None}).pop())
            else:
                return UnionType(list(set_of_return_types))
        elif isinstance(rectype, PythonType):
            target = rectype.get_func_or_method(node.func.attr)
            if target.generic_type != -1:
                return rectype.type_args[target.generic_type]
            else:
                return target.type
    else:
        raise UnsupportedException(node)
Beispiel #23
0
def _do_get_type(node: ast.AST, containers: List[ContainerInterface],
                 container: PythonNode) -> Optional[PythonType]:
    """
    Does the actual work for get_type without boxing the type.
    """
    if isinstance(container, (PythonIOOperation, PythonMethod)):
        module = container.module
        current_function = container
    else:
        module = container
        current_function = None
    target = get_target(node, containers, container)
    if target:
        if isinstance(target, PythonVarBase):
            return target.get_specific_type(node)
        if isinstance(target, PythonMethod):
            if isinstance(node, ast.Call) and isinstance(
                    node.func, ast.Attribute):
                rec_target = get_target(node.func.value, containers, container)
                if not isinstance(rec_target, PythonModule):
                    rectype = get_type(node.func.value, containers, container)
                    if target.generic_type != -1:
                        return rectype.type_args[target.generic_type]
                    if isinstance(target.type, TypeVar):
                        while rectype.python_class is not target.cls:
                            rectype = rectype.superclass
                        name_list = list(rectype.python_class.type_vars.keys())
                        index = name_list.index(target.type.name)
                        return rectype.type_args[index]
            return target.type
        if isinstance(target, PythonField):
            result = target.type
            if isinstance(result, TypeVar):
                assert isinstance(node, ast.Attribute)
                rec_type = _do_get_type(node.value, containers, container)
                while (rec_type.python_class
                       is not result.target_type.python_class):
                    rec_type = rec_type.superclass
                result = rec_type.type_args[result.index]
            return result
        if isinstance(target, PythonIOOperation):
            return module.global_module.classes[BOOL_TYPE]

        if isinstance(target, (PythonType, PythonModule)):
            if (isinstance(node, ast.Call) and isinstance(target, PythonClass)
                    and target.type_vars):
                # This is a call to a constructor of a generic class; it's not
                # enough to just return the class, we need the entire type with
                # type arguments. We only support that if we can get it directly
                # from mypy, i.e., when the result is assigned to a variable
                # and we can get the variable type.
                if node._parent and isinstance(node._parent, ast.Assign):
                    return get_type(node._parent.targets[0], containers,
                                    container)
                elif (target.name in (PSEQ_TYPE, PSET_TYPE, PMSET_TYPE)
                      and isinstance(node, ast.Call) and node.args):
                    arg_types = [
                        get_type(arg, containers, container)
                        for arg in node.args
                    ]
                    return GenericType(target, [common_supertype(arg_types)])
                else:
                    error = 'generic.constructor.without.type'
                    raise InvalidProgramException(node, error)
            return target
    if isinstance(node, (ast.Attribute, ast.Name)):
        if isinstance(node, ast.Attribute):
            lhs = _do_get_type(node.value, containers, container)
            if isinstance(lhs,
                          UnionType) and not isinstance(lhs, OptionalType):
                candidates = [
                    find_entry(node.attr, False, [t]) for t in lhs.type_args
                ]
                if all(
                        isinstance(c, (PythonField, PythonVarBase))
                        for c in candidates):
                    return common_supertype([c.type for c in candidates])
        # All these cases should be handled by get_target, so if we get here,
        # the node refers to something unknown in the given context.
        return None
    if isinstance(node, ast.Num):
        return module.global_module.classes[INT_TYPE]
    elif isinstance(node, ast.Tuple):
        args = [get_type(arg, containers, container) for arg in node.elts]
        return GenericType(module.global_module.classes[TUPLE_TYPE], args)
    elif isinstance(node, ast.Subscript):
        return get_subscript_type(node, module, containers, container)
    elif isinstance(node, ast.Str):
        return module.global_module.classes[STRING_TYPE]
    elif isinstance(node, ast.Bytes):
        return module.global_module.classes[BYTES_TYPE]
    elif isinstance(node, ast.Compare):
        return module.global_module.classes[BOOL_TYPE]
    elif isinstance(node, ast.BoolOp):
        # And and Or always return one of their operands, so we use the common
        # supertype of all arguments.
        # TODO: We could also use a union type, but since support for e.g.
        # calling methods on those isn't amazing yet, we don't do that yet.
        operand_types = [
            get_type(operand, containers, container) for operand in node.values
        ]
        return common_supertype(operand_types)
    elif isinstance(node, ast.List):
        return _get_collection_literal_type(node, ['elts'], LIST_TYPE, module,
                                            containers, container)
    elif isinstance(node, ast.Set):
        return _get_collection_literal_type(node, ['elts'], SET_TYPE, module,
                                            containers, container)
    elif isinstance(node, ast.Dict):
        return _get_collection_literal_type(node, ['keys', 'values'],
                                            DICT_TYPE, module, containers,
                                            container)
    elif isinstance(node, ast.IfExp):
        body_type = get_type(node.body, containers, container)
        else_type = get_type(node.orelse, containers, container)
        return pairwise_supertype(body_type, else_type)
    elif isinstance(node, ast.BinOp):
        left_type = get_type(node.left, containers, container)
        right_type = get_type(node.right, containers, container)
        operator_func = OPERATOR_FUNCTIONS[type(node.op)]
        return left_type.get_func_or_method(operator_func).type
    elif isinstance(node, ast.UnaryOp):
        if isinstance(node.op, ast.Not):
            return module.global_module.classes[BOOL_TYPE]
        elif isinstance(node.op, ast.USub):
            return module.global_module.classes[INT_TYPE]
        else:
            raise UnsupportedException(node)
    elif isinstance(node, ast.NameConstant):
        if (node.value is True) or (node.value is False):
            return module.global_module.classes[BOOL_TYPE]
        elif node.value is None:
            return module.global_module.classes[OBJECT_TYPE]
        else:
            raise UnsupportedException(node)
    elif isinstance(node, ast.Call):
        return _get_call_type(node, module, current_function, containers,
                              container)
    elif isinstance(node, ast.ListComp):
        if (node._parent and isinstance(node._parent, ast.Assign)
                and node is node._parent.value):
            # Constructor is assigned to variable;
            # we get the type of the dict from the type of the
            # variable it's assigned to.
            return get_type(node._parent.targets[0], containers, container)
        else:
            raise UnsupportedException(
                node, 'List comprehensions must be directly '
                'assigned to a local variable.')
    else:
        raise UnsupportedException(node)
Beispiel #24
0
    def visit_Call(self, node: ast.Call) -> None:
        """Parse IO operation properties.

        Currently, only parses properties such as ``Terminates`` and
        ``TerminationMeasure``.
        """
        assert self._current_io_operation is not None
        assert self._current_node is not None

        body_prefix = None

        if (isinstance(node.func, ast.Name)
                and node.func.id in IO_OPERATION_PROPERTY_FUNCS):

            for child in self._current_node.body:
                if (isinstance(child, ast.Expr) and child.value == node):
                    break
            else:
                self._raise_invalid_operation('misplaced_property', node)

            operation = self._current_io_operation
            arg = node.args[0]
            if node.func.id == 'Terminates':
                if not operation.set_terminates(arg):
                    self._raise_invalid_operation('duplicate_property', node)
            elif node.func.id == 'TerminationMeasure':
                if not operation.set_termination_measure(arg):
                    self._raise_invalid_operation('duplicate_property', node)
            else:
                raise UnsupportedException(node, 'Unsupported property type.')
            self._in_property = True
            for arg in node.args:
                self.visit(arg)
            self._in_property = False
            return
        elif isinstance(
                node.func,
                ast.Name) and node.func.id in ('IOForall', 'Forall', 'Exists'):
            operation = self._current_io_operation
            assert len(node.args[1].args.args) == 1
            arg_type = self._parent.get_target(node.args[0], operation.module)
            lambda_ = node.args[1]
            body_prefix = construct_lambda_prefix(
                lambda_.lineno, getattr(lambda_, 'col_offset', None))
            for arg in lambda_.args.args:
                var = self._node_factory.create_python_var(
                    arg.arg, arg, arg_type)
                operation._io_universals.append(var)
        elif (isinstance(node.func, ast.Call)
              and isinstance(node.func.func, ast.Name)
              and node.func.func.id == 'IOExists'):
            lambda_ = node.args[0]
            arg = lambda_.args.args[0]
            body_prefix = construct_lambda_prefix(lambda_.lineno,
                                                  lambda_.col_offset)
            creator = self._node_factory.create_python_var_creator(
                arg.arg, arg, self._typeof(arg, lambda_))
            current_existentials = self._current_io_operation.get_io_existentials(
            )
            current_existentials.append(creator)
        if body_prefix:
            self._current_lambdas.append(body_prefix)
        for arg in node.args:
            self.visit(arg)
        if body_prefix:
            self._current_lambdas.pop()
Beispiel #25
0
 def translate_contract_Expr(self, node: ast.Expr, ctx: Context) -> Expr:
     if isinstance(node.value, ast.Call):
         return self.translate_contract(node.value, ctx)
     else:
         raise UnsupportedException(node)