Пример #1
0
    def _create_quantifier_contains_expr(self,
                                         e: Expr,
                                         domain_node: ast.AST,
                                         ctx: Context,
                                         trigger=False
                                         ) -> Tuple[List[Stmt], Expr, bool]:
        """
        Creates the left hand side of the implication in a quantifier
        expression, which says that e is an element of the given domain.
        The last return value specifies if the returned expression is
        recommended to be used as a trigger.
        """
        domain_old = False
        if (isinstance(domain_node, ast.Call)
                and get_func_name(domain_node) == 'Old'):
            domain_old = True
            domain_node = domain_node.args[0]
        ref_var = self.to_ref(e, ctx)
        pos = self.to_position(domain_node, ctx)
        info = self.no_info(ctx)

        dom_target = self.get_target(domain_node, ctx)

        if isinstance(dom_target, PythonType):
            result = self.type_check(ref_var, dom_target, pos, ctx, False)
            # Not recommended as a trigger, since it's very broad and will get triggered
            # a lot.
            return [], result, False
        dom_stmt, domain = self.translate_expr(domain_node, ctx)
        dom_type = self.get_type(domain_node, ctx)
        result = self.get_quantifier_lhs(ref_var, dom_type, domain,
                                         domain_node, ctx, pos, trigger)
        if domain_old:
            result = self.viper.Old(result, pos, info)
        return dom_stmt, result, True
Пример #2
0
    def translate_unfolding(self,
                            node: ast.Call,
                            ctx: Context,
                            impure=False) -> StmtsAndExpr:
        """
        Translates a call to the Unfolding() contract function.
        """
        if len(node.args) != 2:
            raise InvalidProgramException(node, 'invalid.contract.call')
        if not isinstance(node.args[0], ast.Call):
            raise InvalidProgramException(node, 'invalid.contract.call')
        if get_func_name(node.args[0]) in ('Acc', 'Rd'):
            pred_call = node.args[0].args[0]
        else:
            pred_call = node.args[0]
        target_pred = self.get_target(pred_call, ctx)
        if (target_pred and (not isinstance(target_pred, PythonMethod)
                             or not target_pred.predicate)):
            raise InvalidProgramException(node, 'invalid.contract.call')

        pred_stmt, pred = self.translate_expr(node.args[0], ctx,
                                              self.viper.Bool, True)
        if pred_stmt:
            raise InvalidProgramException(node, 'purity.violated')
        expr_stmt, expr = self.translate_expr(node.args[1], ctx)
        if expr_stmt:
            raise InvalidProgramException(node, 'purity.violated')
        expr = self.unwrap(expr)
        unfold = self.viper.Unfolding(pred, expr, self.to_position(node, ctx),
                                      self.no_info(ctx))
        return expr_stmt, unfold
Пример #3
0
 def translate_unfold(self, node: ast.Call, ctx: Context) -> StmtsAndExpr:
     """
     Translates a call to the Unfold() contract function.
     """
     if len(node.args) != 1:
         raise InvalidProgramException(node, 'invalid.contract.call')
     if not isinstance(node.args[0], ast.Call):
         raise InvalidProgramException(node, 'invalid.contract.call')
     if get_func_name(node.args[0]) in ('Acc', 'Rd'):
         pred_call = node.args[0].args[0]
     else:
         pred_call = node.args[0]
     target_pred = self.get_target(pred_call, ctx)
     if (target_pred and (not isinstance(target_pred, PythonMethod)
                          or not target_pred.predicate)):
         raise InvalidProgramException(node, 'invalid.contract.call')
     pred_stmt, pred = self.translate_expr(node.args[0], ctx,
                                           self.viper.Bool, True)
     if self._is_family_fold(node):
         # Predicate called on receiver, so it belongs to a family
         if ctx.ignore_family_folds:
             return [], None
     if pred_stmt:
         raise InvalidProgramException(node, 'purity.violated')
     unfold = self.viper.Unfold(pred, self.to_position(node, ctx),
                                self.no_info(ctx))
     return [unfold], None
Пример #4
0
 def translate_contract_Call(self, node: ast.Call, ctx: Context) -> Expr:
     if get_func_name(node) in CONTRACT_WRAPPER_FUNCS:
         stmt, res = self.translate_expr(node.args[0], ctx, self.viper.Bool,
                                         True)
         if stmt:
             raise InvalidProgramException(node, 'purity.violated')
         return res
     else:
         raise UnsupportedException(node)
Пример #5
0
 def translate_obligation_contractfunc_call(self, node: ast.Call,
                                            ctx: Context) -> StmtsAndExpr:
     """Translate a call to obligation contract function."""
     func_name = get_func_name(node)
     if func_name == 'MustTerminate':
         return self._translate_must_terminate(node, ctx)
     elif func_name == 'MustRelease':
         return self._translate_must_release(node, ctx)
     elif func_name == 'WaitLevel':
         raise InvalidProgramException(node, 'invalid.wait_level.use')
     elif func_name == 'Level':
         raise InvalidProgramException(node, 'invalid.level.use')
     else:
         raise UnsupportedException(node, 'Unsupported contract function.')
Пример #6
0
    def translate_perm_Call(self, node: ast.Call, ctx: Context) -> Expr:
        func_name = get_func_name(node)
        if func_name == 'ARP':
            if not ctx.arp:
                raise UnsupportedException(
                    node, 'ARP not supported. Use --arp flag.')
            if len(node.args) == 0:
                return self.get_arp_for_context(node, ctx)
            elif len(node.args) == 1:
                arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx,
                                                      self.viper.Int)
                if arg0_stmt:
                    raise InvalidProgramException(node, 'purity.violated')
                formal_arg = self.viper.LocalVarDecl(
                    'count', self.viper.Int, self.to_position(node, ctx),
                    self.no_info(ctx))
                return self.viper.FuncApp('rdc', [arg0],
                                          self.to_position(node, ctx),
                                          self.no_info(ctx), self.viper.Perm,
                                          [formal_arg])
        elif func_name == 'getARP':
            if not ctx.arp:
                raise UnsupportedException(
                    node, 'ARP not supported. Use --arp flag.')
            if len(node.args) == 1:
                formal_arg = self.viper.LocalVarDecl(
                    'tk', self.viper.Ref, self.to_position(node, ctx),
                    self.no_info(ctx))
                arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx,
                                                      self.viper.Ref)
                if arg0_stmt:
                    raise InvalidProgramException(node, 'purity.violated')
                return self.viper.FuncApp('rd_token', [arg0],
                                          self.to_position(node, ctx),
                                          self.no_info(ctx), self.viper.Perm,
                                          [formal_arg])

        call_stmt, call = self.translate_expr(node, ctx, self.viper.Int)
        if not call_stmt:
            return call
        raise InvalidProgramException(node, 'purity.violated')
Пример #7
0
    def translate_io_contractfunc_call(self, node: ast.Call, ctx: Context,
                                       impure: bool,
                                       statement: bool) -> StmtsAndExpr:
        """Translate a call to a IO contract function.

        Currently supported functions:

        +   ``token``
        +   ``ctoken``
        +   ``Open``
        """
        func_name = get_func_name(node)
        if func_name == 'token':
            if not impure:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self.translate_must_invoke_token(node, ctx)
        elif func_name == 'ctoken':
            if not impure:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self.translate_must_invoke_ctoken(node, ctx)
        elif func_name == 'Open':
            if not statement:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self._translate_open(node, ctx)
        elif func_name == 'Eval':
            return self._translate_eval(node, ctx)
        elif func_name == 'eval_io':
            if not impure:
                raise InvalidProgramException(node,
                                              'invalid.contract.position')
            return self._translate_eval_io(node, ctx)
        else:
            raise UnsupportedException(node, 'Unsupported contract function.')
Пример #8
0
def get_target(node: ast.AST,
               containers: List[ContainerInterface],
               container: PythonNode,
               type: bool = False) -> Optional[PythonNode]:
    """
    Finds the PythonNode that the given ``node`` refers to, e.g. a PythonClass
    or a PythonVar, if the immediate container (e.g. a PythonMethod) of the node
    is ``container``, by looking in the given ``containers`` (can be e.g.
    PythonMethods, the Context, PythonModules, etc).
    If the ``type`` parameter is set, will also consider string literals as potential
    references.
    """
    if isinstance(node, ast.Name):
        return find_entry(node.id, True, containers)
    elif type and isinstance(node, ast.Str):
        return find_entry(node.s, True, containers)
    elif isinstance(node, ast.Call):
        # For calls, we return the type of the result of the call
        func_name = get_func_name(node)
        if (container and func_name == 'Result'
                and isinstance(container, PythonMethod)):
            # In this case the immediate container must be a method, and we
            # return its result type
            return container.type
        elif (container and func_name == 'super'
              and isinstance(container, PythonMethod)):
            # Return the type of the current method's superclass
            return container.cls.superclass
        elif func_name == 'cast':
            return None
        return get_target(node.func, containers, container)
    elif isinstance(node, ast.Attribute):
        # Find the type of the LHS, so that we can look through its members.
        lhs = get_type(node.value, containers, container)
        if isinstance(lhs, OptionalType):
            lhs = lhs.optional_type
        if isinstance(lhs, UnionType):
            # When receiver's type is union, a method call have multiple
            # targets, therefore None is returned in such cases
            return None
        if isinstance(lhs, GenericType) and lhs.name == 'type':
            # For direct references to type objects, we want to lookup things
            # defined in the class. So instead of type[C], we want to look in
            # class C directly here.
            lhs = lhs.type_args[0]
        if isinstance(lhs, GenericType):
            # Use the class, since we want to look for members.
            lhs = lhs.cls
        # Now collect all containers we have to look through
        containers = []
        if isinstance(lhs, PythonModule):
            # We have to look through all included modules as well, but not
            # through the global one, since it makes no sense to refer to
            # global stuff by looking in a different module
            containers.extend(lhs.get_included_modules(include_global=False))
        else:
            containers.append(lhs)
        while (isinstance(containers[-1], PythonClass)
               and containers[-1].superclass):
            # If we're looking in a class, add all superclasses as well.
            containers.append(containers[-1].superclass)
        return find_entry(node.attr, False, containers)
    elif isinstance(node, ast.Subscript):
        # This might be a type literal like List[int]
        if isinstance(node.value, ast.Name):
            module = next(cont for cont in containers
                          if isinstance(cont, PythonModule))
            type_class = None
            if node.value.id == 'Dict':
                type_class = module.global_module.classes[DICT_TYPE]
            if node.value.id == 'Set':
                type_class = module.global_module.classes[SET_TYPE]
            if node.value.id == 'List':
                type_class = module.global_module.classes[LIST_TYPE]
            if node.value.id == 'Tuple':
                type_class = module.global_module.classes[TUPLE_TYPE]
            if not type_class:
                possible_class = get_target(node.value, containers, container)
                if isinstance(possible_class, PythonType):
                    type_class = possible_class
            if type_class:
                # Look up the type arguments. Also consider string arguments.
                if isinstance(node.slice.value, ast.Tuple):
                    args = [
                        get_target(arg, containers, container, True)
                        for arg in node.slice.value.elts
                    ]
                else:
                    args = [
                        get_target(node.slice.value, containers, container,
                                   True)
                    ]
                return GenericType(type_class, args)
            if node.value.id == 'Optional':
                option = get_target(node.slice.value, containers, container,
                                    True)
                return OptionalType(option)
            if node.value.id == 'Union':
                if isinstance(node.slice.value, ast.Tuple):
                    elts = [
                        get_target(e, containers, container, True)
                        for e in node.slice.value.elts
                    ]
                    return UnionType(elts)
                else:
                    return get_target(node.slice.value, containers, container,
                                      True)

    else:
        return None
Пример #9
0
def _get_call_type(node: ast.Call, module: PythonModule,
                   current_function: PythonMethod,
                   containers: List[ContainerInterface],
                   container: PythonNode) -> PythonType:
    func_name = get_func_name(node)
    if func_name == 'super':
        if len(node.args) == 2:
            return module.classes[node.args[0].id].superclass
        elif not node.args:
            return container.cls.superclass
        else:
            raise InvalidProgramException(node, 'invalid.super.call')
    if func_name == 'len':
        return module.global_module.classes[INT_TYPE]
    if func_name in ('token', 'ctoken', 'MustTerminate', 'MustRelease'):
        return module.global_module.classes[BOOL_TYPE]
    if func_name == PSEQ_TYPE:
        return _get_collection_literal_type(node, ['args'], PSEQ_TYPE, module,
                                            containers, container)
    if func_name == PSET_TYPE:
        return _get_collection_literal_type(node, ['args'], PSET_TYPE, module,
                                            containers, container)
    if func_name == PMSET_TYPE:
        return _get_collection_literal_type(node, ['args'], PMSET_TYPE, module,
                                            containers, container)
    if func_name == 'enumerate':
        if len(node.args) != 1:
            raise UnsupportedException(
                node, 'enumerate only supported with single arg.')
        list_type = module.global_module.classes[LIST_TYPE]
        int_type = module.global_module.classes[INT_TYPE]
        tuple_type = module.global_module.classes[TUPLE_TYPE]
        arg_type = get_type(node.args[0], containers, container)
        iterable_type = _get_iteration_type(arg_type, module, node)
        return GenericType(
            list_type, [GenericType(tuple_type, [int_type, iterable_type])])
    if isinstance(node.func, ast.Name):
        if node.func.id in CONTRACT_FUNCS:
            if node.func.id == 'Result':
                return current_function.type
            elif node.func.id == 'RaisedException':
                ctxs = [
                    cont for cont in containers
                    if hasattr(cont, 'var_aliases')
                ]
                ctx = ctxs[0] if ctxs else None
                assert ctx
                assert ctx.current_contract_exception is not None
                return ctx.current_contract_exception
            elif node.func.id in ('Acc', 'Rd', 'Read', 'Implies', 'Forall',
                                  'Exists', 'MayCreate', 'MaySet', 'Low',
                                  'LowVal', 'LowEvent', 'LowExit'):
                return module.global_module.classes[BOOL_TYPE]
            elif node.func.id == 'Declassify':
                return None
            elif node.func.id == 'Old':
                return get_type(node.args[0], containers, container)
            elif node.func.id == 'Unfolding':
                return get_type(node.args[1], containers, container)
            elif node.func.id == 'ToSeq':
                arg_type = get_type(node.args[0], containers, container)
                seq_class = module.global_module.classes[PSEQ_TYPE]
                content_type = _get_iteration_type(arg_type, module, node)
                return GenericType(seq_class, [content_type])
            elif node.func.id == 'Previous':
                arg_type = get_type(node.args[0], containers, container)
                list_class = module.global_module.classes[PSEQ_TYPE]
                return GenericType(list_class, [arg_type])
            elif node.func.id in ('getArg', 'getOld', 'getMethod'):
                object_class = module.global_module.classes[OBJECT_TYPE]
                return object_class
            elif node.func.id == 'Let':
                body_type = get_target(node.args[1], containers, container)
                if isinstance(body_type, PythonType):
                    return body_type
                raise InvalidProgramException(node, 'invalid.let')
            else:
                raise UnsupportedException(node)
        elif node.func.id in BUILTINS:
            if node.func.id in ('isinstance', BOOL_TYPE):
                return module.global_module.classes[BOOL_TYPE]
            elif node.func.id == 'cast':
                return get_target(node.args[0], containers, container)
            else:
                raise UnsupportedException(node)
        if node.func.id in module.classes:
            return module.global_module.classes[node.func.id]
        elif module.get_func_or_method(node.func.id) is not None:
            target = module.get_func_or_method(node.func.id)
            return target.type
    elif isinstance(node.func, ast.Attribute):
        rectype = get_type(node.func.value, containers, container)
        if isinstance(rectype, UnionType):
            set_of_classes = rectype.get_types() - {None}
            set_of_return_types = {
                type.get_func_or_method(node.func.attr).type
                for type in set_of_classes
            }
            if len(set_of_return_types) == 1:
                return set_of_return_types.pop()
            elif len(set_of_return_types) == 2 and None in set_of_return_types:
                return OptionalType((set_of_return_types - {None}).pop())
            else:
                return UnionType(list(set_of_return_types))
        elif isinstance(rectype, PythonType):
            target = rectype.get_func_or_method(node.func.attr)
            if target.generic_type != -1:
                return rectype.type_args[target.generic_type]
            else:
                return target.type
    else:
        raise UnsupportedException(node)
Пример #10
0
 def translate_contractfunc_call(self,
                                 node: ast.Call,
                                 ctx: Context,
                                 impure=False,
                                 statement=False) -> StmtsAndExpr:
     """
     Translates calls to contract functions like Result() and Acc()
     """
     func_name = get_func_name(node)
     if func_name == 'Result':
         return self.translate_result(node, ctx)
     elif func_name == 'RaisedException':
         return self.translate_raised_exception(node, ctx)
     elif func_name in ('Acc', 'Rd', 'Wildcard'):
         if not impure:
             raise InvalidProgramException(node,
                                           'invalid.contract.position')
         if func_name == 'Rd':
             perm = self.get_arp_for_context(node, ctx)
         elif func_name == 'Wildcard':
             perm = self.viper.WildcardPerm(self.to_position(node, ctx),
                                            self.no_info(ctx))
         else:
             perm = self._get_perm(node, ctx)
         if isinstance(node.args[0], ast.Call):
             return self.translate_acc_predicate(node, perm, ctx)
         else:
             if isinstance(node.args[0], ast.Attribute):
                 type = self.get_type(node.args[0].value, ctx)
                 if isinstance(type, UnionType):
                     guarded_field_access = []
                     stmt, receiver = self.translate_expr(
                         node.args[0].value, ctx)
                     for recv_type in toposort_classes(type.get_types() -
                                                       {None}):
                         target = self.get_target(node.args[0].value, ctx)
                         field_guard = self.var_type_check(
                             target.sil_name, recv_type,
                             self.to_position(node, ctx), ctx)
                         field = recv_type.get_field(
                             node.args[0].attr).actual_field
                         field_access = self.viper.FieldAccess(
                             receiver, field.sil_field,
                             self.to_position(node, ctx), self.no_info(ctx))
                         field_acc = self._translate_acc_field(
                             field_access, field.type, perm,
                             self.to_position(node, ctx), ctx)
                         guarded_field_access.append(
                             (field_guard, field_acc))
                     if len(guarded_field_access) == 1:
                         _, field_acc = guarded_field_access[0]
                         return stmt, field_acc
                     else:
                         return (stmt,
                                 chain_cond_exp(guarded_field_access,
                                                self.viper,
                                                self.to_position(node, ctx),
                                                self.no_info(ctx), ctx))
             target = self.get_target(node.args[0], ctx)
             if isinstance(target, PythonField):
                 return self.translate_acc_field(node, perm, ctx)
             else:
                 if not isinstance(target, PythonGlobalVar):
                     raise InvalidProgramException(node, 'invalid.acc')
                 return self.translate_acc_global(node, perm, ctx)
     elif func_name in BUILTIN_PREDICATES:
         return [], self.translate_unwrapped_builtin_predicate(node, ctx)
     elif func_name == 'MaySet':
         return self.translate_may_set(node, ctx)
     elif func_name == 'MayCreate':
         return self.translate_may_create(node, ctx)
     elif func_name in ('Assert', 'Assume', 'Fold', 'Unfold'):
         if not statement:
             raise InvalidProgramException(node,
                                           'invalid.contract.position')
         if func_name == 'Assert':
             return self.translate_assert(node, ctx)
         elif func_name == 'Assume':
             return self.translate_assume(node, ctx)
         elif func_name == 'Fold':
             return self.translate_fold(node, ctx)
         elif func_name == 'Unfold':
             return self.translate_unfold(node, ctx)
     elif func_name == 'Implies':
         return self.translate_implies(node, ctx, impure)
     elif func_name == 'Old':
         return self.translate_old(node, ctx)
     elif func_name == 'Unfolding':
         return self.translate_unfolding(node, ctx, impure)
     elif func_name == 'Low':
         return self.translate_low(node, ctx)
     elif func_name == 'LowVal':
         return self.translate_lowval(node, ctx)
     elif func_name == 'LowEvent':
         return self.translate_lowevent(node, ctx)
     elif func_name == 'LowExit':
         return self.translate_lowexit(node, ctx)
     elif func_name == 'Declassify':
         return self.translate_declassify(node, ctx)
     elif func_name == 'TerminatesSif':
         return self.translate_terminates_sif(node, ctx)
     elif func_name == 'Forall':
         return self.translate_forall(node, ctx, impure)
     elif func_name == 'Previous':
         return self.translate_previous(node, ctx)
     elif func_name == 'Let':
         return self.translate_let(node, ctx, impure)
     elif func_name == PSEQ_TYPE:
         return self.translate_sequence(node, ctx)
     elif func_name == PSET_TYPE:
         return self.translate_pset(node, ctx)
     elif func_name == PMSET_TYPE:
         return self.translate_mset(node, ctx)
     elif func_name == 'ToSeq':
         return self.translate_to_sequence(node, ctx)
     elif func_name == 'Joinable':
         return self.translate_joinable(node, ctx)
     elif func_name == 'getArg':
         return self.translate_get_arg(node, ctx)
     elif func_name == 'getOld':
         return self.translate_get_old(node, ctx)
     elif func_name == 'getMethod':
         raise InvalidProgramException(node, 'invalid.get.method.use')
     elif func_name == 'arg':
         raise InvalidProgramException(node, 'invalid.arg.use')
     else:
         raise UnsupportedException(node)