def test_wrap_patch_with_module(): import sys this_module = sys.modules[__name__] def new_sample_function(a, b): """new mlflow.log_param""" return a - b assert sample_function_to_patch(10, 5) == 15 _wrap_patch(this_module, sample_function_to_patch.__name__, new_sample_function) assert sample_function_to_patch(10, 5) == 5
def test_wrap_patch_with_class(): class Math: def add(self, a, b): """add""" return a + b def new_add(self, *args, **kwargs): """new add""" orig = gorilla.get_original_attribute(self, "add") return 2 * orig(*args, **kwargs) _wrap_patch(Math, Math.add.__name__, new_add) assert Math().add(1, 2) == 6