Esempio n. 1
0
    def test_create_patches_3(self):
        def filter(name, value):
            return 'method' in name

        destination = _tomodule
        obj = _frommodule
        patches = gorilla.create_patches(destination,
                                         obj,
                                         filter=filter,
                                         use_decorators=False)
        expected_patches = [
            gorilla.Patch(destination, 'unbound_class_method',
                          gorilla.get_attribute(obj, 'unbound_class_method')),
            gorilla.Patch(destination, 'unbound_method',
                          gorilla.get_attribute(obj, 'unbound_method')),
            gorilla.Patch(destination, 'unbound_static_method',
                          gorilla.get_attribute(obj, 'unbound_static_method')),
        ]
        self.assertEqual(patches, expected_patches)

        destination = _tomodule.Class
        obj = _frommodule.Class
        patches = gorilla.create_patches(destination,
                                         obj,
                                         filter=filter,
                                         use_decorators=False)
        expected_patches = [
            gorilla.Patch(destination, 'class_method',
                          gorilla.get_attribute(obj, 'class_method')),
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
            gorilla.Patch(destination, 'static_method',
                          gorilla.get_attribute(obj, 'static_method')),
        ]
        self.assertEqual(patches, expected_patches)

        destination = _tomodule.Parent
        obj = _frommodule.Parent
        patches = gorilla.create_patches(destination,
                                         obj,
                                         filter=filter,
                                         use_decorators=False)
        expected_patches = [
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
        ]
        self.assertEqual(patches, expected_patches)

        destination = _tomodule.Child
        obj = _frommodule.Child
        patches = gorilla.create_patches(destination,
                                         obj,
                                         filter=filter,
                                         use_decorators=False)
        expected_patches = [
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
        ]
        self.assertEqual(patches, expected_patches)
Esempio n. 2
0
    def test_apply_patch_with_hit_3(self):
        settings = gorilla.Settings(allow_hit=True, store_hit=False)

        source_paths = [''] + _list_attribute_paths(_frommodule)
        target_paths = _list_attribute_paths(_tomodule)
        combinations = itertools.product(source_paths, target_paths)
        for source_path, target_path in combinations:
            self.setUp()

            destination_path, name = _split_attribute_path(target_path)
            destination = _get_attribute_from_path(_tomodule, destination_path)
            obj = _get_attribute_from_path(_frommodule, source_path)
            patch = gorilla.Patch(destination, name, obj, settings=settings)
            gorilla.apply(patch)
            self.assertIs(
                destination,
                _get_attribute_from_path(_tomodule, destination_path))

            result = gorilla.get_attribute(destination, name)
            self.assertIs(result, obj)

            self.assertRaises(AttributeError, gorilla.get_original_attribute,
                              destination, name)

            self.tearDown()
Esempio n. 3
0
    def test_create_patches_5(self):
        destination = _tomodule
        obj = _frommodule

        gorilla.name('function')(gorilla.get_attribute(obj, 'Class'))
        gorilla.name('dummy_1')(gorilla.get_attribute(obj, 'Parent'))
        gorilla.name('dummy_2')(gorilla.get_attribute(obj, 'Child'))
        patches = gorilla.create_patches(destination, obj)

        expected_patches = [
            gorilla.Patch(destination, 'dummy_2',
                          gorilla.get_attribute(obj, 'Child')),
            gorilla.Patch(destination, 'function',
                          gorilla.get_attribute(obj, 'Class')),
            gorilla.Patch(destination, 'dummy_1',
                          gorilla.get_attribute(obj, 'Parent')),
            gorilla.Patch(destination, 'function',
                          gorilla.get_attribute(obj, 'function')),
            gorilla.Patch(destination, 'global_variable',
                          gorilla.get_attribute(obj, 'global_variable')),
            gorilla.Patch(destination, 'whatever',
                          gorilla.get_attribute(obj, 'unbound_class_method')),
            gorilla.Patch(destination,
                          'unbound_static_method',
                          gorilla.get_attribute(obj, 'unbound_static_method'),
                          settings=gorilla.Settings(allow_hit=True)),
        ]
        self.assertEqual(patches, expected_patches)
Esempio n. 4
0
    def test_name_decorator(self):
        destination = _tomodule.Class
        obj = _frommodule.Class

        name_override = 'whatever'
        gorilla.name(name_override)(gorilla.get_attribute(obj, 'class_method'))
        gorilla.name(name_override)(gorilla.get_attribute(
            obj, 'static_method'))
        gorilla.name(name_override)(gorilla.get_attribute(obj.Inner, 'method'))
        gorilla.patches(destination)(obj)

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination, 'STATIC_VALUE',
                          gorilla.get_attribute(obj, 'STATIC_VALUE')),
            gorilla.Patch(destination, name_override,
                          gorilla.get_attribute(obj, 'class_method')),
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
            gorilla.Patch(destination, name_override,
                          gorilla.get_attribute(obj, 'static_method')),
            gorilla.Patch(destination, 'value',
                          gorilla.get_attribute(obj, 'value')),
            gorilla.Patch(destination.Inner, 'STATIC_VALUE',
                          gorilla.get_attribute(obj.Inner, 'STATIC_VALUE')),
            gorilla.Patch(destination.Inner, name_override,
                          gorilla.get_attribute(obj.Inner, 'method')),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 5
0
    def patch_frelu(cls):
        import gorilla
        import torch.nn.functional as F
        settings = gorilla.Settings(allow_hit=True)

        patch = gorilla.Patch(F, 'relu', lambda x: x, settings=settings)
        gorilla.apply(patch)

        print('F.relu have been patched to identity fn.')
Esempio n. 6
0
def autolog():
    """
    Enable automatic logging from TensorFlow to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.
    """
    import keras

    class __MLflowKerasCallback(keras.callbacks.Callback):
        """
        Callback for auto-logging metrics and parameters.
        Records available logs after each epoch.
        Records model structural information as params after training finishes.
        """
        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            try_mlflow_log(mlflow.log_param, 'num_layers',
                           len(self.model.layers))
            try_mlflow_log(mlflow.log_param, 'optimizer_name',
                           type(self.model.optimizer).__name__)
            if hasattr(self.model.optimizer, 'lr'):
                lr = self.model.optimizer.lr if \
                    type(self.model.optimizer.lr) is float \
                    else keras.backend.eval(self.model.optimizer.lr)
                try_mlflow_log(mlflow.log_param, 'learning_rate', lr)
            if hasattr(self.model.optimizer, 'epsilon'):
                epsilon = self.model.optimizer.epsilon if \
                    type(self.model.optimizer.epsilon) is float \
                    else keras.backend.eval(self.model.optimizer.epsilon)
                try_mlflow_log(mlflow.log_param, 'epsilon', epsilon)
            sum_list = []
            self.model.summary(print_fn=sum_list.append)
            summary = '\n'.join(sum_list)
            try_mlflow_log(mlflow.set_tag, 'summary', summary)
            try_mlflow_log(log_model, self.model, artifact_path='model')

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit')
        if len(args) >= 6:
            l = list(args)
            l[5] += [__MLflowKerasCallback()]
            args = tuple(l)
        elif 'callbacks' in kwargs:
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]
        return original(self, *args, **kwargs)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(keras.Model, 'fit', fit, settings=settings)
    gorilla.apply(patch)
Esempio n. 7
0
def set_seed(seed=100):
    np.random.seed(seed)
    random.seed(seed)
    # Monkey Patch get_seed.
    func = lambda op_seed: better_get_seed(seed, op_seed)
    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(
        random_seed, 'get_seed', func, settings=settings)
    gorilla.apply(patch)
Esempio n. 8
0
    def test_patch_decorator_on_class(self):
        destination = _tomodule
        obj = _frommodule.Class

        self.assertIs(gorilla.patch(destination)(obj), obj)

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination, 'Class', obj),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 9
0
    def test_patch_decorator_on_function(self):
        destination = _tomodule
        obj = gorilla.get_attribute(_frommodule, 'function')

        self.assertIs(gorilla.patch(destination)(obj), obj)

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination, 'function', obj),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 10
0
def apply_gorrila(function: Callable, module: Any):
    """Overriding a function using a gorilla patch.

    Args:
        function (Callable): Override function
        module (Any): Function caller module
    """
    patch = gorilla.Patch(module,
                          function.__name__,
                          function,
                          settings=gorilla.Settings(allow_hit=True))
    gorilla.apply(patch)
Esempio n. 11
0
    def test_filter_decorator(self):
        destination = _tomodule.Class
        obj = _frommodule.Class

        gorilla.filter(True)(gorilla.get_attribute(obj, '__init__'))
        gorilla.filter(False)(gorilla.get_attribute(obj, 'class_method'))
        gorilla.filter(False)(gorilla.get_attribute(obj.Inner, 'method'))
        gorilla.patches(destination)(obj)

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination, 'STATIC_VALUE',
                          gorilla.get_attribute(obj, 'STATIC_VALUE')),
            gorilla.Patch(destination, '__init__',
                          gorilla.get_attribute(obj, '__init__')),
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
            gorilla.Patch(destination, 'static_method',
                          gorilla.get_attribute(obj, 'static_method')),
            gorilla.Patch(destination, 'value',
                          gorilla.get_attribute(obj, 'value')),
            gorilla.Patch(destination.Inner, 'STATIC_VALUE',
                          gorilla.get_attribute(obj.Inner, 'STATIC_VALUE')),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 12
0
    def test_settings_decorator_2(self):
        destination = _tomodule.Class
        obj = _frommodule.Class

        gorilla.settings(some_value=123)(gorilla.get_attribute(obj, 'method'))
        gorilla.settings(allow_hit=False)(gorilla.get_attribute(
            obj, 'static_method'))
        gorilla.settings(store_hit=True)(gorilla.get_attribute(obj, 'value'))
        gorilla.settings(allow_hit=False,
                         store_hit=True)(gorilla.get_attribute(
                             obj.Inner, 'method'))
        gorilla.patches(destination,
                        settings=gorilla.Settings(allow_hit=True,
                                                  store_hit=False))(obj)

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination,
                          'STATIC_VALUE',
                          gorilla.get_attribute(obj, 'STATIC_VALUE'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=False)),
            gorilla.Patch(destination,
                          'class_method',
                          gorilla.get_attribute(obj, 'class_method'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=False)),
            gorilla.Patch(destination,
                          'method',
                          gorilla.get_attribute(obj, 'method'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    some_value=123,
                                                    store_hit=False)),
            gorilla.Patch(destination,
                          'static_method',
                          gorilla.get_attribute(obj, 'static_method'),
                          settings=gorilla.Settings(store_hit=False)),
            gorilla.Patch(destination,
                          'value',
                          gorilla.get_attribute(obj, 'value'),
                          settings=gorilla.Settings(allow_hit=True)),
            gorilla.Patch(destination.Inner,
                          'STATIC_VALUE',
                          gorilla.get_attribute(obj.Inner, 'STATIC_VALUE'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=False)),
            gorilla.Patch(destination.Inner,
                          'method',
                          gorilla.get_attribute(obj.Inner, 'method'),
                          settings=gorilla.Settings(allow_hit=False,
                                                    store_hit=True)),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 13
0
    def test_patch(self):
        patch_1 = gorilla.Patch(_tomodule, 'dummy', _frommodule.function)
        patch_2 = gorilla.Patch(_tomodule,
                                'dummy',
                                _frommodule.function,
                                settings=None)
        self.assertEqual(patch_1, patch_2)
        self.assertNotEqual(
            patch_1, {
                'destination': _tomodule,
                'name': 'dummy',
                'obj': _frommodule.function,
                'settings': None
            })

        patch_1.name = 'function'
        self.assertNotEqual(patch_1, patch_2)

        patch_2.name = 'function'
        self.assertEqual(patch_1, patch_2)

        patch_1.some_value = 123
        self.assertNotEqual(patch_1, patch_2)

        patch_2.some_value = 123
        self.assertEqual(patch_1, patch_2)

        patch = gorilla.Patch(_tomodule, 'dummy', _frommodule.function)
        self.assertEqual(
            str(patch),
            "Patch(destination=%r, name='dummy', obj=%r, settings=None)" %
            (_tomodule, _frommodule.function))

        patch.some_value = 123
        self.assertEqual(
            str(patch),
            "Patch(destination=%r, name='dummy', obj=%r, settings=None)" %
            (_tomodule, _frommodule.function))
Esempio n. 14
0
def monkey_patch_tf_get_seed(seed: int,
                             default_op_seed: int = 1923746) -> None:
    """
    Monkey patching tensorflow.random.get_seed to avoid the increasing memory usage arising from
    repeated random sampling from tensorflow distributions.

    This code is taken from https://github.com/lerobitaille/tf-issue-36164-workaround which remedies
    issue 36164 (https://github.com/tensorflow/tensorflow/issues/36164).

    We have raised our own clearer and concise issue which should be the point at which should be
    the reference point for this memory leak: https://github.com/tensorflow/tensorflow/issues/37252

    :param seed: Seed to set as the TensorFlow global seed.
    :param default_op_seed: Default seed for any random operations if required.
    """
    warn(
        "WARNING: Patching native TensorFlow functionality to avoid memory leak when setting "
        "a random seed.")
    warn("WARNING: Patch required due to TensorFlow issue 37252. "
         "Check if the issue is resolved at "
         "https://github.com/tensorflow/tensorflow/issues/37252")
    # Lazy imports to show which imports to remove once the issue is resolved and to avoid wider
    # usage of monkey patching and usage of the TensorFlow back end which involves imports the
    # linter does not like.
    # pylint: disable=no-name-in-module,import-error
    from tensorflow.python.eager import context
    from tensorflow.python import pywrap_tensorflow
    from tensorflow.python.framework import random_seed
    # Remove gorilla dependency completely when issue fixed. (Remove from requirements.txt)
    import gorilla

    def better_get_seed(global_seed, op_seed):
        if op_seed is not None:
            return global_seed, op_seed
        else:
            return global_seed, default_op_seed

    # Monkey Patch get_seed.
    def func(op_seed):
        better_get_seed(seed, op_seed)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(random_seed, 'get_seed', func, settings=settings)
    gorilla.apply(patch)

    # Also clear the kernel cache, to reset any existing seeds
    # pylint: disable=protected-access
    _context = context.context()
    if _context._context_handle is not None:
        pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle)
Esempio n. 15
0
    def test_create_patches_6(self):
        destination = _tomodule.Class
        obj = _frommodule.Class
        patches = gorilla.create_patches(destination,
                                         obj,
                                         filter=None,
                                         recursive=False,
                                         use_decorators=False)

        expected_patches = [
            gorilla.Patch(destination, name, value)
            for name, value in sorted(_iteritems(obj.__dict__))
        ]

        self.assertEqual(patches, expected_patches)
Esempio n. 16
0
    def test_patch_decorator(self):
        destination = _tomodule
        obj = gorilla.get_attribute(_frommodule, 'function')

        settings = gorilla.Settings(allow_hit=True, store_hit=True)
        self.assertIs(gorilla.patch(destination, settings=settings)(obj), obj)
        settings.allow_hit = False
        settings.store_hit = False

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination, 'function', obj,
                          gorilla.Settings(allow_hit=True, store_hit=True)),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 17
0
    def test_get_original_attribute(self):
        destination = _tomodule.Class
        name = 'method'
        target = gorilla.get_attribute(destination, name)
        obj = gorilla.get_attribute(_frommodule, 'unbound_method')
        settings = gorilla.Settings(allow_hit=True)
        patch = gorilla.Patch(destination, name, obj, settings=settings)

        gorilla.apply(patch)
        self.assertIs(
            _unfold(gorilla.get_original_attribute(destination, name)), target)

        gorilla.apply(patch)
        self.assertIs(
            _unfold(gorilla.get_original_attribute(destination, name)), target)
Esempio n. 18
0
    def test_apply_patch_with_hit_1(self):
        settings = gorilla.Settings()

        source_paths = [''] + _list_attribute_paths(_frommodule)
        target_paths = _list_attribute_paths(_tomodule)
        combinations = itertools.product(source_paths, target_paths)
        for source_path, target_path in combinations:
            self.setUp()

            destination_path, name = _split_attribute_path(target_path)
            destination = _get_attribute_from_path(_tomodule, destination_path)
            obj = _get_attribute_from_path(_frommodule, source_path)
            patch = gorilla.Patch(destination, name, obj, settings=settings)
            self.assertRaises(RuntimeError, gorilla.apply, patch)

            self.tearDown()
Esempio n. 19
0
def wrap_patch(destination, name, patch, settings=None):
    """
    Apply a patch while preserving the attributes (e.g. __doc__) of an original function.

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch: Patch function
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    original = getattr(destination, name)
    wrapped = functools.wraps(original)(patch)
    patch = gorilla.Patch(destination, name, wrapped, settings=settings)
    gorilla.apply(patch)
Esempio n. 20
0
    def test_patches_decorator(self):
        destination = _tomodule.Class
        obj = _frommodule.Class

        settings = gorilla.Settings(allow_hit=True, store_hit=True)
        self.assertIs(
            gorilla.patches(destination, settings=settings)(obj), obj)
        settings.allow_hit = False
        settings.store_hit = False

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination,
                          'STATIC_VALUE',
                          gorilla.get_attribute(obj, 'STATIC_VALUE'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
            gorilla.Patch(destination,
                          'class_method',
                          gorilla.get_attribute(obj, 'class_method'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
            gorilla.Patch(destination,
                          'method',
                          gorilla.get_attribute(obj, 'method'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
            gorilla.Patch(destination,
                          'static_method',
                          gorilla.get_attribute(obj, 'static_method'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
            gorilla.Patch(destination,
                          'value',
                          gorilla.get_attribute(obj, 'value'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
            gorilla.Patch(destination.Inner,
                          'STATIC_VALUE',
                          gorilla.get_attribute(obj.Inner, 'STATIC_VALUE'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
            gorilla.Patch(destination.Inner,
                          'method',
                          gorilla.get_attribute(obj.Inner, 'method'),
                          settings=gorilla.Settings(allow_hit=True,
                                                    store_hit=True)),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 21
0
    def test_create_patches_4(self):
        def filter(name, value):
            return 'method' in name

        destination = _tomodule
        obj = _frommodule
        patches = gorilla.create_patches(destination, obj, filter=filter)
        expected_patches = [
            gorilla.Patch(destination, 'function',
                          gorilla.get_attribute(obj, 'function')),
            gorilla.Patch(destination, 'whatever',
                          gorilla.get_attribute(obj, 'unbound_class_method')),
            gorilla.Patch(destination,
                          'unbound_static_method',
                          gorilla.get_attribute(obj, 'unbound_static_method'),
                          settings=gorilla.Settings(allow_hit=True))
        ]
        self.assertEqual(patches, expected_patches)

        destination = _tomodule.Class
        obj = _frommodule.Class
        patches = gorilla.create_patches(destination, obj, filter=filter)
        expected_patches = [
            gorilla.Patch(destination, 'class_method',
                          gorilla.get_attribute(obj, 'class_method')),
            gorilla.Patch(destination, 'whatever',
                          gorilla.get_attribute(obj, 'method')),
        ]
        self.assertEqual(patches, expected_patches)

        destination = _tomodule.Parent
        obj = _frommodule.Parent
        patches = gorilla.create_patches(destination, obj, filter=filter)
        expected_patches = [
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
        ]
        self.assertEqual(patches, expected_patches)

        destination = _tomodule.Child
        obj = _frommodule.Child
        patches = gorilla.create_patches(destination, obj, filter=filter)
        expected_patches = [
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
        ]
        self.assertEqual(patches, expected_patches)
Esempio n. 22
0
    def test_patches_decorator_on_class(self):
        destination = _tomodule.Class
        obj = _frommodule.Class

        self.assertIs(gorilla.patches(destination)(obj), obj)

        decorator_data = gorilla.get_decorator_data(obj)
        expected_patches = [
            gorilla.Patch(destination, 'STATIC_VALUE',
                          gorilla.get_attribute(obj, 'STATIC_VALUE')),
            gorilla.Patch(destination, 'class_method',
                          gorilla.get_attribute(obj, 'class_method')),
            gorilla.Patch(destination, 'method',
                          gorilla.get_attribute(obj, 'method')),
            gorilla.Patch(destination, 'static_method',
                          gorilla.get_attribute(obj, 'static_method')),
            gorilla.Patch(destination, 'value',
                          gorilla.get_attribute(obj, 'value')),
            gorilla.Patch(destination.Inner, 'STATIC_VALUE',
                          gorilla.get_attribute(obj.Inner, 'STATIC_VALUE')),
            gorilla.Patch(destination.Inner, 'method',
                          gorilla.get_attribute(obj.Inner, 'method')),
        ]
        self.assertEqual(decorator_data.patches, expected_patches)
Esempio n. 23
0
def autolog():
    """
    Enable automatic logging from Fastai to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.

    MLflow will also log the parameters of the EarlyStopping and OneCycleScheduler callbacks
    """
    from fastai.basic_train import LearnerCallback, Learner
    from fastai.callbacks.hooks import model_summary, layers_info
    from fastai.callbacks import EarlyStoppingCallback, OneCycleScheduler

    class __MLflowFastaiCallback(LearnerCallback):
        """
        Callback for auto-logging metrics and parameters.
        Records model structural information as params when training begins
        """
        def __init__(
            self,
            learner,
        ):
            super().__init__(learner)
            self.learner = learner
            self.opt = self.learn.opt
            self.metrics_names = ["train_loss", "valid_loss"
                                  ] + [o.__name__ for o in learner.metrics]

        def on_epoch_end(self, **kwargs):
            """
            Log loss and other metrics values after each epoch
            """
            if kwargs["smooth_loss"] is None or kwargs["last_metrics"] is None:
                return
            epoch = kwargs["epoch"]
            metrics = [kwargs["smooth_loss"]] + kwargs["last_metrics"]
            metrics = map(float, metrics)
            metrics = dict(zip(self.metrics_names, metrics))
            try_mlflow_log(mlflow.log_metrics, metrics, step=epoch)

        def on_train_begin(self, **kwargs):
            info = layers_info(self.learner)
            try_mlflow_log(mlflow.log_param, "num_layers", len(info))
            try_mlflow_log(mlflow.log_param, "opt_func",
                           self.opt_func.func.__name__)

            if hasattr(self.opt, "true_wd"):
                try_mlflow_log(mlflow.log_param, "true_wd", self.opt.true_wd)

            if hasattr(self.opt, "bn_wd"):
                try_mlflow_log(mlflow.log_param, "bn_wd", self.opt.bn_wd)

            if hasattr(self.opt, "train_bn"):
                try_mlflow_log(mlflow.log_param, "train_bn", self.train_bn)

            summary = model_summary(self.learner)
            try_mlflow_log(mlflow.set_tag, "model_summary", summary)

            tempdir = tempfile.mkdtemp()
            try:
                summary_file = os.path.join(tempdir, "model_summary.txt")
                with open(summary_file, "w") as f:
                    f.write(summary)
                try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
            finally:
                shutil.rmtree(tempdir)

        def on_train_end(self, **kwargs):
            try_mlflow_log(log_model, self.learner, artifact_path="model")

    def _find_callback_of_type(callback_type, callbacks):
        for callback in callbacks:
            if isinstance(callback, callback_type):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    "early_stop_monitor": callback.monitor,
                    "early_stop_min_delta": callback.min_delta,
                    "early_stop_patience": callback.patience,
                    "early_stop_mode": callback.mode,
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _log_one_cycle_callback_params(callback):
        if callback:
            try:
                params = {
                    "lr_max": callback.lr_max,
                    "div_factor": callback.div_factor,
                    "pct_start": callback.pct_start,
                    "final_div": callback.final_div,
                    "tot_epochs": callback.tot_epochs,
                    "start_epoch": callback.start_epoch,
                    "moms": callback.moms,
                }
                try_mlflow_log(mlflow.log_params, params)
            except Exception:  # pylint: disable=W0703
                return

    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, [self] + list(args), kwargs,
                              unlogged_params)

        callbacks = [cb(self)
                     for cb in self.callback_fns] + (self.callbacks or [])

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            callbacks += list(args[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowFastaiCallback(self)]
            args = tuple(tmp_list)
        elif "callbacks" in kwargs:
            callbacks += list(kwargs["callbacks"])
            kwargs["callbacks"] += [__MLflowFastaiCallback(self)]
        else:
            kwargs["callbacks"] = [__MLflowFastaiCallback(self)]

        early_stop_callback = _find_callback_of_type(EarlyStoppingCallback,
                                                     callbacks)
        one_cycle_callback = _find_callback_of_type(OneCycleScheduler,
                                                    callbacks)

        _log_early_stop_callback_params(early_stop_callback)
        _log_one_cycle_callback_params(one_cycle_callback)

        result = original(self, *args, **kwargs)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return result

    @gorilla.patch(Learner)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(Learner, "fit")
        unlogged_params = ["self", "callbacks", "learner"]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 3)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(Learner, "fit", fit, settings=settings))
Esempio n. 24
0
def autolog():
    """
    Enables automatic logging from LightGBM to MLflow. Logs the following.

    - parameters specified in `lightgbm.train`_.
    - metrics on each iteration (if ``valid_sets`` specified).
    - metrics at the best iteration (if ``early_stopping_rounds`` specified).
    - feature importance (both "split" and "gain") as JSON files and plots.
    - trained model.

    Note that the `scikit-learn API`_ is not supported.
    """
    import lightgbm
    import numpy as np

    @gorilla.patch(lightgbm)
    def train(*args, **kwargs):

        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                res = {}
                for data_name, eval_name, value, _ in env.evaluation_result_list:
                    key = data_name + '-' + eval_name
                    res[key] = value

                eval_results.append(res)
            return callback

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            indices = np.argsort(importance)
            features = np.array(features)[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align='center', height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel('Importance')
            ax.set_title('Feature Importance ({})'.format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(tmpdir, 'feature_importance_{}.png'.format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(lightgbm, 'train')

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs['params']
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval',
                           'init_model', 'evals_result', 'learning_rates', 'callbacks']

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index('callbacks')
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs and kwargs['callbacks'] is not None:
            kwargs['callbacks'] += [callback]
        else:
            kwargs['callbacks'] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index('early_stopping_rounds')
        early_stopping = (num_pos_args >= early_stopping_index + 1 or
                          'early_stopping_rounds' in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results))
            # best_iteration is set even if training does not stop early.
            try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration)
            # iteration starts from 1 in LightGBM.
            try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in ['split', 'gain']:
            features = model.feature_name()
            importance = model.feature_importance(importance_type=imp_type)
            try:
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:  # pylint: disable=broad-except
                _logger.exception('Failed to log feature importance plot. LightGBM autologging '
                                  'will ignore the failure and continue. Exception: ')

            imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type))
                with open(filepath, 'w') as f:
                    json.dump(imp, f, indent=2)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        try_mlflow_log(log_model, model, artifact_path='model')

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(lightgbm, 'train', train, settings=settings))
def autolog(every_n_iter=100):
    # pylint: disable=E0611
    """
    Enables automatic logging from TensorFlow to MLflow.
    Note that autologging for ``tf.keras`` is handled by :py:func:`mlflow.tensorflow.autolog`,
    not :py:func:`mlflow.keras.autolog`.
    As an example, try running the
    `TensorFlow examples <https://github.com/mlflow/mlflow/tree/master/examples/tensorflow>`_.

    For each TensorFlow module, autologging captures the following information:

    **tf.keras**
     - **Metrics** and **Parameters**

      - Training loss; validation loss; user-specified metrics
      - ``fit()`` or ``fit_generator()`` parameters; optimizer name; learning rate; epsilon

     - **Artifacts**

      - Model summary on training start
      - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (Keras model)
      - TensorBoard logs on training end

    **tf.keras.callbacks.EarlyStopping**
     - **Metrics** and **Parameters**

      - Metrics from the ``EarlyStopping`` callbacks: ``stopped_epoch``, ``restored_epoch``,
        ``restore_best_weight``, etc
      - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``:
        ``min_delta``, ``patience``, ``baseline``, ``restore_best_weights``, etc

    **tf.estimator**
     - **Metrics** and **Parameters**

      - TensorBoard metrics: ``average_loss``, ``loss``, etc
      - Parameters ``steps`` and ``max_steps``

     - **Artifacts**

      - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (TF saved model) on call
        to ``tf.estimator.export_saved_model``

    **TensorFlow Core**
     - **Metrics**

      - All ``tf.summary.scalar`` calls

    Refer to the autologging tracking documentation for more
    information on `TensorFlow workflows
    <https://www.mlflow.org/docs/latest/tracking.html#tensorflow-and-keras-experimental>`_.

    :param every_n_iter: The frequency with which metrics should be logged.
                                  Defaults to 100. Ex: a value of 100 will log metrics
                                  at step 0, 100, 200, etc.
    """
    import tensorflow

    global _LOG_EVERY_N_STEPS
    _LOG_EVERY_N_STEPS = every_n_iter

    if LooseVersion(tensorflow.__version__) < LooseVersion("1.12"):
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    try:
        from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
        from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
        from tensorflow.python.saved_model import tag_constants
        from tensorflow.python.summary.writer.writer import FileWriter
    except ImportError:
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    @contextmanager
    def _manage_active_run():
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            global _AUTOLOG_RUN_ID
            if mlflow.active_run(
            ) is not None:  # defensive check in case `mlflow.start_run` fails
                _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id
        yield mlflow.active_run()
        if mlflow.active_run() is not None and mlflow.active_run(
        ).info.run_id == _AUTOLOG_RUN_ID:
            try_mlflow_log(mlflow.end_run)

    @gorilla.patch(tensorflow.estimator.Estimator)
    def train(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(
                tensorflow.estimator.Estimator, "train")

            # Checking step and max_step parameters for logging
            if len(args) >= 3:
                try_mlflow_log(mlflow.log_param, "steps", args[2])
                if len(args) >= 4:
                    try_mlflow_log(mlflow.log_param, "max_steps", args[3])
            if "steps" in kwargs:
                try_mlflow_log(mlflow.log_param, "steps", kwargs["steps"])
            if "max_steps" in kwargs:
                try_mlflow_log(mlflow.log_param, "max_steps",
                               kwargs["max_steps"])

            result = original(self, *args, **kwargs)

            return result

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_saved_model(self, *args, **kwargs):
        auto_end = False
        if not mlflow.active_run():
            global _AUTOLOG_RUN_ID
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, "export_saved_model")
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(
            log_model,
            tf_saved_model_dir=serialized.decode("utf-8"),
            tf_meta_graph_tags=[tag_constants.SERVING],
            tf_signature_def_key="predict",
            artifact_path="model",
        )
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id
                == _AUTOLOG_RUN_ID) or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_savedmodel(self, *args, **kwargs):
        auto_end = False
        global _AUTOLOG_RUN_ID
        if not mlflow.active_run():
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, "export_savedmodel")
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(
            log_model,
            tf_saved_model_dir=serialized.decode("utf-8"),
            tf_meta_graph_tags=[tag_constants.SERVING],
            tf_signature_def_key="predict",
            artifact_path="model",
        )
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id
                == _AUTOLOG_RUN_ID) or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    def _early_stop_check(callbacks):
        for callback in callbacks:
            if isinstance(callback, tensorflow.keras.callbacks.EarlyStopping):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    "monitor": callback.monitor,
                    "min_delta": callback.min_delta,
                    "patience": callback.patience,
                    "baseline": callback.baseline,
                    "restore_best_weights": callback.restore_best_weights,
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return callback.stopped_epoch, callback.restore_best_weights, callback.patience
        except Exception:  # pylint: disable=W0703
            return None

    def _log_early_stop_callback_metrics(callback, history):
        if callback:
            callback_attrs = _get_early_stop_callback_attrs(callback)
            if callback_attrs is None:
                return
            stopped_epoch, restore_best_weights, patience = callback_attrs
            try_mlflow_log(mlflow.log_metric, "stopped_epoch", stopped_epoch)
            # Weights are restored only if early stopping occurs
            if stopped_epoch != 0 and restore_best_weights:
                restored_epoch = stopped_epoch - max(1, patience)
                try_mlflow_log(mlflow.log_metric, "restored_epoch",
                               restored_epoch)
                restored_metrics = {
                    key: history.history[key][restored_epoch]
                    for key in history.history.keys()
                }
                # Metrics are logged as 'epoch_loss' and 'epoch_acc' in TF 1.X
                if LooseVersion(
                        tensorflow.__version__) < LooseVersion("2.0.0"):
                    if "loss" in restored_metrics:
                        restored_metrics["epoch_loss"] = restored_metrics.pop(
                            "loss")
                    if "acc" in restored_metrics:
                        restored_metrics["epoch_acc"] = restored_metrics.pop(
                            "acc")
                # Checking that a metric history exists
                metric_key = next(iter(history.history), None)
                if metric_key is not None:
                    last_epoch = len(history.history[metric_key])
                    try_mlflow_log(mlflow.log_metrics,
                                   restored_metrics,
                                   step=last_epoch)

    @gorilla.patch(tensorflow.keras.Model)
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      "fit")

            unlogged_params = [
                "self", "x", "y", "callbacks", "validation_data", "verbose"
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)
            early_stop_callback = None

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                early_stop_callback = _early_stop_check(tmp_list[5])
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5])
                args = tuple(tmp_list)
            elif "callbacks" in kwargs:
                early_stop_callback = _early_stop_check(kwargs["callbacks"])
                kwargs["callbacks"], log_dir = _setup_callbacks(
                    kwargs["callbacks"])
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([])

            _log_early_stop_callback_params(early_stop_callback)

            history = original(self, *args, **kwargs)

            _log_early_stop_callback_metrics(early_stop_callback, history)

            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path="tensorboard_logs")
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return history

    @gorilla.patch(tensorflow.keras.Model)
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      "fit_generator")

            unlogged_params = [
                "self", "generator", "callbacks", "validation_data", "verbose"
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif "callbacks" in kwargs:
                kwargs["callbacks"], log_dir = _setup_callbacks(
                    kwargs["callbacks"])
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path="tensorboard_logs")
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return result

    @gorilla.patch(EventFileWriter)
    def add_event(self, event):
        _log_event(event)
        original = gorilla.get_original_attribute(EventFileWriter, "add_event")
        return original(self, event)

    @gorilla.patch(FileWriter)
    def add_summary(self, *args, **kwargs):
        original = gorilla.get_original_attribute(FileWriter, "add_summary")
        result = original(self, *args, **kwargs)
        _flush_queue()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patches = [
        gorilla.Patch(EventFileWriter,
                      "add_event",
                      add_event,
                      settings=settings),
        gorilla.Patch(EventFileWriterV2,
                      "add_event",
                      add_event,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      "train",
                      train,
                      settings=settings),
        gorilla.Patch(tensorflow.keras.Model, "fit", fit, settings=settings),
        gorilla.Patch(tensorflow.keras.Model,
                      "fit_generator",
                      fit_generator,
                      settings=settings),
        gorilla.Patch(
            tensorflow.estimator.Estimator,
            "export_saved_model",
            export_saved_model,
            settings=settings,
        ),
        gorilla.Patch(
            tensorflow.estimator.Estimator,
            "export_savedmodel",
            export_savedmodel,
            settings=settings,
        ),
        gorilla.Patch(FileWriter,
                      "add_summary",
                      add_summary,
                      settings=settings),
    ]

    for x in patches:
        gorilla.apply(x)
Esempio n. 26
0
def autolog():
    # pylint: disable=E0611
    """
    Enables automatic logging from Keras to MLflow. Autologging captures the following information:

    **Metrics** and **Parameters**
     - Training loss; validation loss; user-specified metrics
     - Metrics associated with the ``EarlyStopping`` callbacks: ``stopped_epoch``,
       ``restored_epoch``, ``restore_best_weight``, ``last_epoch``, etc
     - ``fit()`` or ``fit_generator()`` parameters; optimizer name; learning rate; epsilon
     - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``: ``min_delta``,
       ``patience``, ``baseline``, ``restore_best_weights``, etc
    **Artifacts**
     - Model summary on training start
     - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (Keras model) on training end

    .. code-block:: python
        :caption: Example

        import mlflow
        import mlflow.keras
        # Build, compile, enable autologging, and train your model
        keras_model = ...
        keras_model.compile(optimizer="rmsprop", loss="mse", metrics=["accuracy"])
        # autolog your metrics, parameters, and model
        mlflow.keras.autolog()
        results = keras_model.fit(
            x_train, y_train, epochs=20, batch_size=128, validation_data=(x_val, y_val))

    ``EarlyStopping Integration with Keras AutoLogging``

    MLflow will detect if an ``EarlyStopping`` callback is used in a ``fit()`` or
    ``fit_generator()`` call, and if the ``restore_best_weights`` parameter is set to be ``True``,
    then MLflow will log the metrics associated with the restored model as a final, extra step.
    The epoch of the restored model will also be logged as the metric ``restored_epoch``.
    This allows for easy comparison between the actual metrics of the restored model and
    the metrics of other models.

    If ``restore_best_weights`` is set to be ``False``, then MLflow will not log an additional step.

    Regardless of ``restore_best_weights``, MLflow will also log ``stopped_epoch``,
    which indicates the epoch at which training stopped due to early stopping.

    If training does not end due to early stopping, then ``stopped_epoch`` will be logged as ``0``.

    MLflow will also log the parameters of the ``EarlyStopping`` callback,
    excluding ``mode`` and ``verbose``.
    """
    import keras

    class __MLflowKerasCallback(keras.callbacks.Callback):
        """
        Callback for auto-logging metrics and parameters.
        Records available logs after each epoch.
        Records model structural information as params when training begins
        """
        def on_train_begin(self, logs=None):  # pylint: disable=unused-argument
            try_mlflow_log(mlflow.log_param, 'num_layers',
                           len(self.model.layers))
            try_mlflow_log(mlflow.log_param, 'optimizer_name',
                           type(self.model.optimizer).__name__)
            if hasattr(self.model.optimizer, 'lr'):
                lr = self.model.optimizer.lr if \
                    type(self.model.optimizer.lr) is float \
                    else keras.backend.eval(self.model.optimizer.lr)
                try_mlflow_log(mlflow.log_param, 'learning_rate', lr)
            if hasattr(self.model.optimizer, 'epsilon'):
                epsilon = self.model.optimizer.epsilon if \
                    type(self.model.optimizer.epsilon) is float \
                    else keras.backend.eval(self.model.optimizer.epsilon)
                try_mlflow_log(mlflow.log_param, 'epsilon', epsilon)

            sum_list = []
            self.model.summary(print_fn=sum_list.append)
            summary = '\n'.join(sum_list)
            tempdir = tempfile.mkdtemp()
            try:
                summary_file = os.path.join(tempdir, "model_summary.txt")
                with open(summary_file, 'w') as f:
                    f.write(summary)
                try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
            finally:
                shutil.rmtree(tempdir)

        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            try_mlflow_log(log_model, self.model, artifact_path='model')

        # As of Keras 2.4.0, Keras Callback implementations must define the following
        # methods indicating whether or not the callback overrides functions for
        # batch training/testing/inference
        def _implements_train_batch_hooks(self):
            return False

        def _implements_test_batch_hooks(self):
            return False

        def _implements_predict_batch_hooks(self):
            return False

    def _early_stop_check(callbacks):
        if LooseVersion(keras.__version__) < LooseVersion('2.3.0'):
            es_callback = keras.callbacks.EarlyStopping
        else:
            es_callback = keras.callbacks.callbacks.EarlyStopping
        for callback in callbacks:
            if isinstance(callback, es_callback):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    'monitor': callback.monitor,
                    'min_delta': callback.min_delta,
                    'patience': callback.patience,
                    'baseline': callback.baseline,
                    'restore_best_weights': callback.restore_best_weights
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return callback.stopped_epoch, callback.restore_best_weights, callback.patience
        except Exception:  # pylint: disable=W0703
            return None

    def _log_early_stop_callback_metrics(callback, history):
        if callback:
            callback_attrs = _get_early_stop_callback_attrs(callback)
            if callback_attrs is None:
                return
            stopped_epoch, restore_best_weights, patience = callback_attrs
            try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch)
            # Weights are restored only if early stopping occurs
            if stopped_epoch != 0 and restore_best_weights:
                restored_epoch = stopped_epoch - max(1, patience)
                try_mlflow_log(mlflow.log_metric, 'restored_epoch',
                               restored_epoch)
                restored_metrics = {
                    key: history.history[key][restored_epoch]
                    for key in history.history.keys()
                }
                # Checking that a metric history exists
                metric_key = next(iter(history.history), None)
                if metric_key is not None:
                    last_epoch = len(history.history[metric_key])
                    try_mlflow_log(mlflow.log_metrics,
                                   restored_metrics,
                                   step=last_epoch)

    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, args, kwargs, unlogged_params)
        early_stop_callback = None

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            early_stop_callback = _early_stop_check(
                tmp_list[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            early_stop_callback = _early_stop_check(kwargs['callbacks'])
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        _log_early_stop_callback_params(early_stop_callback)

        history = original(self, *args, **kwargs)

        _log_early_stop_callback_metrics(early_stop_callback, history)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return history

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit')
        unlogged_params = [
            'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose'
        ]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 5)

    @gorilla.patch(keras.Model)
    def fit_generator(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit_generator')
        unlogged_params = [
            'self', 'generator', 'callbacks', 'validation_data', 'verbose'
        ]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 4)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(keras.Model, 'fit', fit, settings=settings))
    gorilla.apply(
        gorilla.Patch(keras.Model,
                      'fit_generator',
                      fit_generator,
                      settings=settings))
Esempio n. 27
0
from splicemachine.mlflow_support import *
from splicemachine.mlflow_support.mlflow_support import _GORILLA_SETTINGS
import gorilla
import mlflow.pyfunc


def _log_model(name='pyfunc_model', **flavor_options):
    model = None
    if 'python_model' in flavor_options:
        model = flavor_options.pop('python_model')
    mlflow.log_model(model, name=name, model_lib='pyfunc', **flavor_options)


gorilla.apply(
    gorilla.Patch(mlflow.pyfunc,
                  _log_model.__name__.lstrip('_'),
                  _log_model,
                  settings=_GORILLA_SETTINGS))
Esempio n. 28
0
def autolog(every_n_iter=100):
    # pylint: disable=E0611
    """
    Enable automatic logging from TensorFlow to MLflow. If applicable,
    model checkpoints are logged as artifacts to a 'models' directory, along
    with any TensorBoard log data.

    Refer to the tracking documentation for
    information on what is logged with different TensorFlow workflows.

    :param every_n_iter: The frequency with which metrics should be logged.
                                  Defaults to 100. Ex: a value of 100 will log metrics
                                  at step 0, 100, 200, etc.

    """
    global _LOG_EVERY_N_STEPS
    _LOG_EVERY_N_STEPS = every_n_iter

    if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'):
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    try:
        from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
        from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
        from tensorflow.python.saved_model import tag_constants
        from tensorflow.python.summary.writer.writer import FileWriter
    except ImportError:
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    @contextmanager
    def _manage_active_run():
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            global _AUTOLOG_RUN_ID
            if mlflow.active_run(
            ) is not None:  # defensive check in case `mlflow.start_run` fails
                _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id
        yield mlflow.active_run()
        if mlflow.active_run() is not None and mlflow.active_run(
        ).info.run_id == _AUTOLOG_RUN_ID:
            try_mlflow_log(mlflow.end_run)

    @gorilla.patch(tensorflow.estimator.Estimator)
    def train(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(
                tensorflow.estimator.Estimator, 'train')

            # Checking step and max_step parameters for logging
            if len(args) >= 3:
                try_mlflow_log(mlflow.log_param, 'steps', args[2])
                if len(args) >= 4:
                    try_mlflow_log(mlflow.log_param, 'max_steps', args[3])
            if 'steps' in kwargs:
                try_mlflow_log(mlflow.log_param, 'steps', kwargs['steps'])
            if 'max_steps' in kwargs:
                try_mlflow_log(mlflow.log_param, 'max_steps',
                               kwargs['max_steps'])

            result = original(self, *args, **kwargs)

            return result

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_saved_model(self, *args, **kwargs):
        auto_end = False
        if not mlflow.active_run():
            global _AUTOLOG_RUN_ID
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_saved_model')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\
                or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_savedmodel(self, *args, **kwargs):
        auto_end = False
        global _AUTOLOG_RUN_ID
        if not mlflow.active_run():
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_savedmodel')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\
                or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.keras.Model)
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit')

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir,
                                        artifact_path='tensorboard_logs')
            shutil.rmtree(log_dir)

            return result

    @gorilla.patch(tensorflow.keras.Model)
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit_generator')

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir,
                                        artifact_path='tensorboard_logs')
            shutil.rmtree(log_dir)

            return result

    @gorilla.patch(EventFileWriter)
    def add_event(self, event):
        _log_event(event)
        original = gorilla.get_original_attribute(EventFileWriter, 'add_event')
        return original(self, event)

    @gorilla.patch(FileWriter)
    def add_summary(self, *args, **kwargs):
        original = gorilla.get_original_attribute(FileWriter, 'add_summary')
        result = original(self, *args, **kwargs)
        _flush_queue()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patches = [
        gorilla.Patch(EventFileWriter,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(EventFileWriterV2,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'train',
                      train,
                      settings=settings),
        gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings),
        gorilla.Patch(tensorflow.keras.Model,
                      'fit_generator',
                      fit_generator,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_saved_model',
                      export_saved_model,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_savedmodel',
                      export_savedmodel,
                      settings=settings),
        gorilla.Patch(FileWriter,
                      'add_summary',
                      add_summary,
                      settings=settings),
    ]

    for x in patches:
        gorilla.apply(x)
Esempio n. 29
0
        cnts[status] = cnt
    best_top1_acc = 0.
    for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials):
        if not trial.last_result:
            continue
        best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid'])
    print('iter',
          self._iteration,
          'top1_acc=%.3f' % best_top1_acc,
          cnts,
          end='\r')
    return original(self)


patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner,
                      'step',
                      step_w_log,
                      settings=gorilla.Settings(allow_hit=True))
gorilla.apply(patch)

logger = get_logger('Fast AutoAugment')


def _get_path(dataset, model, tag):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)),
                        'FastAutoAugment',
                        'models/%s_%s_%s.pt' % (dataset, model, tag))  # TODO


# @ray.remote(num_gpus=4, max_calls=1) #TODO: change to num_gpus=1 ???
# @ray.remote
def train_model(config,
Esempio n. 30
0
def autolog():
    """
    Enables autologging for scikit-learn estimators.

    **When is autologging performed?**
      Autologging is performed when you call:

      - ``estimator.fit()``
      - ``estimator.fit_predict()``
      - ``estimator.fit_transform()``

    **Logged information**
      **Parameters**
        - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params``
          is called with ``deep=True``. This means when you fit a meta estimator that chains
          a series of estimators, the parameters of these child estimators are also logged.

      **Metrics**
        - A training score obtained by ``estimator.score``. Note that the training score is
          computed using parameters given to ``fit()``.
        - Common metrics for classifier:

          - `precision score`_

          .. _precision score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html

          - `recall score`_

          .. _recall score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html

          - `f1 score`_

          .. _f1 score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html

          - `accuracy score`_

          .. _accuracy score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html

          If the classifier has method ``predict_proba``, we additionally log:

          - `log loss`_

          .. _log loss:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html

          - `roc auc score`_

          .. _roc auc score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html

        - Common metrics for regressor:

          - `mean squared error`_

          .. _mean squared error:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html

          - root mean squared error

          - `mean absolute error`_

          .. _mean absolute error:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html

          - `r2 score`_

          .. _r2 score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html

      **Tags**
        - An estimator class name (e.g. "LinearRegression").
        - A fully qualified estimator class name
          (e.g. "sklearn.linear_model._base.LinearRegression").

      **Artifacts**
        - A fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`).

    **How does autologging work for meta estimators?**
      When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls
      ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent
      ``fit()`` calls.

      **Parameter search**
          In addition to recording the information discussed above, autologging for parameter
          search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs
          with metrics for each set of explored parameters, as well as artifacts and parameters
          for the best model (if available).

    **Supported estimators**
      - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators).
      - `Pipeline`_
      - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_)

    .. _sklearn.utils.all_estimators:
        https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html

    .. _Pipeline:
        https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html

    .. _GridSearchCV:
        https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

    .. _RandomizedSearchCV:
        https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html

    **Example**

    `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_

    .. code-block:: python

        from pprint import pprint
        import numpy as np
        from sklearn.linear_model import LinearRegression
        import mlflow

        def fetch_logged_data(run_id):
            client = mlflow.tracking.MlflowClient()
            data = client.get_run(run_id).data
            tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")}
            artifacts = [f.path for f in client.list_artifacts(run_id, "model")]
            return data.params, data.metrics, tags, artifacts

        # enable autologging
        mlflow.sklearn.autolog()

        # prepare training data
        X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
        y = np.dot(X, np.array([1, 2])) + 3

        # train a model
        model = LinearRegression()
        with mlflow.start_run() as run:
            model.fit(X, y)

        # fetch logged data
        params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id)

        pprint(params)
        # {'copy_X': 'True',
        #  'fit_intercept': 'True',
        #  'n_jobs': 'None',
        #  'normalize': 'False'}

        pprint(metrics)
        # {'training_score': 1.0,
           'training_mae': 2.220446049250313e-16,
           'training_mse': 1.9721522630525295e-31,
           'training_r2_score': 1.0,
           'training_rmse': 4.440892098500626e-16}

        pprint(tags)
        # {'estimator_class': 'sklearn.linear_model._base.LinearRegression',
        #  'estimator_name': 'LinearRegression'}

        pprint(artifacts)
        # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl']
    """
    import pandas as pd
    import sklearn

    from mlflow.models import infer_signature
    from mlflow.sklearn.utils import (
        _MIN_SKLEARN_VERSION,
        _is_supported_version,
        _chunk_dict,
        _get_args_for_score,
        _log_specialized_estimator_content,
        _get_Xy,
        _all_estimators,
        _truncate_dict,
        _get_arg_names,
        _get_estimator_info_tags,
        _get_meta_estimators_for_autologging,
        _is_parameter_search_estimator,
        _log_parameter_search_results_as_artifact,
        _create_child_runs_for_parameter_search,
    )
    from mlflow.tracking.context import registry as context_registry
    from mlflow.utils.validation import (
        MAX_PARAMS_TAGS_PER_BATCH,
        MAX_PARAM_VAL_LENGTH,
        MAX_ENTITY_KEY_LENGTH,
    )

    if not _is_supported_version():
        warnings.warn(
            "Autologging utilities may not work properly on scikit-learn < {} "
            .format(_MIN_SKLEARN_VERSION) +
            "(current version: {})".format(sklearn.__version__),
            stacklevel=2,
        )

    def fit_mlflow(self, func_name, *args, **kwargs):
        should_start_run = mlflow.active_run() is None
        if should_start_run:
            try_mlflow_log(mlflow.start_run)

        _log_pretraining_metadata(self, *args, **kwargs)

        original_fit = gorilla.get_original_attribute(self, func_name)
        try:
            fit_output = original_fit(*args, **kwargs)
        except Exception as e:
            if should_start_run:
                try_mlflow_log(mlflow.end_run,
                               RunStatus.to_string(RunStatus.FAILED))

            raise e

        _log_posttraining_metadata(self, *args, **kwargs)

        if should_start_run:
            try_mlflow_log(mlflow.end_run)

        return fit_output

    def _log_pretraining_metadata(estimator, *args, **kwargs):  # pylint: disable=unused-argument
        """
        Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training.
        This is intended to be invoked within a patched scikit-learn training routine
        (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
        MLflow run that can be referenced via the fluent Tracking API.

        :param estimator: The scikit-learn estimator for which to log metadata.
        :param args: The arguments passed to the scikit-learn training routine (e.g.,
                     `fit()`, `fit_transform()`, ...).
        :param kwargs: The keyword arguments passed to the scikit-learn training routine.
        """
        # Deep parameter logging includes parameters from children of a given
        # estimator. For some meta estimators (e.g., pipelines), recording
        # these parameters is desirable. For parameter search estimators,
        # however, child estimators act as seeds for the parameter search
        # process; accordingly, we avoid logging initial, untuned parameters
        # for these seed estimators.
        should_log_params_deeply = not _is_parameter_search_estimator(
            estimator)
        # Chunk model parameters to avoid hitting the log_batch API limit
        for chunk in _chunk_dict(
                estimator.get_params(deep=should_log_params_deeply),
                chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
        ):
            truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH,
                                       MAX_PARAM_VAL_LENGTH)
            try_mlflow_log(mlflow.log_params, truncated)

        try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator))

    def _log_posttraining_metadata(estimator, *args, **kwargs):
        """
        Records metadata for a scikit-learn estimator after training has completed.
        This is intended to be invoked within a patched scikit-learn training routine
        (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
        MLflow run that can be referenced via the fluent Tracking API.

        :param estimator: The scikit-learn estimator for which to log metadata.
        :param args: The arguments passed to the scikit-learn training routine (e.g.,
                     `fit()`, `fit_transform()`, ...).
        :param kwargs: The keyword arguments passed to the scikit-learn training routine.
        """
        if hasattr(estimator, "score"):
            try:
                score_args = _get_args_for_score(estimator.score,
                                                 estimator.fit, args, kwargs)
                training_score = estimator.score(*score_args)
            except Exception as e:  # pylint: disable=broad-except
                msg = (
                    estimator.score.__qualname__ +
                    " failed. The 'training_score' metric will not be recorded. Scoring error: "
                    + str(e))
                _logger.warning(msg)
            else:
                try_mlflow_log(mlflow.log_metric, "training_score",
                               training_score)

        # log common metrics and artifacts for estimators (classifier, regressor)
        _log_specialized_estimator_content(estimator,
                                           mlflow.active_run().info.run_id,
                                           args, kwargs)

        input_example = None
        signature = None
        if hasattr(estimator, "predict"):
            try:
                # Fetch an input example using the first several rows of the array-like
                # training data supplied to the training routine (e.g., `fit()`)
                SAMPLE_ROWS = 5
                fit_arg_names = _get_arg_names(estimator.fit)
                X_var_name, y_var_name = fit_arg_names[:2]
                input_example = _get_Xy(args, kwargs, X_var_name,
                                        y_var_name)[0][:SAMPLE_ROWS]

                model_output = estimator.predict(input_example)
                signature = infer_signature(input_example, model_output)
            except Exception as e:  # pylint: disable=broad-except
                input_example = None
                msg = "Failed to infer an input example and model signature: " + str(
                    e)
                _logger.warning(msg)

        try_mlflow_log(
            log_model,
            estimator,
            artifact_path="model",
            signature=signature,
            input_example=input_example,
        )

        if _is_parameter_search_estimator(estimator):
            if hasattr(estimator, "best_estimator_"):
                try_mlflow_log(
                    log_model,
                    estimator.best_estimator_,
                    artifact_path="best_estimator",
                    signature=signature,
                    input_example=input_example,
                )

            if hasattr(estimator, "best_params_"):
                best_params = {
                    "best_{param_name}".format(param_name=param_name):
                    param_value
                    for param_name, param_value in
                    estimator.best_params_.items()
                }
                try_mlflow_log(mlflow.log_params, best_params)

            if hasattr(estimator, "cv_results_"):
                try:
                    # Fetch environment-specific tags (e.g., user and source) to ensure that lineage
                    # information is consistent with the parent run
                    environment_tags = context_registry.resolve_tags()
                    _create_child_runs_for_parameter_search(
                        cv_estimator=estimator,
                        parent_run=mlflow.active_run(),
                        child_tags=environment_tags,
                    )
                except Exception as e:  # pylint: disable=broad-except

                    msg = (
                        "Encountered exception during creation of child runs for parameter search."
                        " Child runs may be missing. Exception: {}".format(
                            str(e)))
                    _logger.warning(msg)

                try:
                    cv_results_df = pd.DataFrame.from_dict(
                        estimator.cv_results_)
                    _log_parameter_search_results_as_artifact(
                        cv_results_df,
                        mlflow.active_run().info.run_id)
                except Exception as e:  # pylint: disable=broad-except

                    msg = (
                        "Failed to log parameter search results as an artifact."
                        " Exception: {}".format(str(e)))
                    _logger.warning(msg)

    def patched_fit(self, func_name, *args, **kwargs):
        """
        To be applied to a sklearn model class that defines a `fit` method and
        inherits from `BaseEstimator` (thereby defining the `get_params()` method)
        """
        with _SklearnTrainingSession(clazz=self.__class__,
                                     allow_children=False) as t:
            if t.should_log():
                return fit_mlflow(self, func_name, *args, **kwargs)
            else:
                original_fit = gorilla.get_original_attribute(self, func_name)
                return original_fit(*args, **kwargs)

    def create_patch_func(func_name):
        def f(self, *args, **kwargs):
            return patched_fit(self, func_name, *args, **kwargs)

        return f

    patch_settings = gorilla.Settings(allow_hit=True, store_hit=True)
    _, estimators_to_patch = zip(*_all_estimators())
    # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected
    # for patching if they are not already included in the output of `all_estimators()`
    estimators_to_patch = set(estimators_to_patch).union(
        set(_get_meta_estimators_for_autologging()))

    for class_def in estimators_to_patch:
        for func_name in ["fit", "fit_transform", "fit_predict"]:
            if hasattr(class_def, func_name):
                original = getattr(class_def, func_name)

                # A couple of estimators use property methods to return fitting functions,
                # rather than defining the fitting functions on the estimator class directly.
                #
                # Example: https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/neighbors/_lof.py#L183  # noqa
                #
                # We currently exclude these property fitting methods from patching because
                # it's challenging to patch them correctly.
                #
                # Excluded fitting methods:
                # - sklearn.cluster._agglomerative.FeatureAgglomeration.fit_predict
                # - sklearn.neighbors._lof.LocalOutlierFactor.fit_predict
                #
                # You can list property fitting methods by inserting "print(class_def, func_name)"
                # in the if clause below.
                if isinstance(original, property):
                    continue

                patch_func = create_patch_func(func_name)
                # TODO(harupy): Package this wrap & patch routine into a utility function so we can
                # reuse it in other autologging integrations.
                # preserve original function attributes
                patch_func = functools.wraps(original)(patch_func)
                patch = gorilla.Patch(class_def,
                                      func_name,
                                      patch_func,
                                      settings=patch_settings)
                gorilla.apply(patch)