def get_arp_for_context(self, node: ast.AST, ctx: Context): if (ctx.actual_function and isinstance(ctx.actual_function, PythonMethod) and ctx.actual_function.pure): return self.viper.WildcardPerm(self.to_position(node, ctx), self.no_info(ctx)) else: if not ctx.arp: raise UnsupportedException( node, 'ARP not supported. Use --arp flag.') if ctx.current_thread_object is not None: formal_arg = self.viper.LocalVarDecl( 'tk', self.viper.Ref, self.to_position(node, ctx), self.no_info(ctx)) if ctx.is_thread_start: return self.viper.FuncApp('rd_token_fresh', [ctx.current_thread_object], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, [formal_arg]) else: return self.viper.FuncApp('rd_token', [ctx.current_thread_object], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, [formal_arg]) else: return self.viper.FuncApp('rd', [], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, {})
def translate_builtin_predicate(self, node: ast.Call, perm: Expr, args: List[Expr], ctx: Context) -> Expr: name = node.func.id seq_ref = self.viper.SeqType(self.viper.Ref) set_ref = self.viper.SetType(self.viper.Ref) pos = self.to_position(node, ctx) if name == 'list_pred': # field list_acc : Seq[Ref] return self._get_field_perm('list_acc', seq_ref, perm, args[0], pos, ctx) elif name == 'set_pred': # field set_acc : Set[Ref] return self._get_field_perm('set_acc', set_ref, perm, args[0], pos, ctx) elif name == 'dict_pred': # field dict_acc : Set[Ref] && dict_acc2 : Ref acc1 = self._get_field_perm('dict_acc', set_ref, perm, args[0], pos, ctx) acc2 = self._get_field_perm('dict_acc2', self.viper.Ref, perm, args[0], pos, ctx) return self.viper.And(acc1, acc2, pos, self.no_info(ctx)) elif name == 'MayStart': return self.translate_may_start(node, args, perm, ctx) elif name == 'ThreadPost': return self.translate_thread_post(node, args, perm, ctx) else: raise UnsupportedException(node)
def visit_Call(self, node: ast.Call) -> None: """Parse IO operation properties. Currently, only parses properties such as ``Terminates`` and ``TerminationMeasure``. """ assert self._current_io_operation is not None assert self._current_node is not None if (isinstance(node.func, ast.Name) and node.func.id in IO_OPERATION_PROPERTY_FUNCS): for child in self._current_node.body: if (isinstance(child, ast.Expr) and child.value == node): break else: self._raise_invalid_operation('misplaced_property', node) operation = self._current_io_operation arg = node.args[0] if node.func.id == 'Terminates': if not operation.set_terminates(arg): self._raise_invalid_operation('duplicate_property', node) elif node.func.id == 'TerminationMeasure': if not operation.set_termination_measure(arg): self._raise_invalid_operation('duplicate_property', node) else: raise UnsupportedException(node, 'Unsupported property type.') self._in_property = True for arg in node.args: self.visit(arg) self._in_property = False else: for arg in node.args: self.visit(arg)
def translate_contract_Call(self, node: ast.Call, ctx: Context) -> Expr: if get_func_name(node) in CONTRACT_WRAPPER_FUNCS: stmt, res = self.translate_expr(node.args[0], ctx, self.viper.Bool, True) if stmt: raise InvalidProgramException(node, 'purity.violated') return res else: raise UnsupportedException(node)
def _translate_wrapper(self, wrapper: Wrapper, previous: Expr, function: PythonMethod, ctx: Context) -> Expr: if isinstance(wrapper, ReturnWrapper): return self._translate_return_wrapper(wrapper, previous, function, ctx) elif isinstance(wrapper, AssignWrapper): return self._translate_assign_wrapper(wrapper, previous, function, ctx) else: raise UnsupportedException(wrapper)
def translate_perm_Call(self, node: ast.Call, ctx: Context) -> Expr: func_name = get_func_name(node) if func_name == 'ARP': if not ctx.arp: raise UnsupportedException( node, 'ARP not supported. Use --arp flag.') if len(node.args) == 0: return self.get_arp_for_context(node, ctx) elif len(node.args) == 1: arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx, self.viper.Int) if arg0_stmt: raise InvalidProgramException(node, 'purity.violated') formal_arg = self.viper.LocalVarDecl( 'count', self.viper.Int, self.to_position(node, ctx), self.no_info(ctx)) return self.viper.FuncApp('rdc', [arg0], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, [formal_arg]) elif func_name == 'getARP': if not ctx.arp: raise UnsupportedException( node, 'ARP not supported. Use --arp flag.') if len(node.args) == 1: formal_arg = self.viper.LocalVarDecl( 'tk', self.viper.Ref, self.to_position(node, ctx), self.no_info(ctx)) arg0_stmt, arg0 = self.translate_expr(node.args[0], ctx, self.viper.Ref) if arg0_stmt: raise InvalidProgramException(node, 'purity.violated') return self.viper.FuncApp('rd_token', [arg0], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, [formal_arg]) call_stmt, call = self.translate_expr(node, ctx, self.viper.Int) if not call_stmt: return call raise InvalidProgramException(node, 'purity.violated')
def _get_subscript_type(value_type: PythonType, module: PythonModule, node: ast.AST) -> PythonType: if isinstance(value_type, OptionalType): value_type = value_type.cls if value_type.name == TUPLE_TYPE: if isinstance(node, ast.Subscript): if isinstance(node.slice, ast.Slice): raise UnsupportedException(node, 'tuple slicing') if len(value_type.type_args) == 1: return value_type.type_args[0] if isinstance(node.slice.value, ast.UnaryOp): if (isinstance(node.slice.value.op, ast.USub) and isinstance(node.slice.value.operand, ast.Num)): index = -node.slice.value.operand.n else: raise UnsupportedException(node, 'dynamic subscript type') elif isinstance(node.slice.value, ast.Num): index = node.slice.value.n return value_type.type_args[index] else: return common_supertype(value_type.type_args) elif value_type.name == LIST_TYPE: return value_type.type_args[0] elif value_type.name == SET_TYPE: return value_type.type_args[0] elif value_type.name in (DICT_TYPE, 'defaultdict', 'ExpiringDict'): # FIXME: This is very unfortunate, but right now we cannot handle this # generically, so we have to hard code these two cases for the moment. return value_type.type_args[1] elif value_type.name in (RANGE_TYPE, BYTES_TYPE): return module.global_module.classes[INT_TYPE] elif value_type.name == PSEQ_TYPE: return value_type.type_args[0] elif value_type.name == PSET_TYPE: return value_type.type_args[0] elif value_type.name == PMSET_TYPE: return value_type.type_args[0] elif value_type.python_class.get_func_or_method('__getitem__'): return value_type.python_class.get_func_or_method('__getitem__').type else: raise UnsupportedException(node)
def translate_perm_Name(self, node: ast.Name, ctx: Context) -> Expr: if node.id == 'RD_PRED': if not ctx.arp: raise UnsupportedException( node, 'ARP not supported. Use --arp flag.') return self.viper.FuncApp('globalRd', [], self.to_position(node, ctx), self.no_info(ctx), self.viper.Perm, {}) stmt, res = self.translate_expr(node, ctx) if stmt: raise InvalidProgramException(node, 'purity.violated') return res
def translate_pure_Assign(self, conds: List, node: ast.Assign, ctx: Context) -> List[Wrapper]: """ Translates an assign statement to an AssignWrapper """ assert len(node.targets) == 1 if not isinstance(node.targets[0], ast.Name): raise UnsupportedException( node, "Multi-target assignments are not supported in pure functions." ) wrapper = AssignWrapper(node.targets[0].id, conds, node.value, node) return [wrapper]
def translate_low(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: """ Translates a call to the Low() contract function. """ if len(node.args) != 1: raise UnsupportedException(node, "Low() requires exactly one argument") stmts, expr = self.translate_expr(node.args[0], ctx) if stmts: raise InvalidProgramException(node, 'purity.violated') info = self._create_dyn_check_info(ctx) return [], self.viper.Low(expr, None, self.to_position(node, ctx), info)
def translate_obligation_contractfunc_call(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: """Translate a call to obligation contract function.""" func_name = get_func_name(node) if func_name == 'MustTerminate': return self._translate_must_terminate(node, ctx) elif func_name == 'MustRelease': return self._translate_must_release(node, ctx) elif func_name == 'WaitLevel': raise InvalidProgramException(node, 'invalid.wait_level.use') elif func_name == 'Level': raise InvalidProgramException(node, 'invalid.level.use') else: raise UnsupportedException(node, 'Unsupported contract function.')
def translate_perm_BinOp(self, node: ast.BinOp, ctx: Context) -> Expr: if isinstance(node.op, ast.Div): left, left_int = self.translate_perm_or_int(node.left, ctx) right_stmt, right = self.translate_expr(node.right, ctx, target_type=self.viper.Int) if right_stmt: raise InvalidProgramException(node, 'purity.violated') if left_int: return self.viper.FractionalPerm(left, right, self.to_position(node, ctx), self.no_info(ctx)) return self.viper.PermDiv(left, right, self.to_position(node, ctx), self.no_info(ctx)) if isinstance(node.op, ast.Mult): left, left_int = self.translate_perm_or_int(node.left, ctx) right, right_int = self.translate_perm_or_int(node.right, ctx) if left_int != right_int and right_int: left, left_int, right, right_int = right, right_int, left, left_int if left_int and right_int: return self.viper.Mul(left, right, self.to_position(node, ctx), self.no_info(ctx)) if left_int or right_int: return self.viper.IntPermMul(left, right, self.to_position(node, ctx), self.no_info(ctx)) return self.viper.PermMul(left, right, self.to_position(node, ctx), self.no_info(ctx)) new_node = None if isinstance(node.op, ast.Add): new_node = self.viper.PermAdd elif isinstance(node.op, ast.Sub): new_node = self.viper.PermSub if new_node: left = self.translate_perm(node.left, ctx) right = self.translate_perm(node.right, ctx) return new_node(left, right, self.to_position(node, ctx), self.no_info(ctx)) raise UnsupportedException(node)
def translate_acc_predicate(self, node: ast.Call, perm: Expr, ctx: Context) -> StmtsAndExpr: """ Translates a call to the Acc() contract function with a predicate call inside to a predicate access. """ assert isinstance(node.args[0], ast.Call) call = node.args[0] # The predicate inside is a function call in python. args = [] arg_stmts = [] for arg in call.args: arg_stmt, arg_expr = self.translate_expr(arg, ctx) arg_stmts = arg_stmts + arg_stmt args.append(arg_expr) # Get the predicate inside the Acc() if isinstance(call.func, ast.Name): if call.func.id in BUILTIN_PREDICATES: return arg_stmts, self.translate_builtin_predicate( call, perm, args, ctx) else: pred = self.get_target(call.func, ctx) elif isinstance(call.func, ast.Attribute): receiver = self.get_target(call.func.value, ctx) if isinstance(receiver, PythonModule): pred = receiver.predicates[call.func.attr] else: rec_stmt, receiver = self.translate_expr(call.func.value, ctx) assert not rec_stmt receiver_class = self.get_type(call.func.value, ctx) name = call.func.attr pred = receiver_class.get_predicate(name) args = [receiver] + args else: raise UnsupportedException(node) if not (isinstance(pred, PythonMethod) and pred.predicate): raise InvalidProgramException(node, 'invalid.acc') pred_name = pred.sil_name # If the predicate is part of a family, find the correct version. if pred.cls: family_root = pred.cls while (family_root.superclass and family_root.superclass.get_predicate(name)): family_root = family_root.superclass pred_name = family_root.get_predicate(name).sil_name return arg_stmts, self.create_predicate_access(pred_name, args, perm, node, ctx)
def _get_function_call(self, receiver: PythonType, func_name: str, args: List[Expr], arg_types: List[PythonType], node: ast.AST, ctx: Context, position: Position = None) -> FuncApp: """ Creates a function application of the function called func_name, with the given receiver and arguments. Boxes arguments if necessary, and unboxed the result if needed as well. This method only handles receivers of non-union types. """ if receiver: target_cls = receiver func = target_cls.get_function(func_name) else: for container in ctx.module.get_included_modules(): if func_name in container.functions: func = container.functions[func_name] break if not func: if receiver and target_cls.get_method(func_name): msg = 'Called method is expected to be pure: ' + func_name raise UnsupportedException(node, msg) raise InvalidProgramException(node, 'unknown.function.called') formal_args = [] actual_args = [] assert len(args) == len(func.get_args()) for arg, param, type in zip(args, func.get_args(), arg_types): formal_args.append(param.decl) if param.type.name == '__prim__bool': actual_arg = self.to_bool(arg, ctx) elif param.type.name == '__prim__int': actual_arg = self.to_int(arg, ctx) else: actual_arg = self.to_ref(arg, ctx) actual_args.append(actual_arg) type = self.translate_type(func.type, ctx) sil_name = func.sil_name actual_position = position if position else self.to_position(node, ctx) call = self.viper.FuncApp(sil_name, actual_args, actual_position, self.no_info(ctx), type, formal_args) return call
def _translate_wrapper_expr(self, wrapper: Wrapper, ctx: Context) -> Expr: info = self.no_info(ctx) position = self.to_position(wrapper.node, ctx) if isinstance(wrapper.expr, BinOpWrapper): assert isinstance(wrapper, AssignWrapper) stmt, val = self.translate_expr(wrapper.expr.rhs, ctx) val = self.to_int(val, ctx) var = ctx.var_aliases[wrapper.name].ref() var = self.to_int(var, ctx) if isinstance(wrapper.expr.op, ast.Add): val = self.viper.Add(var, val, position, info) elif isinstance(wrapper.expr.op, ast.Sub): val = self.viper.Sub(var, val, position, info) elif isinstance(wrapper.expr.op, ast.Mult): val = self.viper.Mul(var, val, position, info) else: raise UnsupportedException(wrapper.node) else: stmt, val = self.translate_expr(wrapper.expr, ctx) if stmt: raise InvalidProgramException(wrapper.expr, 'purity.violated') return val
def translate_io_contractfunc_call(self, node: ast.Call, ctx: Context, impure: bool, statement: bool) -> StmtsAndExpr: """Translate a call to a IO contract function. Currently supported functions: + ``token`` + ``ctoken`` + ``Open`` """ func_name = get_func_name(node) if func_name == 'token': if not impure: raise InvalidProgramException(node, 'invalid.contract.position') return self.translate_must_invoke_token(node, ctx) elif func_name == 'ctoken': if not impure: raise InvalidProgramException(node, 'invalid.contract.position') return self.translate_must_invoke_ctoken(node, ctx) elif func_name == 'Open': if not statement: raise InvalidProgramException(node, 'invalid.contract.position') return self._translate_open(node, ctx) elif func_name == 'Eval': return self._translate_eval(node, ctx) elif func_name == 'eval_io': if not impure: raise InvalidProgramException(node, 'invalid.contract.position') return self._translate_eval_io(node, ctx) else: raise UnsupportedException(node, 'Unsupported contract function.')
def translate_contractfunc_call(self, node: ast.Call, ctx: Context, impure=False, statement=False) -> StmtsAndExpr: """ Translates calls to contract functions like Result() and Acc() """ func_name = get_func_name(node) if func_name == 'Result': return self.translate_result(node, ctx) elif func_name == 'RaisedException': return self.translate_raised_exception(node, ctx) elif func_name in ('Acc', 'Rd', 'Wildcard'): if not impure: raise InvalidProgramException(node, 'invalid.contract.position') if func_name == 'Rd': perm = self.get_arp_for_context(node, ctx) elif func_name == 'Wildcard': perm = self.viper.WildcardPerm(self.to_position(node, ctx), self.no_info(ctx)) else: perm = self._get_perm(node, ctx) if isinstance(node.args[0], ast.Call): return self.translate_acc_predicate(node, perm, ctx) else: if isinstance(node.args[0], ast.Attribute): type = self.get_type(node.args[0].value, ctx) if isinstance(type, UnionType): guarded_field_access = [] stmt, receiver = self.translate_expr( node.args[0].value, ctx) for recv_type in toposort_classes(type.get_types() - {None}): target = self.get_target(node.args[0].value, ctx) field_guard = self.var_type_check( target.sil_name, recv_type, self.to_position(node, ctx), ctx) field = recv_type.get_field( node.args[0].attr).actual_field field_access = self.viper.FieldAccess( receiver, field.sil_field, self.to_position(node, ctx), self.no_info(ctx)) field_acc = self._translate_acc_field( field_access, field.type, perm, self.to_position(node, ctx), ctx) guarded_field_access.append( (field_guard, field_acc)) if len(guarded_field_access) == 1: _, field_acc = guarded_field_access[0] return stmt, field_acc else: return (stmt, chain_cond_exp(guarded_field_access, self.viper, self.to_position(node, ctx), self.no_info(ctx), ctx)) target = self.get_target(node.args[0], ctx) if isinstance(target, PythonField): return self.translate_acc_field(node, perm, ctx) else: if not isinstance(target, PythonGlobalVar): raise InvalidProgramException(node, 'invalid.acc') return self.translate_acc_global(node, perm, ctx) elif func_name in BUILTIN_PREDICATES: return [], self.translate_unwrapped_builtin_predicate(node, ctx) elif func_name == 'MaySet': return self.translate_may_set(node, ctx) elif func_name == 'MayCreate': return self.translate_may_create(node, ctx) elif func_name in ('Assert', 'Assume', 'Fold', 'Unfold'): if not statement: raise InvalidProgramException(node, 'invalid.contract.position') if func_name == 'Assert': return self.translate_assert(node, ctx) elif func_name == 'Assume': return self.translate_assume(node, ctx) elif func_name == 'Fold': return self.translate_fold(node, ctx) elif func_name == 'Unfold': return self.translate_unfold(node, ctx) elif func_name == 'Implies': return self.translate_implies(node, ctx, impure) elif func_name == 'Old': return self.translate_old(node, ctx) elif func_name == 'Unfolding': return self.translate_unfolding(node, ctx, impure) elif func_name == 'Low': return self.translate_low(node, ctx) elif func_name == 'LowVal': return self.translate_lowval(node, ctx) elif func_name == 'LowEvent': return self.translate_lowevent(node, ctx) elif func_name == 'LowExit': return self.translate_lowexit(node, ctx) elif func_name == 'Declassify': return self.translate_declassify(node, ctx) elif func_name == 'TerminatesSif': return self.translate_terminates_sif(node, ctx) elif func_name == 'Forall': return self.translate_forall(node, ctx, impure) elif func_name == 'Previous': return self.translate_previous(node, ctx) elif func_name == 'Let': return self.translate_let(node, ctx, impure) elif func_name == PSEQ_TYPE: return self.translate_sequence(node, ctx) elif func_name == PSET_TYPE: return self.translate_pset(node, ctx) elif func_name == PMSET_TYPE: return self.translate_mset(node, ctx) elif func_name == 'ToSeq': return self.translate_to_sequence(node, ctx) elif func_name == 'Joinable': return self.translate_joinable(node, ctx) elif func_name == 'getArg': return self.translate_get_arg(node, ctx) elif func_name == 'getOld': return self.translate_get_old(node, ctx) elif func_name == 'getMethod': raise InvalidProgramException(node, 'invalid.get.method.use') elif func_name == 'arg': raise InvalidProgramException(node, 'invalid.arg.use') else: raise UnsupportedException(node)
def translate_pure_Expr(self, conds: List, node: ast.Expr, ctx: Context) -> List[Wrapper]: if isinstance(node.value, ast.Str): # Ignore docstrings. return [] raise UnsupportedException(node)
def translate_pure_generic(self, conds: List, node: ast.AST, ctx: Context) -> List[Wrapper]: raise UnsupportedException(node)
def translate_generic(self, node: ast.AST, ctx: Context) -> None: """ Visitor that is used if no other visitor is implemented. Simply raises an exception. """ raise UnsupportedException(node)
def translate_perm_Num(self, node: ast.Num, ctx: Context) -> Expr: if node.n == 1: return self.viper.FullPerm(self.to_position(node, ctx), self.no_info(ctx)) raise UnsupportedException(node)
def _get_call_type(node: ast.Call, module: PythonModule, current_function: PythonMethod, containers: List[ContainerInterface], container: PythonNode) -> PythonType: func_name = get_func_name(node) if func_name == 'super': if len(node.args) == 2: return module.classes[node.args[0].id].superclass elif not node.args: return container.cls.superclass else: raise InvalidProgramException(node, 'invalid.super.call') if func_name == 'len': return module.global_module.classes[INT_TYPE] if func_name in ('token', 'ctoken', 'MustTerminate', 'MustRelease'): return module.global_module.classes[BOOL_TYPE] if func_name == PSEQ_TYPE: return _get_collection_literal_type(node, ['args'], PSEQ_TYPE, module, containers, container) if func_name == PSET_TYPE: return _get_collection_literal_type(node, ['args'], PSET_TYPE, module, containers, container) if func_name == PMSET_TYPE: return _get_collection_literal_type(node, ['args'], PMSET_TYPE, module, containers, container) if func_name == 'enumerate': if len(node.args) != 1: raise UnsupportedException( node, 'enumerate only supported with single arg.') list_type = module.global_module.classes[LIST_TYPE] int_type = module.global_module.classes[INT_TYPE] tuple_type = module.global_module.classes[TUPLE_TYPE] arg_type = get_type(node.args[0], containers, container) iterable_type = _get_iteration_type(arg_type, module, node) return GenericType( list_type, [GenericType(tuple_type, [int_type, iterable_type])]) if isinstance(node.func, ast.Name): if node.func.id in CONTRACT_FUNCS: if node.func.id == 'Result': return current_function.type elif node.func.id == 'RaisedException': ctxs = [ cont for cont in containers if hasattr(cont, 'var_aliases') ] ctx = ctxs[0] if ctxs else None assert ctx assert ctx.current_contract_exception is not None return ctx.current_contract_exception elif node.func.id in ('Acc', 'Rd', 'Read', 'Implies', 'Forall', 'Exists', 'MayCreate', 'MaySet', 'Low', 'LowVal', 'LowEvent', 'LowExit'): return module.global_module.classes[BOOL_TYPE] elif node.func.id == 'Declassify': return None elif node.func.id == 'Old': return get_type(node.args[0], containers, container) elif node.func.id == 'Unfolding': return get_type(node.args[1], containers, container) elif node.func.id == 'ToSeq': arg_type = get_type(node.args[0], containers, container) seq_class = module.global_module.classes[PSEQ_TYPE] content_type = _get_iteration_type(arg_type, module, node) return GenericType(seq_class, [content_type]) elif node.func.id == 'Previous': arg_type = get_type(node.args[0], containers, container) list_class = module.global_module.classes[PSEQ_TYPE] return GenericType(list_class, [arg_type]) elif node.func.id in ('getArg', 'getOld', 'getMethod'): object_class = module.global_module.classes[OBJECT_TYPE] return object_class elif node.func.id == 'Let': body_type = get_target(node.args[1], containers, container) if isinstance(body_type, PythonType): return body_type raise InvalidProgramException(node, 'invalid.let') else: raise UnsupportedException(node) elif node.func.id in BUILTINS: if node.func.id in ('isinstance', BOOL_TYPE): return module.global_module.classes[BOOL_TYPE] elif node.func.id == 'cast': return get_target(node.args[0], containers, container) else: raise UnsupportedException(node) if node.func.id in module.classes: return module.global_module.classes[node.func.id] elif module.get_func_or_method(node.func.id) is not None: target = module.get_func_or_method(node.func.id) return target.type elif isinstance(node.func, ast.Attribute): rectype = get_type(node.func.value, containers, container) if isinstance(rectype, UnionType): set_of_classes = rectype.get_types() - {None} set_of_return_types = { type.get_func_or_method(node.func.attr).type for type in set_of_classes } if len(set_of_return_types) == 1: return set_of_return_types.pop() elif len(set_of_return_types) == 2 and None in set_of_return_types: return OptionalType((set_of_return_types - {None}).pop()) else: return UnionType(list(set_of_return_types)) elif isinstance(rectype, PythonType): target = rectype.get_func_or_method(node.func.attr) if target.generic_type != -1: return rectype.type_args[target.generic_type] else: return target.type else: raise UnsupportedException(node)
def _do_get_type(node: ast.AST, containers: List[ContainerInterface], container: PythonNode) -> Optional[PythonType]: """ Does the actual work for get_type without boxing the type. """ if isinstance(container, (PythonIOOperation, PythonMethod)): module = container.module current_function = container else: module = container current_function = None target = get_target(node, containers, container) if target: if isinstance(target, PythonVarBase): return target.get_specific_type(node) if isinstance(target, PythonMethod): if isinstance(node, ast.Call) and isinstance( node.func, ast.Attribute): rec_target = get_target(node.func.value, containers, container) if not isinstance(rec_target, PythonModule): rectype = get_type(node.func.value, containers, container) if target.generic_type != -1: return rectype.type_args[target.generic_type] if isinstance(target.type, TypeVar): while rectype.python_class is not target.cls: rectype = rectype.superclass name_list = list(rectype.python_class.type_vars.keys()) index = name_list.index(target.type.name) return rectype.type_args[index] return target.type if isinstance(target, PythonField): result = target.type if isinstance(result, TypeVar): assert isinstance(node, ast.Attribute) rec_type = _do_get_type(node.value, containers, container) while (rec_type.python_class is not result.target_type.python_class): rec_type = rec_type.superclass result = rec_type.type_args[result.index] return result if isinstance(target, PythonIOOperation): return module.global_module.classes[BOOL_TYPE] if isinstance(target, (PythonType, PythonModule)): if (isinstance(node, ast.Call) and isinstance(target, PythonClass) and target.type_vars): # This is a call to a constructor of a generic class; it's not # enough to just return the class, we need the entire type with # type arguments. We only support that if we can get it directly # from mypy, i.e., when the result is assigned to a variable # and we can get the variable type. if node._parent and isinstance(node._parent, ast.Assign): return get_type(node._parent.targets[0], containers, container) elif (target.name in (PSEQ_TYPE, PSET_TYPE, PMSET_TYPE) and isinstance(node, ast.Call) and node.args): arg_types = [ get_type(arg, containers, container) for arg in node.args ] return GenericType(target, [common_supertype(arg_types)]) else: error = 'generic.constructor.without.type' raise InvalidProgramException(node, error) return target if isinstance(node, (ast.Attribute, ast.Name)): if isinstance(node, ast.Attribute): lhs = _do_get_type(node.value, containers, container) if isinstance(lhs, UnionType) and not isinstance(lhs, OptionalType): candidates = [ find_entry(node.attr, False, [t]) for t in lhs.type_args ] if all( isinstance(c, (PythonField, PythonVarBase)) for c in candidates): return common_supertype([c.type for c in candidates]) # All these cases should be handled by get_target, so if we get here, # the node refers to something unknown in the given context. return None if isinstance(node, ast.Num): return module.global_module.classes[INT_TYPE] elif isinstance(node, ast.Tuple): args = [get_type(arg, containers, container) for arg in node.elts] return GenericType(module.global_module.classes[TUPLE_TYPE], args) elif isinstance(node, ast.Subscript): return get_subscript_type(node, module, containers, container) elif isinstance(node, ast.Str): return module.global_module.classes[STRING_TYPE] elif isinstance(node, ast.Bytes): return module.global_module.classes[BYTES_TYPE] elif isinstance(node, ast.Compare): return module.global_module.classes[BOOL_TYPE] elif isinstance(node, ast.BoolOp): # And and Or always return one of their operands, so we use the common # supertype of all arguments. # TODO: We could also use a union type, but since support for e.g. # calling methods on those isn't amazing yet, we don't do that yet. operand_types = [ get_type(operand, containers, container) for operand in node.values ] return common_supertype(operand_types) elif isinstance(node, ast.List): return _get_collection_literal_type(node, ['elts'], LIST_TYPE, module, containers, container) elif isinstance(node, ast.Set): return _get_collection_literal_type(node, ['elts'], SET_TYPE, module, containers, container) elif isinstance(node, ast.Dict): return _get_collection_literal_type(node, ['keys', 'values'], DICT_TYPE, module, containers, container) elif isinstance(node, ast.IfExp): body_type = get_type(node.body, containers, container) else_type = get_type(node.orelse, containers, container) return pairwise_supertype(body_type, else_type) elif isinstance(node, ast.BinOp): left_type = get_type(node.left, containers, container) right_type = get_type(node.right, containers, container) operator_func = OPERATOR_FUNCTIONS[type(node.op)] return left_type.get_func_or_method(operator_func).type elif isinstance(node, ast.UnaryOp): if isinstance(node.op, ast.Not): return module.global_module.classes[BOOL_TYPE] elif isinstance(node.op, ast.USub): return module.global_module.classes[INT_TYPE] else: raise UnsupportedException(node) elif isinstance(node, ast.NameConstant): if (node.value is True) or (node.value is False): return module.global_module.classes[BOOL_TYPE] elif node.value is None: return module.global_module.classes[OBJECT_TYPE] else: raise UnsupportedException(node) elif isinstance(node, ast.Call): return _get_call_type(node, module, current_function, containers, container) elif isinstance(node, ast.ListComp): if (node._parent and isinstance(node._parent, ast.Assign) and node is node._parent.value): # Constructor is assigned to variable; # we get the type of the dict from the type of the # variable it's assigned to. return get_type(node._parent.targets[0], containers, container) else: raise UnsupportedException( node, 'List comprehensions must be directly ' 'assigned to a local variable.') else: raise UnsupportedException(node)
def visit_Call(self, node: ast.Call) -> None: """Parse IO operation properties. Currently, only parses properties such as ``Terminates`` and ``TerminationMeasure``. """ assert self._current_io_operation is not None assert self._current_node is not None body_prefix = None if (isinstance(node.func, ast.Name) and node.func.id in IO_OPERATION_PROPERTY_FUNCS): for child in self._current_node.body: if (isinstance(child, ast.Expr) and child.value == node): break else: self._raise_invalid_operation('misplaced_property', node) operation = self._current_io_operation arg = node.args[0] if node.func.id == 'Terminates': if not operation.set_terminates(arg): self._raise_invalid_operation('duplicate_property', node) elif node.func.id == 'TerminationMeasure': if not operation.set_termination_measure(arg): self._raise_invalid_operation('duplicate_property', node) else: raise UnsupportedException(node, 'Unsupported property type.') self._in_property = True for arg in node.args: self.visit(arg) self._in_property = False return elif isinstance( node.func, ast.Name) and node.func.id in ('IOForall', 'Forall', 'Exists'): operation = self._current_io_operation assert len(node.args[1].args.args) == 1 arg_type = self._parent.get_target(node.args[0], operation.module) lambda_ = node.args[1] body_prefix = construct_lambda_prefix( lambda_.lineno, getattr(lambda_, 'col_offset', None)) for arg in lambda_.args.args: var = self._node_factory.create_python_var( arg.arg, arg, arg_type) operation._io_universals.append(var) elif (isinstance(node.func, ast.Call) and isinstance(node.func.func, ast.Name) and node.func.func.id == 'IOExists'): lambda_ = node.args[0] arg = lambda_.args.args[0] body_prefix = construct_lambda_prefix(lambda_.lineno, lambda_.col_offset) creator = self._node_factory.create_python_var_creator( arg.arg, arg, self._typeof(arg, lambda_)) current_existentials = self._current_io_operation.get_io_existentials( ) current_existentials.append(creator) if body_prefix: self._current_lambdas.append(body_prefix) for arg in node.args: self.visit(arg) if body_prefix: self._current_lambdas.pop()
def translate_contract_Expr(self, node: ast.Expr, ctx: Context) -> Expr: if isinstance(node.value, ast.Call): return self.translate_contract(node.value, ctx) else: raise UnsupportedException(node)