예제 #1
0
    def apply_gorilla_patch(patch, force_backup=True):
        """
        Apply a patch, even if the backup method already exists.
        Adapted from gorilla.py in the gorilla package
        """
        settings = gorilla.Settings() if patch.settings is None else patch.settings

        # When a hit occurs due to an attribute at the destination already existing
        # with the patch's name, the existing attribute is referred to as 'target'.
        try:
            target = gorilla.get_attribute(patch.destination, patch.name)
        except AttributeError:
            pass
        else:
            if not settings.allow_hit:
                raise RuntimeError(
                    "An attribute named '%s' already exists at the destination "
                    "'%s'. Set a different name through the patch object to avoid "
                    "a name clash or set the setting 'allow_hit' to True to "
                    "overwrite the attribute. In the latter case, it is "
                    "recommended to also set the 'store_hit' setting to True in "
                    "order to store the original attribute under a different "
                    "name so it can still be accessed." % (patch.name, patch.destination.__name__)
                )

            if settings.store_hit:
                original_name = gorilla._ORIGINAL_NAME % (patch.name,)
                # This condition is different from gorilla.apply as it now includes force_backup
                if force_backup or not hasattr(patch.destination, original_name):
                    setattr(patch.destination, original_name, target)

        setattr(patch.destination, patch.name, patch.obj)
예제 #2
0
    def patch_class_tree(klass):
        """
        Patches all subclasses that override any auto-loggable method via monkey patching using
        the gorilla package, taking the argument as the tree root in the class hierarchy. Every
        auto-loggable method found in any of the subclasses is replaced by the patched version.
        :param klass: root in the class hierarchy to be analyzed and patched recursively
        """

        # TODO: add more autologgable methods here (e.g. fit_regularized, from_formula, etc)
        # See https://www.statsmodels.org/dev/api.html
        autolog_supported_func = {"fit": wrapper_fit}

        glob_settings = gorilla.Settings(allow_hit=True, store_hit=True)
        glob_subclasses = set(find_subclasses(klass))

        # Create a patch for every method that needs to be patched, i.e. those
        # which actually override an autologgable method
        patches_list = [
            # Link the patched function with the original via a local variable in the closure
            # to allow invoking superclass methods in the context of the subclass, and not
            # losing the trace of the true original method
            gorilla.Patch(
                c, method_name, wrapper_func(getattr(c, method_name)), settings=glob_settings
            )
            for c in glob_subclasses
            for (method_name, wrapper_func) in autolog_supported_func.items()
            if overrides(c, method_name)
        ]

        for p in patches_list:
            apply_gorilla_patch(p)
예제 #3
0
def test_patch_on_attribute_not_exist(store_hit):
    A, _ = gen_class_A_B()

    def patched_fx(self):  # pylint: disable=unused-argument
        return 101

    gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
    fx_patch = gorilla.Patch(A, "fx", patched_fx, gorilla_setting)
    gorilla.apply(fx_patch)
    a1 = A()
    assert a1.fx() == 101
    gorilla.revert(fx_patch)
    assert not hasattr(A, "fx")
예제 #4
0
def _wrap_patch(destination, name, patch_obj, settings=None):
    """
    Apply a patch.

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch_obj: Patch object, it should be a function or a property decorated function
                      to be assigned to the patch point {destination}.{name}
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    patch = gorilla.Patch(destination, name, patch_obj, settings=settings)
    gorilla.apply(patch)
    return patch
예제 #5
0
def wrap_patch(destination, name, patch, settings=None):
    """
    Apply a patch while preserving the attributes (e.g. __doc__) of an original function.

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch: Patch function
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    original = getattr(destination, name)
    wrapped = functools.wraps(original)(patch)
    patch = gorilla.Patch(destination, name, wrapped, settings=settings)
    gorilla.apply(patch)
예제 #6
0
def test_patch_on_inherit_method(store_hit):
    A, B = gen_class_A_B()

    original_A_f2 = A.f2

    def patched_B_f2(self):  # pylint: disable=unused-argument
        pass

    gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
    patch_B_f2 = gorilla.Patch(B, "f2", patched_B_f2, gorilla_setting)
    gorilla.apply(patch_B_f2)

    assert B.f2 is patched_B_f2

    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.revert(patch_B_f2)
    assert B.f2 is original_A_f2
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2
    assert "f2" not in B.__dict__  # assert no side effect after reverting
예제 #7
0
def wrap_patch(destination, name, patch, settings=None):
    """
    Apply a patch while preserving the attributes (e.g. __doc__) of an original function.

    TODO(dbczumar): Convert this to an internal method once existing `wrap_patch` calls
                    outside of `autologging_utils` have been converted to `safe_patch`

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch: Patch function
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    original = getattr(destination, name)
    wrapped = _update_wrapper_extended(patch, original)

    patch = gorilla.Patch(destination, name, wrapped, settings=settings)
    gorilla.apply(patch)
예제 #8
0
def gorilla_setting():
    return gorilla.Settings(allow_hit=True, store_hit=True)