def apply_gorilla_patch(patch, force_backup=True): """ Apply a patch, even if the backup method already exists. Adapted from gorilla.py in the gorilla package """ settings = gorilla.Settings() if patch.settings is None else patch.settings # When a hit occurs due to an attribute at the destination already existing # with the patch's name, the existing attribute is referred to as 'target'. try: target = gorilla.get_attribute(patch.destination, patch.name) except AttributeError: pass else: if not settings.allow_hit: raise RuntimeError( "An attribute named '%s' already exists at the destination " "'%s'. Set a different name through the patch object to avoid " "a name clash or set the setting 'allow_hit' to True to " "overwrite the attribute. In the latter case, it is " "recommended to also set the 'store_hit' setting to True in " "order to store the original attribute under a different " "name so it can still be accessed." % (patch.name, patch.destination.__name__) ) if settings.store_hit: original_name = gorilla._ORIGINAL_NAME % (patch.name,) # This condition is different from gorilla.apply as it now includes force_backup if force_backup or not hasattr(patch.destination, original_name): setattr(patch.destination, original_name, target) setattr(patch.destination, patch.name, patch.obj)
def patch_class_tree(klass): """ Patches all subclasses that override any auto-loggable method via monkey patching using the gorilla package, taking the argument as the tree root in the class hierarchy. Every auto-loggable method found in any of the subclasses is replaced by the patched version. :param klass: root in the class hierarchy to be analyzed and patched recursively """ # TODO: add more autologgable methods here (e.g. fit_regularized, from_formula, etc) # See https://www.statsmodels.org/dev/api.html autolog_supported_func = {"fit": wrapper_fit} glob_settings = gorilla.Settings(allow_hit=True, store_hit=True) glob_subclasses = set(find_subclasses(klass)) # Create a patch for every method that needs to be patched, i.e. those # which actually override an autologgable method patches_list = [ # Link the patched function with the original via a local variable in the closure # to allow invoking superclass methods in the context of the subclass, and not # losing the trace of the true original method gorilla.Patch( c, method_name, wrapper_func(getattr(c, method_name)), settings=glob_settings ) for c in glob_subclasses for (method_name, wrapper_func) in autolog_supported_func.items() if overrides(c, method_name) ] for p in patches_list: apply_gorilla_patch(p)
def test_patch_on_attribute_not_exist(store_hit): A, _ = gen_class_A_B() def patched_fx(self): # pylint: disable=unused-argument return 101 gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit) fx_patch = gorilla.Patch(A, "fx", patched_fx, gorilla_setting) gorilla.apply(fx_patch) a1 = A() assert a1.fx() == 101 gorilla.revert(fx_patch) assert not hasattr(A, "fx")
def _wrap_patch(destination, name, patch_obj, settings=None): """ Apply a patch. :param destination: Patch destination :param name: Name of the attribute at the destination :param patch_obj: Patch object, it should be a function or a property decorated function to be assigned to the patch point {destination}.{name} :param settings: Settings for gorilla.Patch """ if settings is None: settings = gorilla.Settings(allow_hit=True, store_hit=True) patch = gorilla.Patch(destination, name, patch_obj, settings=settings) gorilla.apply(patch) return patch
def wrap_patch(destination, name, patch, settings=None): """ Apply a patch while preserving the attributes (e.g. __doc__) of an original function. :param destination: Patch destination :param name: Name of the attribute at the destination :param patch: Patch function :param settings: Settings for gorilla.Patch """ if settings is None: settings = gorilla.Settings(allow_hit=True, store_hit=True) original = getattr(destination, name) wrapped = functools.wraps(original)(patch) patch = gorilla.Patch(destination, name, wrapped, settings=settings) gorilla.apply(patch)
def test_patch_on_inherit_method(store_hit): A, B = gen_class_A_B() original_A_f2 = A.f2 def patched_B_f2(self): # pylint: disable=unused-argument pass gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit) patch_B_f2 = gorilla.Patch(B, "f2", patched_B_f2, gorilla_setting) gorilla.apply(patch_B_f2) assert B.f2 is patched_B_f2 assert gorilla.get_original_attribute(B, "f2") is original_A_f2 gorilla.revert(patch_B_f2) assert B.f2 is original_A_f2 assert gorilla.get_original_attribute(B, "f2") is original_A_f2 assert "f2" not in B.__dict__ # assert no side effect after reverting
def wrap_patch(destination, name, patch, settings=None): """ Apply a patch while preserving the attributes (e.g. __doc__) of an original function. TODO(dbczumar): Convert this to an internal method once existing `wrap_patch` calls outside of `autologging_utils` have been converted to `safe_patch` :param destination: Patch destination :param name: Name of the attribute at the destination :param patch: Patch function :param settings: Settings for gorilla.Patch """ if settings is None: settings = gorilla.Settings(allow_hit=True, store_hit=True) original = getattr(destination, name) wrapped = _update_wrapper_extended(patch, original) patch = gorilla.Patch(destination, name, wrapped, settings=settings) gorilla.apply(patch)
def gorilla_setting(): return gorilla.Settings(allow_hit=True, store_hit=True)