Ejemplo n.º 1
0
def test_patch_on_attribute_not_exist(store_hit):
    A, _ = gen_class_A_B()

    def patched_fx(self):  # pylint: disable=unused-argument
        return 101

    gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
    fx_patch = gorilla.Patch(A, "fx", patched_fx, gorilla_setting)
    gorilla.apply(fx_patch)
    a1 = A()
    assert a1.fx() == 101
    gorilla.revert(fx_patch)
    assert not hasattr(A, "fx")
Ejemplo n.º 2
0
def revert_patches(autologging_integration):
    """
    Reverts all patches on the specified destination class for autologging disablement
    purposes.

    :param autologging_integration: The name of the autologging integration associated with the
                                    patch. Note: If called via fluent api
                                    (`autologging_integration="mlfow"`), then revert all patches
                                    for all active autologging integrations.
    """
    for patch in _AUTOLOGGING_PATCHES.get(autologging_integration, []):
        gorilla.revert(patch)

    _AUTOLOGGING_PATCHES.pop(autologging_integration, None)
Ejemplo n.º 3
0
def test_patch_on_inherit_method(store_hit):
    A, B = gen_class_A_B()

    original_A_f2 = A.f2

    def patched_B_f2(self):  # pylint: disable=unused-argument
        pass

    gorilla_setting = gorilla.Settings(allow_hit=True, store_hit=store_hit)
    patch_B_f2 = gorilla.Patch(B, "f2", patched_B_f2, gorilla_setting)
    gorilla.apply(patch_B_f2)

    assert B.f2 is patched_B_f2

    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.revert(patch_B_f2)
    assert B.f2 is original_A_f2
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2
    assert "f2" not in B.__dict__  # assert no side effect after reverting
Ejemplo n.º 4
0
def test_basic_patch_for_class(gorilla_setting):
    A, B = gen_class_A_B()

    original_A_f1 = A.f1
    original_A_f2 = A.f2
    original_B_f1 = B.f1

    def patched_A_f1(self):  # pylint: disable=unused-argument
        pass

    def patched_A_f2(self):  # pylint: disable=unused-argument
        pass

    def patched_B_f1(self):  # pylint: disable=unused-argument
        pass

    patch_A_f1 = gorilla.Patch(A, "f1", patched_A_f1, gorilla_setting)
    patch_A_f2 = gorilla.Patch(A, "f2", patched_A_f2, gorilla_setting)
    patch_B_f1 = gorilla.Patch(B, "f1", patched_B_f1, gorilla_setting)

    assert gorilla.get_original_attribute(A, "f1") is original_A_f1
    assert gorilla.get_original_attribute(B, "f1") is original_B_f1
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.apply(patch_A_f1)
    assert A.f1 is patched_A_f1
    assert gorilla.get_original_attribute(A, "f1") is original_A_f1
    assert gorilla.get_original_attribute(B, "f1") is original_B_f1

    gorilla.apply(patch_B_f1)
    assert A.f1 is patched_A_f1
    assert B.f1 is patched_B_f1
    assert gorilla.get_original_attribute(A, "f1") is original_A_f1
    assert gorilla.get_original_attribute(B, "f1") is original_B_f1

    gorilla.apply(patch_A_f2)
    assert A.f2 is patched_A_f2
    assert B.f2 is patched_A_f2
    assert gorilla.get_original_attribute(A, "f2") is original_A_f2
    assert gorilla.get_original_attribute(B, "f2") is original_A_f2

    gorilla.revert(patch_A_f2)
    assert A.f2 is original_A_f2
    assert B.f2 is original_A_f2
    assert gorilla.get_original_attribute(A, "f2") == original_A_f2
    assert gorilla.get_original_attribute(B, "f2") == original_A_f2

    gorilla.revert(patch_B_f1)
    assert A.f1 is patched_A_f1
    assert B.f1 is original_B_f1
    assert gorilla.get_original_attribute(A, "f1") == original_A_f1
    assert gorilla.get_original_attribute(B, "f1") == original_B_f1

    gorilla.revert(patch_A_f1)
    assert A.f1 is original_A_f1
    assert B.f1 is original_B_f1
    assert gorilla.get_original_attribute(A, "f1") == original_A_f1
    assert gorilla.get_original_attribute(B, "f1") == original_B_f1
Ejemplo n.º 5
0
def test_patch_for_descriptor(gorilla_setting):
    A, _ = gen_class_A_B()

    original_A_f3_raw = object.__getattribute__(A, "f3")

    def patched_A_f3(self):  # pylint: disable=unused-argument
        pass

    patch_A_f3 = gorilla.Patch(A, "f3", patched_A_f3, gorilla_setting)

    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)

    gorilla.apply(patch_A_f3)
    assert A.f3 is patched_A_f3
    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)

    gorilla.revert(patch_A_f3)
    assert A.f3 is A.delegated_f3
    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)

    # test patch a descriptor
    @delegate(patched_A_f3)
    def new_patched_A_f3(self):  # pylint: disable=unused-argument
        pass

    new_patch_A_f3 = gorilla.Patch(A, "f3", new_patched_A_f3, gorilla_setting)
    gorilla.apply(new_patch_A_f3)
    assert A.f3 is patched_A_f3
    assert object.__getattribute__(A, "f3") is new_patched_A_f3
    assert gorilla.get_original_attribute(A, "f3") is A.delegated_f3
    assert (gorilla.get_original_attribute(
        A, "f3", bypass_descriptor_protocol=True) is original_A_f3_raw)