示例#1
0
def test_basic_patch_for_class(gorilla_setting):
    A, B = gen_class_A_B()

    original_A_f1 = A.f1
    original_A_f2 = A.f2
    original_B_f1 = B.f1

    def patched_A_f1(self):  # pylint: disable=unused-argument
        pass

    def patched_A_f2(self):  # pylint: disable=unused-argument
        pass

    def patched_B_f1(self):  # pylint: disable=unused-argument
        pass

    patch_A_f1 = gorilla.Patch(A, "f1", patched_A_f1, gorilla_setting)
    patch_A_f2 = gorilla.Patch(A, "f2", patched_A_f2, gorilla_setting)
    patch_B_f1 = gorilla.Patch(B, "f1", patched_B_f1, gorilla_setting)

    assert gorilla.get_original_attribute(A, "f1") is original_A_f1
    assert gorilla.get_original_attribute(B, "f1") is original_B_f1
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.apply(patch_A_f1)
    assert A.f1 is patched_A_f1
    assert gorilla.get_original_attribute(A, "f1") is original_A_f1
    assert gorilla.get_original_attribute(B, "f1") is original_B_f1

    gorilla.apply(patch_B_f1)
    assert A.f1 is patched_A_f1
    assert B.f1 is patched_B_f1
    assert gorilla.get_original_attribute(A, "f1") is original_A_f1
    assert gorilla.get_original_attribute(B, "f1") is original_B_f1

    gorilla.apply(patch_A_f2)
    assert A.f2 is patched_A_f2
    assert B.f2 is patched_A_f2
    assert gorilla.get_original_attribute(A, "f2") is original_A_f2
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.revert(patch_A_f2)
    assert A.f2 is original_A_f2
    assert B.f2 is original_A_f2
    assert gorilla.get_original_attribute(A, "f2") == original_A_f2
    assert gorilla.get_original_attribute(B, "f2") == original_A_f2

    gorilla.revert(patch_B_f1)
    assert A.f1 is patched_A_f1
    assert B.f1 is original_B_f1
    assert gorilla.get_original_attribute(A, "f1") == original_A_f1
    assert gorilla.get_original_attribute(B, "f1") == original_B_f1

    gorilla.revert(patch_A_f1)
    assert A.f1 is original_A_f1
    assert B.f1 is original_B_f1
    assert gorilla.get_original_attribute(A, "f1") == original_A_f1
    assert gorilla.get_original_attribute(B, "f1") == original_B_f1
示例#2
0
    def patch_class_tree(klass):
        """
        Patches all subclasses that override any auto-loggable method via monkey patching using
        the gorilla package, taking the argument as the tree root in the class hierarchy. Every
        auto-loggable method found in any of the subclasses is replaced by the patched version.
        :param klass: root in the class hierarchy to be analyzed and patched recursively
        """

        # TODO: add more autologgable methods here (e.g. fit_regularized, from_formula, etc)
        # See https://www.statsmodels.org/dev/api.html
        autolog_supported_func = {"fit": wrapper_fit}

        glob_settings = gorilla.Settings(allow_hit=True, store_hit=True)
        glob_subclasses = set(find_subclasses(klass))

        # Create a patch for every method that needs to be patched, i.e. those
        # which actually override an autologgable method
        patches_list = [
            # Link the patched function with the original via a local variable in the closure
            # to allow invoking superclass methods in the context of the subclass, and not
            # losing the trace of the true original method
            gorilla.Patch(
                c, method_name, wrapper_func(getattr(c, method_name)), settings=glob_settings
            )
            for c in glob_subclasses
            for (method_name, wrapper_func) in autolog_supported_func.items()
            if overrides(c, method_name)
        ]

        for p in patches_list:
            apply_gorilla_patch(p)
示例#3
0
def test_patch_on_attribute_not_exist(store_hit):
    A, _ = gen_class_A_B()

    def patched_fx(self):  # pylint: disable=unused-argument
        return 101

    gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
    fx_patch = gorilla.Patch(A, "fx", patched_fx, gorilla_setting)
    gorilla.apply(fx_patch)
    a1 = A()
    assert a1.fx() == 101
    gorilla.revert(fx_patch)
    assert not hasattr(A, "fx")
示例#4
0
def test_patch_for_descriptor(gorilla_setting):
    A, _ = gen_class_A_B()

    original_A_f3_raw = object.__getattribute__(A, "f3")

    def patched_A_f3(self):  # pylint: disable=unused-argument
        pass

    patch_A_f3 = gorilla.Patch(A, "f3", patched_A_f3, gorilla_setting)

    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)

    gorilla.apply(patch_A_f3)
    assert A.f3 is patched_A_f3
    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)

    gorilla.revert(patch_A_f3)
    assert A.f3 is A.delegated_f3
    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)

    # test patch a descriptor
    @delegate(patched_A_f3)
    def new_patched_A_f3(self):  # pylint: disable=unused-argument
        pass

    new_patch_A_f3 = gorilla.Patch(A, "f3", new_patched_A_f3, gorilla_setting)
    gorilla.apply(new_patch_A_f3)
    assert A.f3 is patched_A_f3
    assert object.__getattribute__(A, "f3") is new_patched_A_f3
    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)
示例#5
0
def _wrap_patch(destination, name, patch_obj, settings=None):
    """
    Apply a patch.

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch_obj: Patch object, it should be a function or a property decorated function
                      to be assigned to the patch point {destination}.{name}
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    patch = gorilla.Patch(destination, name, patch_obj, settings=settings)
    gorilla.apply(patch)
    return patch
示例#6
0
def wrap_patch(destination, name, patch, settings=None):
    """
    Apply a patch while preserving the attributes (e.g. __doc__) of an original function.

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch: Patch function
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    original = getattr(destination, name)
    wrapped = functools.wraps(original)(patch)
    patch = gorilla.Patch(destination, name, wrapped, settings=settings)
    gorilla.apply(patch)
示例#7
0
def test_patch_on_inherit_method(store_hit):
    A, B = gen_class_A_B()

    original_A_f2 = A.f2

    def patched_B_f2(self):  # pylint: disable=unused-argument
        pass

    gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
    patch_B_f2 = gorilla.Patch(B, "f2", patched_B_f2, gorilla_setting)
    gorilla.apply(patch_B_f2)

    assert B.f2 is patched_B_f2

    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.revert(patch_B_f2)
    assert B.f2 is original_A_f2
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2
    assert "f2" not in B.__dict__  # assert no side effect after reverting
示例#8
0
def wrap_patch(destination, name, patch, settings=None):
    """
    Apply a patch while preserving the attributes (e.g. __doc__) of an original function.

    TODO(dbczumar): Convert this to an internal method once existing `wrap_patch` calls
                    outside of `autologging_utils` have been converted to `safe_patch`

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch: Patch function
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    original = getattr(destination, name)
    wrapped = _update_wrapper_extended(patch, original)

    patch = gorilla.Patch(destination, name, wrapped, settings=settings)
    gorilla.apply(patch)