def _create_quantifier_contains_expr(self, e: Expr, domain_node: ast.AST, ctx: Context, trigger=False ) -> Tuple[List[Stmt], Expr, bool]: """ Creates the left hand side of the implication in a quantifier expression, which says that e is an element of the given domain. The last return value specifies if the returned expression is recommended to be used as a trigger. """ domain_old = False if (isinstance(domain_node, ast.Call) and get_func_name(domain_node) == 'Old'): domain_old = True domain_node = domain_node.args[0] ref_var = self.to_ref(e, ctx) pos = self.to_position(domain_node, ctx) info = self.no_info(ctx) dom_target = self.get_target(domain_node, ctx) if isinstance(dom_target, PythonType): result = self.type_check(ref_var, dom_target, pos, ctx, False) # Not recommended as a trigger, since it's very broad and will get triggered # a lot. return [], result, False dom_stmt, domain = self.translate_expr(domain_node, ctx) dom_type = self.get_type(domain_node, ctx) result = self.get_quantifier_lhs(ref_var, dom_type, domain, domain_node, ctx, pos, trigger) if domain_old: result = self.viper.Old(result, pos, info) return dom_stmt, result, True
def translate_unfolding(self, node: ast.Call, ctx: Context, impure=False) -> StmtsAndExpr: """ Translates a call to the Unfolding() contract function. """ if len(node.args) != 2: raise InvalidProgramException(node, 'invalid.contract.call') if not isinstance(node.args[0], ast.Call): raise InvalidProgramException(node, 'invalid.contract.call') if get_func_name(node.args[0]) in ('Acc', 'Rd'): pred_call = node.args[0].args[0] else: pred_call = node.args[0] target_pred = self.get_target(pred_call, ctx) if (target_pred and (not isinstance(target_pred, PythonMethod) or not target_pred.predicate)): raise InvalidProgramException(node, 'invalid.contract.call') pred_stmt, pred = self.translate_expr(node.args[0], ctx, self.viper.Bool, True) if pred_stmt: raise InvalidProgramException(node, 'purity.violated') expr_stmt, expr = self.translate_expr(node.args[1], ctx) if expr_stmt: raise InvalidProgramException(node, 'purity.violated') expr = self.unwrap(expr) unfold = self.viper.Unfolding(pred, expr, self.to_position(node, ctx), self.no_info(ctx)) return expr_stmt, unfold
def translate_unfold(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: """ Translates a call to the Unfold() contract function. """ if len(node.args) != 1: raise InvalidProgramException(node, 'invalid.contract.call') if not isinstance(node.args[0], ast.Call): raise InvalidProgramException(node, 'invalid.contract.call') if get_func_name(node.args[0]) in ('Acc', 'Rd'): pred_call = node.args[0].args[0] else: pred_call = node.args[0] target_pred = self.get_target(pred_call, ctx) if (target_pred and (not isinstance(target_pred, PythonMethod) or not target_pred.predicate)): raise InvalidProgramException(node, 'invalid.contract.call') pred_stmt, pred = self.translate_expr(node.args[0], ctx, self.viper.Bool, True) if self._is_family_fold(node): # Predicate called on receiver, so it belongs to a family if ctx.ignore_family_folds: return [], None if pred_stmt: raise InvalidProgramException(node, 'purity.violated') unfold = self.viper.Unfold(pred, self.to_position(node, ctx), self.no_info(ctx)) return [unfold], None
def translate_contract_Call(self, node: ast.Call, ctx: Context) -> Expr: if get_func_name(node) in CONTRACT_WRAPPER_FUNCS: stmt, res = self.translate_expr(node.args[0], ctx, self.viper.Bool, True) if stmt: raise InvalidProgramException(node, 'purity.violated') return res else: raise UnsupportedException(node)
def translate_obligation_contractfunc_call(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: """Translate a call to obligation contract function.""" func_name = get_func_name(node) if func_name == 'MustTerminate': return self._translate_must_terminate(node, ctx) elif func_name == 'MustRelease': return self._translate_must_release(node, ctx) elif func_name == 'WaitLevel': raise InvalidProgramException(node, 'invalid.wait_level.use') elif func_name == 'Level': raise InvalidProgramException(node, 'invalid.level.use') else: raise UnsupportedException(node, 'Unsupported contract function.')
def translate_perm_Call(self, node: ast.Call, ctx: Context) -> Expr: func_name = get_func_name(node) if func_name == 'ARP': if not ctx.arp: raise UnsupportedException( node, 'ARP not supported. Use --arp flag.') if len(node.args) == 0: return self.get_arp_for_context(node, ctx) elif len(node.args) == 1: arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx, self.viper.Int) if arg0_stmt: raise InvalidProgramException(node, 'purity.violated') formal_arg = self.viper.LocalVarDecl( 'count', self.viper.Int, self.to_position(node, ctx), self.no_info(ctx)) return self.viper.FuncApp('rdc', [arg0], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, [formal_arg]) elif func_name == 'getARP': if not ctx.arp: raise UnsupportedException( node, 'ARP not supported. Use --arp flag.') if len(node.args) == 1: formal_arg = self.viper.LocalVarDecl( 'tk', self.viper.Ref, self.to_position(node, ctx), self.no_info(ctx)) arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx, self.viper.Ref) if arg0_stmt: raise InvalidProgramException(node, 'purity.violated') return self.viper.FuncApp('rd_token', [arg0], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, [formal_arg]) call_stmt, call = self.translate_expr(node, ctx, self.viper.Int) if not call_stmt: return call raise InvalidProgramException(node, 'purity.violated')
def translate_io_contractfunc_call(self, node: ast.Call, ctx: Context, impure: bool, statement: bool) -> StmtsAndExpr: """Translate a call to a IO contract function. Currently supported functions: + ``token`` + ``ctoken`` + ``Open`` """ func_name = get_func_name(node) if func_name == 'token': if not impure: raise InvalidProgramException(node, 'invalid.contract.position') return self.translate_must_invoke_token(node, ctx) elif func_name == 'ctoken': if not impure: raise InvalidProgramException(node, 'invalid.contract.position') return self.translate_must_invoke_ctoken(node, ctx) elif func_name == 'Open': if not statement: raise InvalidProgramException(node, 'invalid.contract.position') return self._translate_open(node, ctx) elif func_name == 'Eval': return self._translate_eval(node, ctx) elif func_name == 'eval_io': if not impure: raise InvalidProgramException(node, 'invalid.contract.position') return self._translate_eval_io(node, ctx) else: raise UnsupportedException(node, 'Unsupported contract function.')
def get_target(node: ast.AST, containers: List[ContainerInterface], container: PythonNode, type: bool = False) -> Optional[PythonNode]: """ Finds the PythonNode that the given ``node`` refers to, e.g. a PythonClass or a PythonVar, if the immediate container (e.g. a PythonMethod) of the node is ``container``, by looking in the given ``containers`` (can be e.g. PythonMethods, the Context, PythonModules, etc). If the ``type`` parameter is set, will also consider string literals as potential references. """ if isinstance(node, ast.Name): return find_entry(node.id, True, containers) elif type and isinstance(node, ast.Str): return find_entry(node.s, True, containers) elif isinstance(node, ast.Call): # For calls, we return the type of the result of the call func_name = get_func_name(node) if (container and func_name == 'Result' and isinstance(container, PythonMethod)): # In this case the immediate container must be a method, and we # return its result type return container.type elif (container and func_name == 'super' and isinstance(container, PythonMethod)): # Return the type of the current method's superclass return container.cls.superclass elif func_name == 'cast': return None return get_target(node.func, containers, container) elif isinstance(node, ast.Attribute): # Find the type of the LHS, so that we can look through its members. lhs = get_type(node.value, containers, container) if isinstance(lhs, OptionalType): lhs = lhs.optional_type if isinstance(lhs, UnionType): # When receiver's type is union, a method call have multiple # targets, therefore None is returned in such cases return None if isinstance(lhs, GenericType) and lhs.name == 'type': # For direct references to type objects, we want to lookup things # defined in the class. So instead of type[C], we want to look in # class C directly here. lhs = lhs.type_args[0] if isinstance(lhs, GenericType): # Use the class, since we want to look for members. lhs = lhs.cls # Now collect all containers we have to look through containers = [] if isinstance(lhs, PythonModule): # We have to look through all included modules as well, but not # through the global one, since it makes no sense to refer to # global stuff by looking in a different module containers.extend(lhs.get_included_modules(include_global=False)) else: containers.append(lhs) while (isinstance(containers[-1], PythonClass) and containers[-1].superclass): # If we're looking in a class, add all superclasses as well. containers.append(containers[-1].superclass) return find_entry(node.attr, False, containers) elif isinstance(node, ast.Subscript): # This might be a type literal like List[int] if isinstance(node.value, ast.Name): module = next(cont for cont in containers if isinstance(cont, PythonModule)) type_class = None if node.value.id == 'Dict': type_class = module.global_module.classes[DICT_TYPE] if node.value.id == 'Set': type_class = module.global_module.classes[SET_TYPE] if node.value.id == 'List': type_class = module.global_module.classes[LIST_TYPE] if node.value.id == 'Tuple': type_class = module.global_module.classes[TUPLE_TYPE] if not type_class: possible_class = get_target(node.value, containers, container) if isinstance(possible_class, PythonType): type_class = possible_class if type_class: # Look up the type arguments. Also consider string arguments. if isinstance(node.slice.value, ast.Tuple): args = [ get_target(arg, containers, container, True) for arg in node.slice.value.elts ] else: args = [ get_target(node.slice.value, containers, container, True) ] return GenericType(type_class, args) if node.value.id == 'Optional': option = get_target(node.slice.value, containers, container, True) return OptionalType(option) if node.value.id == 'Union': if isinstance(node.slice.value, ast.Tuple): elts = [ get_target(e, containers, container, True) for e in node.slice.value.elts ] return UnionType(elts) else: return get_target(node.slice.value, containers, container, True) else: return None
def _get_call_type(node: ast.Call, module: PythonModule, current_function: PythonMethod, containers: List[ContainerInterface], container: PythonNode) -> PythonType: func_name = get_func_name(node) if func_name == 'super': if len(node.args) == 2: return module.classes[node.args[0].id].superclass elif not node.args: return container.cls.superclass else: raise InvalidProgramException(node, 'invalid.super.call') if func_name == 'len': return module.global_module.classes[INT_TYPE] if func_name in ('token', 'ctoken', 'MustTerminate', 'MustRelease'): return module.global_module.classes[BOOL_TYPE] if func_name == PSEQ_TYPE: return _get_collection_literal_type(node, ['args'], PSEQ_TYPE, module, containers, container) if func_name == PSET_TYPE: return _get_collection_literal_type(node, ['args'], PSET_TYPE, module, containers, container) if func_name == PMSET_TYPE: return _get_collection_literal_type(node, ['args'], PMSET_TYPE, module, containers, container) if func_name == 'enumerate': if len(node.args) != 1: raise UnsupportedException( node, 'enumerate only supported with single arg.') list_type = module.global_module.classes[LIST_TYPE] int_type = module.global_module.classes[INT_TYPE] tuple_type = module.global_module.classes[TUPLE_TYPE] arg_type = get_type(node.args[0], containers, container) iterable_type = _get_iteration_type(arg_type, module, node) return GenericType( list_type, [GenericType(tuple_type, [int_type, iterable_type])]) if isinstance(node.func, ast.Name): if node.func.id in CONTRACT_FUNCS: if node.func.id == 'Result': return current_function.type elif node.func.id == 'RaisedException': ctxs = [ cont for cont in containers if hasattr(cont, 'var_aliases') ] ctx = ctxs[0] if ctxs else None assert ctx assert ctx.current_contract_exception is not None return ctx.current_contract_exception elif node.func.id in ('Acc', 'Rd', 'Read', 'Implies', 'Forall', 'Exists', 'MayCreate', 'MaySet', 'Low', 'LowVal', 'LowEvent', 'LowExit'): return module.global_module.classes[BOOL_TYPE] elif node.func.id == 'Declassify': return None elif node.func.id == 'Old': return get_type(node.args[0], containers, container) elif node.func.id == 'Unfolding': return get_type(node.args[1], containers, container) elif node.func.id == 'ToSeq': arg_type = get_type(node.args[0], containers, container) seq_class = module.global_module.classes[PSEQ_TYPE] content_type = _get_iteration_type(arg_type, module, node) return GenericType(seq_class, [content_type]) elif node.func.id == 'Previous': arg_type = get_type(node.args[0], containers, container) list_class = module.global_module.classes[PSEQ_TYPE] return GenericType(list_class, [arg_type]) elif node.func.id in ('getArg', 'getOld', 'getMethod'): object_class = module.global_module.classes[OBJECT_TYPE] return object_class elif node.func.id == 'Let': body_type = get_target(node.args[1], containers, container) if isinstance(body_type, PythonType): return body_type raise InvalidProgramException(node, 'invalid.let') else: raise UnsupportedException(node) elif node.func.id in BUILTINS: if node.func.id in ('isinstance', BOOL_TYPE): return module.global_module.classes[BOOL_TYPE] elif node.func.id == 'cast': return get_target(node.args[0], containers, container) else: raise UnsupportedException(node) if node.func.id in module.classes: return module.global_module.classes[node.func.id] elif module.get_func_or_method(node.func.id) is not None: target = module.get_func_or_method(node.func.id) return target.type elif isinstance(node.func, ast.Attribute): rectype = get_type(node.func.value, containers, container) if isinstance(rectype, UnionType): set_of_classes = rectype.get_types() - {None} set_of_return_types = { type.get_func_or_method(node.func.attr).type for type in set_of_classes } if len(set_of_return_types) == 1: return set_of_return_types.pop() elif len(set_of_return_types) == 2 and None in set_of_return_types: return OptionalType((set_of_return_types - {None}).pop()) else: return UnionType(list(set_of_return_types)) elif isinstance(rectype, PythonType): target = rectype.get_func_or_method(node.func.attr) if target.generic_type != -1: return rectype.type_args[target.generic_type] else: return target.type else: raise UnsupportedException(node)
def translate_contractfunc_call(self, node: ast.Call, ctx: Context, impure=False, statement=False) -> StmtsAndExpr: """ Translates calls to contract functions like Result() and Acc() """ func_name = get_func_name(node) if func_name == 'Result': return self.translate_result(node, ctx) elif func_name == 'RaisedException': return self.translate_raised_exception(node, ctx) elif func_name in ('Acc', 'Rd', 'Wildcard'): if not impure: raise InvalidProgramException(node, 'invalid.contract.position') if func_name == 'Rd': perm = self.get_arp_for_context(node, ctx) elif func_name == 'Wildcard': perm = self.viper.WildcardPerm(self.to_position(node, ctx), self.no_info(ctx)) else: perm = self._get_perm(node, ctx) if isinstance(node.args[0], ast.Call): return self.translate_acc_predicate(node, perm, ctx) else: if isinstance(node.args[0], ast.Attribute): type = self.get_type(node.args[0].value, ctx) if isinstance(type, UnionType): guarded_field_access = [] stmt, receiver = self.translate_expr( node.args[0].value, ctx) for recv_type in toposort_classes(type.get_types() - {None}): target = self.get_target(node.args[0].value, ctx) field_guard = self.var_type_check( target.sil_name, recv_type, self.to_position(node, ctx), ctx) field = recv_type.get_field( node.args[0].attr).actual_field field_access = self.viper.FieldAccess( receiver, field.sil_field, self.to_position(node, ctx), self.no_info(ctx)) field_acc = self._translate_acc_field( field_access, field.type, perm, self.to_position(node, ctx), ctx) guarded_field_access.append( (field_guard, field_acc)) if len(guarded_field_access) == 1: _, field_acc = guarded_field_access[0] return stmt, field_acc else: return (stmt, chain_cond_exp(guarded_field_access, self.viper, self.to_position(node, ctx), self.no_info(ctx), ctx)) target = self.get_target(node.args[0], ctx) if isinstance(target, PythonField): return self.translate_acc_field(node, perm, ctx) else: if not isinstance(target, PythonGlobalVar): raise InvalidProgramException(node, 'invalid.acc') return self.translate_acc_global(node, perm, ctx) elif func_name in BUILTIN_PREDICATES: return [], self.translate_unwrapped_builtin_predicate(node, ctx) elif func_name == 'MaySet': return self.translate_may_set(node, ctx) elif func_name == 'MayCreate': return self.translate_may_create(node, ctx) elif func_name in ('Assert', 'Assume', 'Fold', 'Unfold'): if not statement: raise InvalidProgramException(node, 'invalid.contract.position') if func_name == 'Assert': return self.translate_assert(node, ctx) elif func_name == 'Assume': return self.translate_assume(node, ctx) elif func_name == 'Fold': return self.translate_fold(node, ctx) elif func_name == 'Unfold': return self.translate_unfold(node, ctx) elif func_name == 'Implies': return self.translate_implies(node, ctx, impure) elif func_name == 'Old': return self.translate_old(node, ctx) elif func_name == 'Unfolding': return self.translate_unfolding(node, ctx, impure) elif func_name == 'Low': return self.translate_low(node, ctx) elif func_name == 'LowVal': return self.translate_lowval(node, ctx) elif func_name == 'LowEvent': return self.translate_lowevent(node, ctx) elif func_name == 'LowExit': return self.translate_lowexit(node, ctx) elif func_name == 'Declassify': return self.translate_declassify(node, ctx) elif func_name == 'TerminatesSif': return self.translate_terminates_sif(node, ctx) elif func_name == 'Forall': return self.translate_forall(node, ctx, impure) elif func_name == 'Previous': return self.translate_previous(node, ctx) elif func_name == 'Let': return self.translate_let(node, ctx, impure) elif func_name == PSEQ_TYPE: return self.translate_sequence(node, ctx) elif func_name == PSET_TYPE: return self.translate_pset(node, ctx) elif func_name == PMSET_TYPE: return self.translate_mset(node, ctx) elif func_name == 'ToSeq': return self.translate_to_sequence(node, ctx) elif func_name == 'Joinable': return self.translate_joinable(node, ctx) elif func_name == 'getArg': return self.translate_get_arg(node, ctx) elif func_name == 'getOld': return self.translate_get_old(node, ctx) elif func_name == 'getMethod': raise InvalidProgramException(node, 'invalid.get.method.use') elif func_name == 'arg': raise InvalidProgramException(node, 'invalid.arg.use') else: raise UnsupportedException(node)