def get_class_conditions(self, cls: type) -> ClassConditions: if not is_pure_python(cls): # We can't get conditions/line numbers for classes written in C. return ClassConditions([], {}) toplevel_parser = self.get_toplevel_parser() methods = {} super_conditions = merge_class_conditions([ toplevel_parser.get_class_conditions(base) for base in cls.__bases__ ]) inv = self.get_class_invariants(cls) super_methods = super_conditions.methods method_names = set(cls.__dict__.keys()) | super_methods.keys() for method_name in method_names: method = cls.__dict__.get(method_name, None) super_method_conditions = super_methods.get(method_name) if super_method_conditions is not None: revised_sig = set_first_arg_type(super_method_conditions.sig, cls) super_method_conditions = replace(super_method_conditions, sig=revised_sig) if method is None: if super_method_conditions is None: continue else: conditions: Conditions = super_method_conditions else: parsed_conditions = toplevel_parser.get_fn_conditions( FunctionInfo.from_class(cls, method_name)) if parsed_conditions is None: # debug(f'Skipping "{method_name}": Unable to determine the function signature.') continue if super_method_conditions is None: conditions = parsed_conditions else: conditions = merge_fn_conditions(parsed_conditions, super_method_conditions) if method_name in ("__new__", "__repr__"): # __new__ isn't passed a concrete instance. # __repr__ is itself required for reporting problems with invariants. use_pre, use_post = False, False elif method_name == "__del__": use_pre, use_post = True, False elif method_name == "__init__": use_pre, use_post = False, True elif method_name.startswith("__") and method_name.endswith("__"): use_pre, use_post = True, True elif method_name.startswith("_"): use_pre, use_post = False, False else: use_pre, use_post = True, True if use_pre: conditions.pre.extend(inv) if use_post: conditions.post.extend(inv) if conditions.has_any(): methods[method_name] = conditions return ClassConditions(inv, methods)
def _wrap_class_members(self, cls: type, class_conditions: ClassConditions) -> None: method_conditions = dict(class_conditions.methods) for method_name in list(cls.__dict__.keys()): conditions = method_conditions.get(method_name) if conditions is None: continue ctxfn = FunctionInfo.from_class(cls, method_name) raw_fn = ctxfn.descriptor wrapper = self.wrapper_map.get(raw_fn) if wrapper is None: if conditions.has_any(): fn, _ = ctxfn.callable() wrapper = EnforcementWrapper(self.interceptor(fn), conditions, self) functools.update_wrapper(wrapper, fn) else: wrapper = fn self.wrapper_map[raw_fn] = wrapper outer_wrapper = ctxfn.patch_logic(wrapper) self.original_map[IdentityWrapper(outer_wrapper)] = raw_fn setattr(cls, method_name, outer_wrapper)