예제 #1
0
    def test_apply_patch_with_hit_3(self):
        settings = gorilla.Settings(allow_hit=True, store_hit=False)

        source_paths = [''] + _list_attribute_paths(_frommodule)
        target_paths = _list_attribute_paths(_tomodule)
        combinations = itertools.product(source_paths, target_paths)
        for source_path, target_path in combinations:
            self.setUp()

            destination_path, name = _split_attribute_path(target_path)
            destination = _get_attribute_from_path(_tomodule, destination_path)
            obj = _get_attribute_from_path(_frommodule, source_path)
            patch = gorilla.Patch(destination, name, obj, settings=settings)
            gorilla.apply(patch)
            self.assertIs(
                destination,
                _get_attribute_from_path(_tomodule, destination_path))

            result = gorilla.get_attribute(destination, name)
            self.assertIs(result, obj)

            self.assertRaises(AttributeError, gorilla.get_original_attribute,
                              destination, name)

            self.tearDown()
def monkey_patch():
    """
    Function that does the actual monkey patching
    """
    patch_sources = get_monkey_patching_patch_sources()
    patches = gorilla.find_patches(patch_sources)
    for patch in patches:
        gorilla.apply(patch)
예제 #3
0
파일: evonorm.py 프로젝트: pgsrv/evonorm
    def patch_frelu(cls):
        import gorilla
        import torch.nn.functional as F
        settings = gorilla.Settings(allow_hit=True)

        patch = gorilla.Patch(F, 'relu', lambda x: x, settings=settings)
        gorilla.apply(patch)

        print('F.relu have been patched to identity fn.')
예제 #4
0
def set_seed(seed=100):
    np.random.seed(seed)
    random.seed(seed)
    # Monkey Patch get_seed.
    func = lambda op_seed: better_get_seed(seed, op_seed)
    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(
        random_seed, 'get_seed', func, settings=settings)
    gorilla.apply(patch)
예제 #5
0
def autolog():
    """
    Enable automatic logging from TensorFlow to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.
    """
    import keras

    class __MLflowKerasCallback(keras.callbacks.Callback):
        """
        Callback for auto-logging metrics and parameters.
        Records available logs after each epoch.
        Records model structural information as params after training finishes.
        """
        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            try_mlflow_log(mlflow.log_param, 'num_layers',
                           len(self.model.layers))
            try_mlflow_log(mlflow.log_param, 'optimizer_name',
                           type(self.model.optimizer).__name__)
            if hasattr(self.model.optimizer, 'lr'):
                lr = self.model.optimizer.lr if \
                    type(self.model.optimizer.lr) is float \
                    else keras.backend.eval(self.model.optimizer.lr)
                try_mlflow_log(mlflow.log_param, 'learning_rate', lr)
            if hasattr(self.model.optimizer, 'epsilon'):
                epsilon = self.model.optimizer.epsilon if \
                    type(self.model.optimizer.epsilon) is float \
                    else keras.backend.eval(self.model.optimizer.epsilon)
                try_mlflow_log(mlflow.log_param, 'epsilon', epsilon)
            sum_list = []
            self.model.summary(print_fn=sum_list.append)
            summary = '\n'.join(sum_list)
            try_mlflow_log(mlflow.set_tag, 'summary', summary)
            try_mlflow_log(log_model, self.model, artifact_path='model')

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit')
        if len(args) >= 6:
            l = list(args)
            l[5] += [__MLflowKerasCallback()]
            args = tuple(l)
        elif 'callbacks' in kwargs:
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]
        return original(self, *args, **kwargs)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(keras.Model, 'fit', fit, settings=settings)
    gorilla.apply(patch)
예제 #6
0
def apply_gorrila(function: Callable, module: Any):
    """Overriding a function using a gorilla patch.

    Args:
        function (Callable): Override function
        module (Any): Function caller module
    """
    patch = gorilla.Patch(module,
                          function.__name__,
                          function,
                          settings=gorilla.Settings(allow_hit=True))
    gorilla.apply(patch)
예제 #7
0
    def test_get_original_attribute(self):
        destination = _tomodule.Class
        name = 'method'
        target = gorilla.get_attribute(destination, name)
        obj = gorilla.get_attribute(_frommodule, 'unbound_method')
        settings = gorilla.Settings(allow_hit=True)
        patch = gorilla.Patch(destination, name, obj, settings=settings)

        gorilla.apply(patch)
        self.assertIs(_unfold(gorilla.get_original_attribute(destination, name)), target)

        gorilla.apply(patch)
        self.assertIs(_unfold(gorilla.get_original_attribute(destination, name)), target)
예제 #8
0
def monkey_patch_tf_get_seed(seed: int,
                             default_op_seed: int = 1923746) -> None:
    """
    Monkey patching tensorflow.random.get_seed to avoid the increasing memory usage arising from
    repeated random sampling from tensorflow distributions.

    This code is taken from https://github.com/lerobitaille/tf-issue-36164-workaround which remedies
    issue 36164 (https://github.com/tensorflow/tensorflow/issues/36164).

    We have raised our own clearer and concise issue which should be the point at which should be
    the reference point for this memory leak: https://github.com/tensorflow/tensorflow/issues/37252

    :param seed: Seed to set as the TensorFlow global seed.
    :param default_op_seed: Default seed for any random operations if required.
    """
    warn(
        "WARNING: Patching native TensorFlow functionality to avoid memory leak when setting "
        "a random seed.")
    warn("WARNING: Patch required due to TensorFlow issue 37252. "
         "Check if the issue is resolved at "
         "https://github.com/tensorflow/tensorflow/issues/37252")
    # Lazy imports to show which imports to remove once the issue is resolved and to avoid wider
    # usage of monkey patching and usage of the TensorFlow back end which involves imports the
    # linter does not like.
    # pylint: disable=no-name-in-module,import-error
    from tensorflow.python.eager import context
    from tensorflow.python import pywrap_tensorflow
    from tensorflow.python.framework import random_seed
    # Remove gorilla dependency completely when issue fixed. (Remove from requirements.txt)
    import gorilla

    def better_get_seed(global_seed, op_seed):
        if op_seed is not None:
            return global_seed, op_seed
        else:
            return global_seed, default_op_seed

    # Monkey Patch get_seed.
    def func(op_seed):
        better_get_seed(seed, op_seed)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patch = gorilla.Patch(random_seed, 'get_seed', func, settings=settings)
    gorilla.apply(patch)

    # Also clear the kernel cache, to reset any existing seeds
    # pylint: disable=protected-access
    _context = context.context()
    if _context._context_handle is not None:
        pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle)
예제 #9
0
    def test_get_original_attribute(self):
        destination = _tomodule.Class
        name = 'method'
        target = gorilla.get_attribute(destination, name)
        obj = gorilla.get_attribute(_frommodule, 'unbound_method')
        settings = gorilla.Settings(allow_hit=True)
        patch = gorilla.Patch(destination, name, obj, settings=settings)

        gorilla.apply(patch)
        self.assertIs(
            _unfold(gorilla.get_original_attribute(destination, name)), target)

        gorilla.apply(patch)
        self.assertIs(
            _unfold(gorilla.get_original_attribute(destination, name)), target)
예제 #10
0
def wrap_patch(destination, name, patch, settings=None):
    """
    Apply a patch while preserving the attributes (e.g. __doc__) of an original function.

    :param destination: Patch destination
    :param name: Name of the attribute at the destination
    :param patch: Patch function
    :param settings: Settings for gorilla.Patch
    """
    if settings is None:
        settings = gorilla.Settings(allow_hit=True, store_hit=True)

    original = getattr(destination, name)
    wrapped = functools.wraps(original)(patch)
    patch = gorilla.Patch(destination, name, wrapped, settings=settings)
    gorilla.apply(patch)
예제 #11
0
def initialize():
    """Initialize the extensions.

    The patches from the Bana package are searched and applied to the Maya API.
    Patches that seem to have already been applied are skipped.
    """
    packages = [importlib.import_module('%s.%s' % (__package__, packageName))
                for packageName in _PACKAGES]
    defaultSettings = gorilla.Settings()
    for patch in gorilla.find_patches(packages):
        settings = (defaultSettings if patch.settings is None
                    else patch.settings)
        if not settings.allow_hit and hasattr(patch.destination, patch.name):
            continue

        gorilla.apply(patch)
예제 #12
0
파일: __init__.py 프로젝트: jonike/bana
def initialize():
    """Initialize the extensions.

    The patches from the Bana package are searched and applied to the Maya API.
    Patches that seem to have already been applied are skipped.
    """
    packages = [
        importlib.import_module('%s.%s' % (__package__, packageName))
        for packageName in _PACKAGES
    ]
    defaultSettings = gorilla.Settings()
    for patch in gorilla.find_patches(packages):
        settings = (defaultSettings
                    if patch.settings is None else patch.settings)
        if not settings.allow_hit and hasattr(patch.destination, patch.name):
            continue

        gorilla.apply(patch)
예제 #13
0
def autolog(every_n_iter=100):
    # pylint: disable=E0611
    """
    Enable automatic logging from TensorFlow to MLflow. If applicable,
    model checkpoints are logged as artifacts to a 'models' directory, along
    with any TensorBoard log data.

    Refer to the tracking documentation for
    information on what is logged with different TensorFlow workflows.

    :param every_n_iter: The frequency with which metrics should be logged.
                                  Defaults to 100. Ex: a value of 100 will log metrics
                                  at step 0, 100, 200, etc.

    """
    global _LOG_EVERY_N_STEPS
    _LOG_EVERY_N_STEPS = every_n_iter

    if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'):
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    try:
        from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
        from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
        from tensorflow.python.saved_model import tag_constants
        from tensorflow.python.summary.writer.writer import FileWriter
    except ImportError:
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    @contextmanager
    def _manage_active_run():
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            global _AUTOLOG_RUN_ID
            if mlflow.active_run(
            ) is not None:  # defensive check in case `mlflow.start_run` fails
                _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id
        yield mlflow.active_run()
        if mlflow.active_run() is not None and mlflow.active_run(
        ).info.run_id == _AUTOLOG_RUN_ID:
            try_mlflow_log(mlflow.end_run)

    @gorilla.patch(tensorflow.estimator.Estimator)
    def train(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(
                tensorflow.estimator.Estimator, 'train')

            # Checking step and max_step parameters for logging
            if len(args) >= 3:
                try_mlflow_log(mlflow.log_param, 'steps', args[2])
                if len(args) >= 4:
                    try_mlflow_log(mlflow.log_param, 'max_steps', args[3])
            if 'steps' in kwargs:
                try_mlflow_log(mlflow.log_param, 'steps', kwargs['steps'])
            if 'max_steps' in kwargs:
                try_mlflow_log(mlflow.log_param, 'max_steps',
                               kwargs['max_steps'])

            result = original(self, *args, **kwargs)

            return result

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_saved_model(self, *args, **kwargs):
        auto_end = False
        if not mlflow.active_run():
            global _AUTOLOG_RUN_ID
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_saved_model')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\
                or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_savedmodel(self, *args, **kwargs):
        auto_end = False
        global _AUTOLOG_RUN_ID
        if not mlflow.active_run():
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_savedmodel')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\
                or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    def _early_stop_check(callbacks):
        for callback in callbacks:
            if isinstance(callback, tensorflow.keras.callbacks.EarlyStopping):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    'monitor': callback.monitor,
                    'min_delta': callback.min_delta,
                    'patience': callback.patience,
                    'baseline': callback.baseline,
                    'restore_best_weights': callback.restore_best_weights
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return callback.stopped_epoch, callback.restore_best_weights, callback.patience
        except Exception:  # pylint: disable=W0703
            return None

    def _log_early_stop_callback_metrics(callback, history):
        if callback:
            callback_attrs = _get_early_stop_callback_attrs(callback)
            if callback_attrs is None:
                return
            stopped_epoch, restore_best_weights, patience = callback_attrs
            try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch)
            # Weights are restored only if early stopping occurs
            if stopped_epoch != 0 and restore_best_weights:
                restored_epoch = stopped_epoch - max(1, patience)
                try_mlflow_log(mlflow.log_metric, 'restored_epoch',
                               restored_epoch)
                restored_metrics = {
                    key: history.history[key][restored_epoch]
                    for key in history.history.keys()
                }
                # Metrics are logged as 'epoch_loss' and 'epoch_acc' in TF 1.X
                if LooseVersion(
                        tensorflow.__version__) < LooseVersion('2.0.0'):
                    if 'loss' in restored_metrics:
                        restored_metrics['epoch_loss'] = restored_metrics.pop(
                            'loss')
                    if 'acc' in restored_metrics:
                        restored_metrics['epoch_acc'] = restored_metrics.pop(
                            'acc')
                # Checking that a metric history exists
                metric_key = next(iter(history.history), None)
                if metric_key is not None:
                    last_epoch = len(history.history[metric_key])
                    try_mlflow_log(mlflow.log_metrics,
                                   restored_metrics,
                                   step=last_epoch)

    @gorilla.patch(tensorflow.keras.Model)
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit')

            unlogged_params = [
                'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose'
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)
            early_stop_callback = None

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                early_stop_callback = _early_stop_check(tmp_list[5])
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                early_stop_callback = _early_stop_check(kwargs['callbacks'])
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])

            _log_early_stop_callback_params(early_stop_callback)

            history = original(self, *args, **kwargs)

            _log_early_stop_callback_metrics(early_stop_callback, history)

            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path='tensorboard_logs')
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return history

    @gorilla.patch(tensorflow.keras.Model)
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit_generator')

            unlogged_params = [
                'self', 'generator', 'callbacks', 'validation_data', 'verbose'
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path='tensorboard_logs')
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return result

    @gorilla.patch(EventFileWriter)
    def add_event(self, event):
        _log_event(event)
        original = gorilla.get_original_attribute(EventFileWriter, 'add_event')
        return original(self, event)

    @gorilla.patch(FileWriter)
    def add_summary(self, *args, **kwargs):
        original = gorilla.get_original_attribute(FileWriter, 'add_summary')
        result = original(self, *args, **kwargs)
        _flush_queue()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patches = [
        gorilla.Patch(EventFileWriter,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(EventFileWriterV2,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'train',
                      train,
                      settings=settings),
        gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings),
        gorilla.Patch(tensorflow.keras.Model,
                      'fit_generator',
                      fit_generator,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_saved_model',
                      export_saved_model,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_savedmodel',
                      export_savedmodel,
                      settings=settings),
        gorilla.Patch(FileWriter,
                      'add_summary',
                      add_summary,
                      settings=settings),
    ]

    for x in patches:
        gorilla.apply(x)
예제 #14
0
    def test_apply_patch_with_hit_2(self):
        settings = gorilla.Settings(allow_hit=True)

        branch_count = 0

        source_paths = [''] + _list_attribute_paths(_frommodule)
        target_paths = _list_attribute_paths(_tomodule)
        combinations = itertools.product(source_paths, target_paths)
        for source_path, target_path in combinations:
            self.setUp()

            destination_path, name = _split_attribute_path(target_path)
            destination = _get_attribute_from_path(_tomodule, destination_path)
            target = gorilla.get_attribute(destination, name)
            obj = _get_attribute_from_path(_frommodule, source_path)
            patch = gorilla.Patch(destination, name, obj, settings=settings)
            gorilla.apply(patch)
            self.assertIs(
                destination,
                _get_attribute_from_path(_tomodule, destination_path))

            result = gorilla.get_attribute(destination, name)
            self.assertIs(result, obj)

            # `gorilla.get_original_attribute` cannot be used here because it
            # could return a bounded method, which would not compare as
            # expected.
            original = gorilla.get_attribute(destination,
                                             '_gorilla_original_%s' % (name, ))
            self.assertIs(original, target)
            self.assertIsNot(original, result)

            if source_path == '':
                branch_count += 1
                self.assertEqual(result.global_variable,
                                 "frommodule.global_variable")
                self.assertEqual(
                    result.function(),
                    "frommodule.function (frommodule.Class.STATIC_VALUE)")
                self.assertEqual(result.Class.STATIC_VALUE,
                                 "frommodule.Class.STATIC_VALUE")
                self.assertEqual(result.Class.Inner.STATIC_VALUE,
                                 "frommodule.Class.Inner.STATIC_VALUE")
                self.assertEqual(result.Parent.STATIC_VALUE,
                                 "frommodule.Parent.STATIC_VALUE")
                self.assertEqual(result.Child.STATIC_VALUE,
                                 "frommodule.Parent.STATIC_VALUE")
            elif source_path == 'global_variable':
                branch_count += 1
                self.assertEqual(result, "frommodule.global_variable")
                if destination_path in _CLS_REFERENCES and name in (
                        'STATIC_VALUE', ):
                    branch_count += 1
                    self.assertEqual(destination.STATIC_VALUE,
                                     "frommodule.global_variable")

                if target_path == 'Class.STATIC_VALUE':
                    branch_count += 1
                    self.assertEqual(
                        destination.class_method(),
                        "tomodule.Class.class_method (frommodule.global_variable)"
                    )
                    self.assertEqual(
                        destination.static_method(),
                        "tomodule.Class.static_method (frommodule.global_variable)"
                    )
            elif source_path == 'function':
                branch_count += 1
                self.assertEqual(
                    result(),
                    "frommodule.function (frommodule.Class.STATIC_VALUE)")
            elif source_path == 'unbound_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        'STATIC_VALUE', '__init__'):
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.unbound_method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)"
                        % ((_CLS_REFERENCES[destination_path], ) * 2))
            elif source_path == 'unbound_class_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        'STATIC_VALUE', ):
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.unbound_class_method (tomodule.%s.STATIC_VALUE)"
                        % (_CLS_REFERENCES[destination_path], ))
            elif source_path == 'unbound_static_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.unbound_static_method (frommodule.Class.STATIC_VALUE)"
                    )
            elif source_path == 'Class':
                branch_count += 1
                self.assertEqual(result.STATIC_VALUE,
                                 "frommodule.Class.STATIC_VALUE")
            elif source_path == 'Class.Inner':
                branch_count += 1
                self.assertEqual(result.STATIC_VALUE,
                                 "frommodule.Class.Inner.STATIC_VALUE")
            elif source_path == 'Class.value':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        '__init__', ):
                    branch_count += 1
                    instance = destination()
                    self.assertEqual(
                        getattr(instance, name),
                        "frommodule.Class.value.getter (tomodule.%s.instance_value)"
                        % (_CLS_REFERENCES[destination_path], ))
                    setattr(instance, name, 'hello')
                    self.assertEqual(getattr(instance, name),
                                     "frommodule.Class.value.getter (hello)")
            elif source_path == 'Class.method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        'STATIC_VALUE', '__init__'):
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.Class.method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)"
                        % ((_CLS_REFERENCES[destination_path], ) * 2))
            elif source_path == 'Class.class_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        'STATIC_VALUE', ):
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.Class.class_method (tomodule.%s.STATIC_VALUE)"
                        % (_CLS_REFERENCES[destination_path], ))
            elif source_path == 'Class.static_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.Class.static_method (frommodule.Class.STATIC_VALUE)"
                    )
            elif source_path == 'Parent':
                branch_count += 1
                self.assertEqual(result.__slots__,
                                 ('instance_value', 'parent_value', 'to_value',
                                  'from_value'))
                self.assertEqual(result().parent_value,
                                 "frommodule.Parent.parent_value")
                self.assertEqual(
                    result().method(),
                    "frommodule.Parent.method (frommodule.Parent.instance_value)"
                )
            elif source_path == 'Parent.method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        '__init__', ):
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.Parent.method (tomodule.%s.instance_value)"
                        % (_CLS_REFERENCES[destination_path], ))
            elif source_path == 'Child':
                branch_count += 1
                self.assertEqual(result.__slots__, ('child_value', ))
                self.assertEqual(result().parent_value,
                                 "frommodule.Parent.parent_value")
                self.assertEqual(result().child_value,
                                 "frommodule.Child.child_value")
                self.assertEqual(
                    result().method(),
                    "frommodule.Parent.method (frommodule.Parent.instance_value)"
                )
            elif source_path == 'Child.method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES and name not in (
                        '__init__', ):
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.Parent.method (tomodule.%s.instance_value)"
                        % (_CLS_REFERENCES[destination_path], ))

            self.tearDown()

        # Make sure that all test branches are covered.
        self.assertEqual(branch_count, 427)
예제 #15
0
def autolog():
    """
    Enable automatic logging from Keras to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.
    """
    import keras

    class __MLflowKerasCallback(keras.callbacks.Callback):
        """
        Callback for auto-logging metrics and parameters.
        Records available logs after each epoch.
        Records model structural information as params when training begins
        """
        def on_train_begin(self, logs=None):  # pylint: disable=unused-argument
            try_mlflow_log(mlflow.log_param, 'num_layers', len(self.model.layers))
            try_mlflow_log(mlflow.log_param, 'optimizer_name', type(self.model.optimizer).__name__)
            if hasattr(self.model.optimizer, 'lr'):
                lr = self.model.optimizer.lr if \
                    type(self.model.optimizer.lr) is float \
                    else keras.backend.eval(self.model.optimizer.lr)
                try_mlflow_log(mlflow.log_param, 'learning_rate', lr)
            if hasattr(self.model.optimizer, 'epsilon'):
                epsilon = self.model.optimizer.epsilon if \
                    type(self.model.optimizer.epsilon) is float \
                    else keras.backend.eval(self.model.optimizer.epsilon)
                try_mlflow_log(mlflow.log_param, 'epsilon', epsilon)

            sum_list = []
            self.model.summary(print_fn=sum_list.append)
            summary = '\n'.join(sum_list)
            try_mlflow_log(mlflow.set_tag, 'model_summary', summary)

            tempdir = tempfile.mkdtemp()
            try:
                summary_file = os.path.join(tempdir, "model_summary.txt")
                with open(summary_file, 'w') as f:
                    f.write(summary)
                try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
            finally:
                shutil.rmtree(tempdir)

        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            try_mlflow_log(log_model, self.model, artifact_path='model')

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(keras.Model, 'fit')

        # Checking if the 'callback' argument of fit() is set
        if len(args) >= 6:
            tmp_list = list(args)
            tmp_list[5] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        result = original(self, *args, **kwargs)
        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return result

    @gorilla.patch(keras.Model)
    def fit_generator(self, *args, **kwargs):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(keras.Model, 'fit_generator')

        # Checking if the 'callback' argument of fit() is set
        if len(args) >= 5:
            tmp_list = list(args)
            tmp_list[4] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        result = original(self, *args, **kwargs)
        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(keras.Model, 'fit', fit, settings=settings))
    gorilla.apply(gorilla.Patch(keras.Model, 'fit_generator', fit_generator, settings=settings))
예제 #16
0
def autolog():
    # pylint: disable=E0611
    """
    Enables automatic logging from Keras to MLflow. Autologging captures the following information:

    **Metrics** and **Parameters**
     - Training loss; validation loss; user-specified metrics
     - Metrics associated with the ``EarlyStopping`` callbacks: ``stopped_epoch``,
       ``restored_epoch``, ``restore_best_weight``, ``last_epoch``, etc
     - ``fit()`` or ``fit_generator()`` parameters; optimizer name; learning rate; epsilon
     - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``: ``min_delta``,
       ``patience``, ``baseline``, ``restore_best_weights``, etc
    **Artifacts**
     - Model summary on training start
     - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (Keras model) on training end

    .. code-block:: python
        :caption: Example

        import mlflow
        import mlflow.keras
        # Build, compile, enable autologging, and train your model
        keras_model = ...
        keras_model.compile(optimizer="rmsprop", loss="mse", metrics=["accuracy"])
        # autolog your metrics, parameters, and model
        mlflow.keras.autolog()
        results = keras_model.fit(
            x_train, y_train, epochs=20, batch_size=128, validation_data=(x_val, y_val))

    ``EarlyStopping Integration with Keras AutoLogging``

    MLflow will detect if an ``EarlyStopping`` callback is used in a ``fit()`` or
    ``fit_generator()`` call, and if the ``restore_best_weights`` parameter is set to be ``True``,
    then MLflow will log the metrics associated with the restored model as a final, extra step.
    The epoch of the restored model will also be logged as the metric ``restored_epoch``.
    This allows for easy comparison between the actual metrics of the restored model and
    the metrics of other models.

    If ``restore_best_weights`` is set to be ``False``, then MLflow will not log an additional step.

    Regardless of ``restore_best_weights``, MLflow will also log ``stopped_epoch``,
    which indicates the epoch at which training stopped due to early stopping.

    If training does not end due to early stopping, then ``stopped_epoch`` will be logged as ``0``.

    MLflow will also log the parameters of the ``EarlyStopping`` callback,
    excluding ``mode`` and ``verbose``.
    """
    import keras

    class __MLflowKerasCallback(keras.callbacks.Callback):
        """
        Callback for auto-logging metrics and parameters.
        Records available logs after each epoch.
        Records model structural information as params when training begins
        """
        def on_train_begin(self, logs=None):  # pylint: disable=unused-argument
            try_mlflow_log(mlflow.log_param, 'num_layers',
                           len(self.model.layers))
            try_mlflow_log(mlflow.log_param, 'optimizer_name',
                           type(self.model.optimizer).__name__)
            if hasattr(self.model.optimizer, 'lr'):
                lr = self.model.optimizer.lr if \
                    type(self.model.optimizer.lr) is float \
                    else keras.backend.eval(self.model.optimizer.lr)
                try_mlflow_log(mlflow.log_param, 'learning_rate', lr)
            if hasattr(self.model.optimizer, 'epsilon'):
                epsilon = self.model.optimizer.epsilon if \
                    type(self.model.optimizer.epsilon) is float \
                    else keras.backend.eval(self.model.optimizer.epsilon)
                try_mlflow_log(mlflow.log_param, 'epsilon', epsilon)

            sum_list = []
            self.model.summary(print_fn=sum_list.append)
            summary = '\n'.join(sum_list)
            tempdir = tempfile.mkdtemp()
            try:
                summary_file = os.path.join(tempdir, "model_summary.txt")
                with open(summary_file, 'w') as f:
                    f.write(summary)
                try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
            finally:
                shutil.rmtree(tempdir)

        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            try_mlflow_log(log_model, self.model, artifact_path='model')

        # As of Keras 2.4.0, Keras Callback implementations must define the following
        # methods indicating whether or not the callback overrides functions for
        # batch training/testing/inference
        def _implements_train_batch_hooks(self):
            return False

        def _implements_test_batch_hooks(self):
            return False

        def _implements_predict_batch_hooks(self):
            return False

    def _early_stop_check(callbacks):
        if LooseVersion(keras.__version__) < LooseVersion('2.3.0'):
            es_callback = keras.callbacks.EarlyStopping
        else:
            es_callback = keras.callbacks.callbacks.EarlyStopping
        for callback in callbacks:
            if isinstance(callback, es_callback):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    'monitor': callback.monitor,
                    'min_delta': callback.min_delta,
                    'patience': callback.patience,
                    'baseline': callback.baseline,
                    'restore_best_weights': callback.restore_best_weights
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return callback.stopped_epoch, callback.restore_best_weights, callback.patience
        except Exception:  # pylint: disable=W0703
            return None

    def _log_early_stop_callback_metrics(callback, history):
        if callback:
            callback_attrs = _get_early_stop_callback_attrs(callback)
            if callback_attrs is None:
                return
            stopped_epoch, restore_best_weights, patience = callback_attrs
            try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch)
            # Weights are restored only if early stopping occurs
            if stopped_epoch != 0 and restore_best_weights:
                restored_epoch = stopped_epoch - max(1, patience)
                try_mlflow_log(mlflow.log_metric, 'restored_epoch',
                               restored_epoch)
                restored_metrics = {
                    key: history.history[key][restored_epoch]
                    for key in history.history.keys()
                }
                # Checking that a metric history exists
                metric_key = next(iter(history.history), None)
                if metric_key is not None:
                    last_epoch = len(history.history[metric_key])
                    try_mlflow_log(mlflow.log_metrics,
                                   restored_metrics,
                                   step=last_epoch)

    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, args, kwargs, unlogged_params)
        early_stop_callback = None

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            early_stop_callback = _early_stop_check(
                tmp_list[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            early_stop_callback = _early_stop_check(kwargs['callbacks'])
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        _log_early_stop_callback_params(early_stop_callback)

        history = original(self, *args, **kwargs)

        _log_early_stop_callback_metrics(early_stop_callback, history)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return history

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit')
        unlogged_params = [
            'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose'
        ]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 5)

    @gorilla.patch(keras.Model)
    def fit_generator(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit_generator')
        unlogged_params = [
            'self', 'generator', 'callbacks', 'validation_data', 'verbose'
        ]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 4)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(keras.Model, 'fit', fit, settings=settings))
    gorilla.apply(
        gorilla.Patch(keras.Model,
                      'fit_generator',
                      fit_generator,
                      settings=settings))
예제 #17
0
def autolog(every_n_iter=100):
    # pylint: disable=E0611
    """
    Enable automatic logging from TensorFlow to MLflow. If applicable,
    model checkpoints are logged as artifacts to a 'models' directory, along
    with any TensorBoard log data.

    Refer to the tracking documentation for
    information on what is logged with different TensorFlow workflows.

    :param every_n_iter: The frequency with which metrics should be logged.
                                  Defaults to 100. Ex: a value of 100 will log metrics
                                  at step 0, 100, 200, etc.

    """
    global _LOG_EVERY_N_STEPS
    _LOG_EVERY_N_STEPS = every_n_iter

    if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'):
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    try:
        from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
        from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
        from tensorflow.python.saved_model import tag_constants
        from tensorflow.python.summary.writer.writer import FileWriter
    except ImportError:
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    @contextmanager
    def _manage_active_run():
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            global _AUTOLOG_RUN_ID
            if mlflow.active_run(
            ) is not None:  # defensive check in case `mlflow.start_run` fails
                _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id
        yield mlflow.active_run()
        if mlflow.active_run() is not None and mlflow.active_run(
        ).info.run_id == _AUTOLOG_RUN_ID:
            try_mlflow_log(mlflow.end_run)

    @gorilla.patch(tensorflow.estimator.Estimator)
    def train(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(
                tensorflow.estimator.Estimator, 'train')

            # Checking step and max_step parameters for logging
            if len(args) >= 3:
                try_mlflow_log(mlflow.log_param, 'steps', args[2])
                if len(args) >= 4:
                    try_mlflow_log(mlflow.log_param, 'max_steps', args[3])
            if 'steps' in kwargs:
                try_mlflow_log(mlflow.log_param, 'steps', kwargs['steps'])
            if 'max_steps' in kwargs:
                try_mlflow_log(mlflow.log_param, 'max_steps',
                               kwargs['max_steps'])

            result = original(self, *args, **kwargs)

            return result

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_saved_model(self, *args, **kwargs):
        auto_end = False
        if not mlflow.active_run():
            global _AUTOLOG_RUN_ID
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_saved_model')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\
                or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_savedmodel(self, *args, **kwargs):
        auto_end = False
        global _AUTOLOG_RUN_ID
        if not mlflow.active_run():
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_savedmodel')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\
                or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.keras.Model)
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit')

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir,
                                        artifact_path='tensorboard_logs')
            shutil.rmtree(log_dir)

            return result

    @gorilla.patch(tensorflow.keras.Model)
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      'fit_generator')

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif 'callbacks' in kwargs:
                kwargs['callbacks'], log_dir = _setup_callbacks(
                    kwargs['callbacks'])
            else:
                kwargs['callbacks'], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir,
                                        artifact_path='tensorboard_logs')
            shutil.rmtree(log_dir)

            return result

    @gorilla.patch(EventFileWriter)
    def add_event(self, event):
        _log_event(event)
        original = gorilla.get_original_attribute(EventFileWriter, 'add_event')
        return original(self, event)

    @gorilla.patch(FileWriter)
    def add_summary(self, *args, **kwargs):
        original = gorilla.get_original_attribute(FileWriter, 'add_summary')
        result = original(self, *args, **kwargs)
        _flush_queue()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patches = [
        gorilla.Patch(EventFileWriter,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(EventFileWriterV2,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'train',
                      train,
                      settings=settings),
        gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings),
        gorilla.Patch(tensorflow.keras.Model,
                      'fit_generator',
                      fit_generator,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_saved_model',
                      export_saved_model,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_savedmodel',
                      export_savedmodel,
                      settings=settings),
        gorilla.Patch(FileWriter,
                      'add_summary',
                      add_summary,
                      settings=settings),
    ]

    for x in patches:
        gorilla.apply(x)
예제 #18
0
def autolog():
    import keras

    class __MLSKerasCallback(keras.callbacks.Callback):
        def __init__(self):
            self.mls = Run(uuid1())

        def on_train_begin(self, logs=None):
            mls_add_param(self.mls, "num_layers", len(self.model.layers))
            mls_add_param(
                self.mls, "optimizer_name", type(self.model.optimizer).__name__
            )
            if hasattr(self.model.optimizer, "lr"):
                lr = (
                    self.model.optimizer.lr
                    if type(self.model.optimizer.lr) is float
                    else keras.backend.eval(self.model.optimizer.lr)
                )
                mls_add_param(self.mls, "learning_rate", lr)
            if hasattr(self.model.optimizer, "epsilon"):
                epsilon = (
                    self.model.optimizer.epsilon
                    if type(self.model.optimizer.epsilon) is float
                    else keras.backend.eval(self.model.optimizer.epsilon)
                )
                mls_add_param(self.mls, "epsilon", epsilon)

        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            # try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            return

        # As of Keras 2.4.0, Keras Callback implementations must define the following
        # methods indicating whether or not the callback overrides functions for
        # batch training/testing/inference
        def _implements_train_batch_hooks(self):
            return False

        def _implements_test_batch_hooks(self):
            return False

        def _implements_predict_batch_hooks(self):
            return False

    def _early_stop_check(callbacks):
        if LooseVersion(keras.__version__) < LooseVersion("2.3.0"):
            es_callback = keras.callbacks.EarlyStopping
        else:
            es_callback = keras.callbacks.EarlyStopping
        for callback in callbacks:
            if isinstance(callback, es_callback):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    "monitor": callback.monitor,
                    "min_delta": callback.min_delta,
                    "patience": callback.patience,
                    "baseline": callback.baseline,
                    "restore_best_weights": callback.restore_best_weights,
                }
                mls_add_params(self.mls, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return (
                callback.stopped_epoch,
                callback.restore_best_weights,
                callback.patience,
            )
        except Exception:  # pylint: disable=W0703
            return None

    def _run_and_log_function(
        self, original, args, kwargs, unlogged_params, callback_arg_index
    ):
        mls_callback = __MLSKerasCallback()
        model_class = "keras.Model"

        algo = Algorithm(_id="NeuralNetwork")
        params, input_values = fn_args_as_params(
            original, args, kwargs, mls_callback.mls._id, unlogged_params
        )
        mls_implementation = Implementation(
            model_class, params, algo, keras.__version__
        )

        mls_implementation.parameters += params
        mls_callback.mls.executes = mls_implementation
        mls_callback.mls.input_values += input_values

        early_stop_callback = None

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            early_stop_callback = _early_stop_check(tmp_list[callback_arg_index])
            tmp_list[callback_arg_index] += [mls_callback]
            args = tuple(tmp_list)
        elif "callbacks" in kwargs:
            early_stop_callback = _early_stop_check(kwargs["callbacks"])
            kwargs["callbacks"] += [mls_callback]
        else:
            kwargs["callbacks"] = [mls_callback]

        _log_early_stop_callback_params(early_stop_callback)

        history = original(self, *args, **kwargs)

        log_renku_mls(
            RunSchema().dumps(mls_callback.mls), str(self.__hash__()), force=True
        )

        return history

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, "fit")
        unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"]
        return _run_and_log_function(self, original, args, kwargs, unlogged_params, 5)

    @gorilla.patch(keras.Model)
    def fit_generator(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, "fit_generator")
        unlogged_params = [
            "self",
            "generator",
            "callbacks",
            "validation_data",
            "verbose",
        ]
        return _run_and_log_function(self, original, args, kwargs, unlogged_params, 4)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(keras.Model, "fit", fit, settings=settings))
    gorilla.apply(
        gorilla.Patch(keras.Model, "fit_generator", fit_generator, settings=settings)
    )
예제 #19
0
def autolog():
    """
    Enables automatic logging from LightGBM to MLflow. Logs the following.

    - parameters specified in `lightgbm.train`_.
    - metrics on each iteration (if ``valid_sets`` specified).
    - metrics at the best iteration (if ``early_stopping_rounds`` specified).
    - feature importance (both "split" and "gain").
    - trained model.

    Note that the `scikit-learn API`_ is not supported.
    """
    import lightgbm

    @gorilla.patch(lightgbm)
    def train(*args, **kwargs):

        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                res = {}
                for data_name, eval_name, value, _ in env.evaluation_result_list:
                    key = data_name + '-' + eval_name
                    res[key] = value

                eval_results.append(res)
            return callback

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(lightgbm, 'train')

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs['params']
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval',
                           'init_model', 'evals_result', 'learning_rates', 'callbacks']

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index('callbacks')
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs and kwargs['callbacks'] is not None:
            kwargs['callbacks'] += [callback]
        else:
            kwargs['callbacks'] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index('early_stopping_rounds')
        early_stopping = (num_pos_args >= early_stopping_index + 1 or
                          'early_stopping_rounds' in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results))
            # best_iteration is set even if training does not stop early.
            try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration)
            # iteration starts from 1 in LightGBM.
            try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in ['split', 'gain']:
            features = model.feature_name()
            importance = model.feature_importance(importance_type=imp_type)
            imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type))
                with open(filepath, 'w') as f:
                    json.dump(imp, f)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        try_mlflow_log(log_model, model, artifact_path='model')

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(lightgbm, 'train', train, settings=settings))
예제 #20
0
def main(argv):
    del argv  # Unused.

    logging.set_verbosity(logging.DEBUG)

    # Disable TQDM threading (solves a weird C runtime error)
    tqdm.tqdm.monitor_interval = 0

    if FLAGS.do_train or FLAGS.do_predict or (FLAGS.do_test
                                              and not FLAGS.prediction_dir):
        experiment: experiments.Experiment
        if FLAGS.implementation == 'tensorflow':
            # configure_tf(FLAGS.use_xla, FLAGS.use_amp)
            experiment = experiments.TFExperiment(
                cache_dir=FLAGS.cache_dir,
                configuration_name=FLAGS.init_checkpoint,
                max_seq_len=FLAGS.max_seq_len,
                use_xla=FLAGS.use_xla,
                use_amp=FLAGS.use_amp,
                seed=FLAGS.seed)
        elif FLAGS.implementation == 'pytorch':
            experiment = experiments.PTExperiment(
                cache_dir=FLAGS.cache_dir,
                configuration_name=FLAGS.init_checkpoint,
                max_seq_len=FLAGS.max_seq_len,
                use_amp=FLAGS.use_amp,
                warmup_epochs=FLAGS.warmup_epochs,
                seed=FLAGS.seed,
                temperature=FLAGS.temperature,
                dynamic_mixing=FLAGS.dynamic_mixing,
                mix_from_validation=FLAGS.mix_from_validation,
                clip_mixing_size=FLAGS.clip_mixing_size)
        else:
            raise NotImplementedError('Unsupported implementation \"%s\"' %
                                      FLAGS.implementation)

        # Load model
        model = experiment.load_model(model_name=FLAGS.init_checkpoint)

    patch_settings = gorilla.Settings(allow_hit=True)

    def _patched_gcs_dataset_info_files(dataset_dir):
        try:
            original = gorilla.get_original_attribute(
                gcs_utils, 'gcs_dataset_info_files')
            return original(dataset_dir)
        except IOError as ioe:
            logging.error('Failed to connect to GCS', exc_info=ioe)
            return None

    patch = gorilla.Patch(gcs_utils,
                          'gcs_dataset_info_files',
                          _patched_gcs_dataset_info_files,
                          settings=patch_settings)
    gorilla.apply(patch)

    # Setup tfds parameters
    Task.data_dir = FLAGS.data_dir
    Task.add_checksum_dir(FLAGS.checksum_dir)

    # Register all our defined task mappings
    tasks.register_task_mappings()

    if FLAGS.do_train:
        # Parse dataset and split
        training_tasks = Task.parse_train_tasks(FLAGS.training_tasks)
        validation_tasks = Task.parse_validation_tasks(FLAGS.validation_tasks)

        if FLAGS.dynamic_mixing and FLAGS.mix_from_validation:
            train_sets: Dict[str,
                             Task] = {t.dataset: t
                                      for t in training_tasks}
            valid_sets: Dict[str,
                             Task] = {t.dataset: t
                                      for t in validation_tasks}
            if train_sets.keys() != valid_sets.keys():
                logging.error(
                    'Dynamic mixing from validation requites validation data for each training task!'
                )
            for dataset in train_sets.keys() - valid_sets.keys():
                if Task.split_in_dataset("validation", dataset):
                    valid_sets[dataset] = Task(dataset, 'validation')
                    logging.warning('Adding %s to validation tasks', dataset)
                else:
                    train_sets[dataset] = Task(dataset, 'train[:70%]')
                    valid_sets[dataset] = Task(dataset, 'train[-30%:]')
                    logging.warning(
                        'Adjusting %s to use 80%% for training and 20%% for validation',
                        dataset)
            training_tasks = []
            validation_tasks = []
            for dataset in train_sets:
                training_tasks.append(train_sets[dataset])
                validation_tasks.append(valid_sets[dataset])
            for dataset in valid_sets.keys() - train_sets.keys():
                validation_tasks.append(valid_sets[dataset])

        if FLAGS.checkpoint_dir:
            # Make directories to save best checkpoint and final checkpoint
            os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
            FLAGS.append_flags_into_file(
                os.path.join(FLAGS.checkpoint_dir, 'flags.cfg'))
            best_dir = "{0}_best".format(FLAGS.checkpoint_dir)
            os.makedirs(best_dir, exist_ok=True)
            FLAGS.append_flags_into_file(os.path.join(best_dir, 'flags.cfg'))

        # Train model
        logging.info('Training %s with %s...', FLAGS.init_checkpoint,
                     ' '.join(FLAGS.training_tasks))
        experiment.train(model,
                         training_tasks=training_tasks,
                         validation_tasks=validation_tasks,
                         num_epochs=FLAGS.num_epochs,
                         steps_per_epoch=FLAGS.steps_per_epoch,
                         prefetch_size=FLAGS.prefetch_size,
                         batch_size=FLAGS.batch_size,
                         eval_batch_size=FLAGS.eval_batch_size,
                         eval_batches=FLAGS.eval_batches,
                         checkpoint_file=FLAGS.checkpoint_dir)

        if FLAGS.checkpoint_dir:
            # Save final checkpoint
            experiment.save_model(model, FLAGS.checkpoint_dir)

        if FLAGS.do_predict or (FLAGS.do_test and not FLAGS.prediction_dir):
            # Reload model, using best checkpoint if available.
            # Otherwise use the existing model.
            model_dir = "{0}_best".format(FLAGS.checkpoint_dir)
            if os.path.isdir(model_dir):
                logging.info("Loading best performing checkpoint: %s" %
                             (model_dir))
                model = experiment.load_model(model_name=model_dir)

    if FLAGS.do_predict:
        # Evaluate the model
        testing_tasks = Task.parse_test_tasks(FLAGS.testing_tasks)
        logging.info('Predicting %s with %s...', ' '.join(FLAGS.testing_tasks),
                     FLAGS.init_checkpoint)

        predictions = experiment.predict(model,
                                         tasks=testing_tasks,
                                         eval_batch_size=FLAGS.eval_batch_size)
        save_predictions(predictions, FLAGS.prediction_dir)

    if FLAGS.do_test:
        testing_tasks = Task.parse_test_tasks(FLAGS.testing_tasks)
        if FLAGS.prediction_dir:
            predictions = load_predictions(FLAGS.prediction_dir, testing_tasks)
        else:
            logging.warning(
                '--prediction_dir was not specified, generating predictions from scratch'
            )
            predictions = experiment.predict(
                model,
                tasks=testing_tasks,
                eval_batch_size=FLAGS.eval_batch_size)

        evaluator = evaluation.get_evaluator(FLAGS.evaluation)
        results = evaluator.evaluate(predictions, FLAGS.test_limit)
        print('Results:')
        print(results)

    if not any([FLAGS.do_train, FLAGS.do_predict, FLAGS.do_test]):
        logging.error(
            'Please specify at least one of --do_train, --do_predict, or --do_test'
        )
예제 #21
0
def autolog(importance_types=["weight"]):  # pylint: disable=W0102
    """
    Enables automatic logging from XGBoost to MLflow. Logs the following.

    - parameters specified in `xgboost.train`_.
    - metrics on each iteration (if ``evals`` specified).
    - metrics at the best iteration (if ``early_stopping_rounds`` specified).
    - feature importance as JSON files and plots.
    - trained model.

    Note that the `scikit-learn API`_ is not supported.

    :param importance_types: importance types to log.

    """
    import xgboost
    import numpy as np

    @gorilla.patch(xgboost)
    def train(*args, **kwargs):
        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                eval_results.append(dict(env.evaluation_result_list))

            return callback

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            features = np.array(features)
            importance = np.array(importance)
            indices = np.argsort(importance)
            features = features[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align="center", height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel("Importance")
            ax.set_title("Feature Importance ({})".format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(
                    tmpdir, "feature_importance_{}.png".format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        original = gorilla.get_original_attribute(xgboost, "train")

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs["params"]
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = [
            "params",
            "dtrain",
            "evals",
            "obj",
            "feval",
            "evals_result",
            "xgb_model",
            "callbacks",
            "learning_rates",
        ]
        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index("callbacks")
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif "callbacks" in kwargs and kwargs["callbacks"] is not None:
            kwargs["callbacks"] += [callback]
        else:
            kwargs["callbacks"] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index("early_stopping_rounds")
        early_stopping = (num_pos_args >= early_stopping_index + 1
                          or "early_stopping_rounds" in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, "stopped_iteration",
                           len(eval_results) - 1)
            try_mlflow_log(mlflow.log_metric, "best_iteration",
                           model.best_iteration)
            try_mlflow_log(mlflow.log_metrics,
                           eval_results[model.best_iteration],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in importance_types:
            imp = model.get_score(importance_type=imp_type)
            features, importance = zip(*imp.items())
            try:
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:  # pylint: disable=broad-except
                _logger.exception(
                    "Failed to log feature importance plot. LightGBM autologging "
                    "will ignore the failure and continue. Exception: ")

            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(
                    tmpdir, "feature_importance_{}.json".format(imp_type))
                with open(filepath, "w") as f:
                    json.dump(imp, f)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        try_mlflow_log(log_model, model, artifact_path="model")

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(xgboost, "train", train, settings=settings))
예제 #22
0
def autolog():
    """
    Enable automatic logging from Gluon to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.
    """
    class __MLflowGluonCallback(EpochEnd, TrainEnd, TrainBegin):
        def __init__(self):
            self.current_epoch = 0

        def epoch_end(self, estimator, *args, **kwargs):
            logs = {}
            for metric in estimator.train_metrics:
                metric_name, metric_val = metric.get()
                logs[metric_name] = metric_val
            for metric in estimator.val_metrics:
                metric_name, metric_val = metric.get()
                logs[metric_name] = metric_val
            try_mlflow_log(mlflow.log_metrics, logs, step=self.current_epoch)
            self.current_epoch += 1

        def train_begin(self, estimator, *args, **kwargs):
            try_mlflow_log(mlflow.log_param, "num_layers", len(estimator.net))
            if estimator.max_epoch is not None:
                try_mlflow_log(mlflow.log_param, "epochs", estimator.max_epoch)
            if estimator.max_batch is not None:
                try_mlflow_log(mlflow.log_param, "batches",
                               estimator.max_batch)
            try_mlflow_log(mlflow.log_param, "optimizer_name",
                           type(estimator.trainer.optimizer).__name__)
            if hasattr(estimator.trainer.optimizer, "lr"):
                try_mlflow_log(mlflow.log_param, "learning_rate",
                               estimator.trainer.optimizer.lr)
            if hasattr(estimator.trainer.optimizer, "epsilon"):
                try_mlflow_log(mlflow.log_param, "epsilon",
                               estimator.trainer.optimizer.epsilon)

        def train_end(self, estimator, *args, **kwargs):
            if isinstance(estimator.net, HybridSequential):
                try_mlflow_log(log_model, estimator.net, artifact_path="model")

    @gorilla.patch(Estimator)
    def fit(self, *args, **kwargs):
        if not mlflow.active_run():
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(Estimator, "fit")
        if len(args) >= 4:
            l = list(args)
            l[3] += [__MLflowGluonCallback()]
            args = tuple(l)
        elif "event_handlers" in kwargs:
            kwargs["event_handlers"] += [__MLflowGluonCallback()]
        else:
            kwargs["event_handlers"] = [__MLflowGluonCallback()]
        result = original(self, *args, **kwargs)
        if auto_end_run:
            mlflow.end_run()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(Estimator, "fit", fit, settings=settings))
예제 #23
0
def autolog():
    """
    Enables autologging for scikit-learn estimators.

    **When is autologging performed?**
      Autologging is performed when you call:

      - ``estimator.fit()``
      - ``estimator.fit_predict()``
      - ``estimator.fit_transform()``

    **Logged information**
      **Parameters**
        - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params``
          is called with ``deep=True``. This means when you fit a meta estimator that chains
          a series of estimators, the parameters of these child estimators are also logged.

      **Metrics**
        - A training score obtained by ``estimator.score``. Note that the training score is
          computed using parameters given to ``fit()``.
        - Common metrics for classifier:

          - `precision score`_

          .. _precision score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html

          - `recall score`_

          .. _recall score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html

          - `f1 score`_

          .. _f1 score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html

          - `accuracy score`_

          .. _accuracy score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html

          If the classifier has method ``predict_proba``, we additionally log:

          - `log loss`_

          .. _log loss:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html

          - `roc auc score`_

          .. _roc auc score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html

        - Common metrics for regressor:

          - `mean squared error`_

          .. _mean squared error:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html

          - root mean squared error

          - `mean absolute error`_

          .. _mean absolute error:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html

          - `r2 score`_

          .. _r2 score:
              https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html

      **Tags**
        - An estimator class name (e.g. "LinearRegression").
        - A fully qualified estimator class name
          (e.g. "sklearn.linear_model._base.LinearRegression").

      **Artifacts**
        - A fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`).

    **How does autologging work for meta estimators?**
      When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls
      ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent
      ``fit()`` calls.

      **Parameter search**
          In addition to recording the information discussed above, autologging for parameter
          search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs
          with metrics for each set of explored parameters, as well as artifacts and parameters
          for the best model (if available).

    **Supported estimators**
      - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators).
      - `Pipeline`_
      - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_)

    .. _sklearn.utils.all_estimators:
        https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html

    .. _Pipeline:
        https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html

    .. _GridSearchCV:
        https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

    .. _RandomizedSearchCV:
        https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html

    **Example**

    `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_

    .. code-block:: python

        from pprint import pprint
        import numpy as np
        from sklearn.linear_model import LinearRegression
        import mlflow

        def fetch_logged_data(run_id):
            client = mlflow.tracking.MlflowClient()
            data = client.get_run(run_id).data
            tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")}
            artifacts = [f.path for f in client.list_artifacts(run_id, "model")]
            return data.params, data.metrics, tags, artifacts

        # enable autologging
        mlflow.sklearn.autolog()

        # prepare training data
        X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
        y = np.dot(X, np.array([1, 2])) + 3

        # train a model
        model = LinearRegression()
        with mlflow.start_run() as run:
            model.fit(X, y)

        # fetch logged data
        params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id)

        pprint(params)
        # {'copy_X': 'True',
        #  'fit_intercept': 'True',
        #  'n_jobs': 'None',
        #  'normalize': 'False'}

        pprint(metrics)
        # {'training_score': 1.0,
           'training_mae': 2.220446049250313e-16,
           'training_mse': 1.9721522630525295e-31,
           'training_r2_score': 1.0,
           'training_rmse': 4.440892098500626e-16}

        pprint(tags)
        # {'estimator_class': 'sklearn.linear_model._base.LinearRegression',
        #  'estimator_name': 'LinearRegression'}

        pprint(artifacts)
        # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl']
    """
    import pandas as pd
    import sklearn

    from mlflow.models import infer_signature
    from mlflow.sklearn.utils import (
        _MIN_SKLEARN_VERSION,
        _is_supported_version,
        _chunk_dict,
        _get_args_for_score,
        _log_specialized_estimator_content,
        _get_Xy,
        _all_estimators,
        _truncate_dict,
        _get_arg_names,
        _get_estimator_info_tags,
        _get_meta_estimators_for_autologging,
        _is_parameter_search_estimator,
        _log_parameter_search_results_as_artifact,
        _create_child_runs_for_parameter_search,
    )
    from mlflow.tracking.context import registry as context_registry
    from mlflow.utils.validation import (
        MAX_PARAMS_TAGS_PER_BATCH,
        MAX_PARAM_VAL_LENGTH,
        MAX_ENTITY_KEY_LENGTH,
    )

    if not _is_supported_version():
        warnings.warn(
            "Autologging utilities may not work properly on scikit-learn < {} "
            .format(_MIN_SKLEARN_VERSION) +
            "(current version: {})".format(sklearn.__version__),
            stacklevel=2,
        )

    def fit_mlflow(self, func_name, *args, **kwargs):
        should_start_run = mlflow.active_run() is None
        if should_start_run:
            try_mlflow_log(mlflow.start_run)

        _log_pretraining_metadata(self, *args, **kwargs)

        original_fit = gorilla.get_original_attribute(self, func_name)
        try:
            fit_output = original_fit(*args, **kwargs)
        except Exception as e:
            if should_start_run:
                try_mlflow_log(mlflow.end_run,
                               RunStatus.to_string(RunStatus.FAILED))

            raise e

        _log_posttraining_metadata(self, *args, **kwargs)

        if should_start_run:
            try_mlflow_log(mlflow.end_run)

        return fit_output

    def _log_pretraining_metadata(estimator, *args, **kwargs):  # pylint: disable=unused-argument
        """
        Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training.
        This is intended to be invoked within a patched scikit-learn training routine
        (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
        MLflow run that can be referenced via the fluent Tracking API.

        :param estimator: The scikit-learn estimator for which to log metadata.
        :param args: The arguments passed to the scikit-learn training routine (e.g.,
                     `fit()`, `fit_transform()`, ...).
        :param kwargs: The keyword arguments passed to the scikit-learn training routine.
        """
        # Deep parameter logging includes parameters from children of a given
        # estimator. For some meta estimators (e.g., pipelines), recording
        # these parameters is desirable. For parameter search estimators,
        # however, child estimators act as seeds for the parameter search
        # process; accordingly, we avoid logging initial, untuned parameters
        # for these seed estimators.
        should_log_params_deeply = not _is_parameter_search_estimator(
            estimator)
        # Chunk model parameters to avoid hitting the log_batch API limit
        for chunk in _chunk_dict(
                estimator.get_params(deep=should_log_params_deeply),
                chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
        ):
            truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH,
                                       MAX_PARAM_VAL_LENGTH)
            try_mlflow_log(mlflow.log_params, truncated)

        try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator))

    def _log_posttraining_metadata(estimator, *args, **kwargs):
        """
        Records metadata for a scikit-learn estimator after training has completed.
        This is intended to be invoked within a patched scikit-learn training routine
        (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
        MLflow run that can be referenced via the fluent Tracking API.

        :param estimator: The scikit-learn estimator for which to log metadata.
        :param args: The arguments passed to the scikit-learn training routine (e.g.,
                     `fit()`, `fit_transform()`, ...).
        :param kwargs: The keyword arguments passed to the scikit-learn training routine.
        """
        if hasattr(estimator, "score"):
            try:
                score_args = _get_args_for_score(estimator.score,
                                                 estimator.fit, args, kwargs)
                training_score = estimator.score(*score_args)
            except Exception as e:  # pylint: disable=broad-except
                msg = (
                    estimator.score.__qualname__ +
                    " failed. The 'training_score' metric will not be recorded. Scoring error: "
                    + str(e))
                _logger.warning(msg)
            else:
                try_mlflow_log(mlflow.log_metric, "training_score",
                               training_score)

        # log common metrics and artifacts for estimators (classifier, regressor)
        _log_specialized_estimator_content(estimator,
                                           mlflow.active_run().info.run_id,
                                           args, kwargs)

        input_example = None
        signature = None
        if hasattr(estimator, "predict"):
            try:
                # Fetch an input example using the first several rows of the array-like
                # training data supplied to the training routine (e.g., `fit()`)
                SAMPLE_ROWS = 5
                fit_arg_names = _get_arg_names(estimator.fit)
                X_var_name, y_var_name = fit_arg_names[:2]
                input_example = _get_Xy(args, kwargs, X_var_name,
                                        y_var_name)[0][:SAMPLE_ROWS]

                model_output = estimator.predict(input_example)
                signature = infer_signature(input_example, model_output)
            except Exception as e:  # pylint: disable=broad-except
                input_example = None
                msg = "Failed to infer an input example and model signature: " + str(
                    e)
                _logger.warning(msg)

        try_mlflow_log(
            log_model,
            estimator,
            artifact_path="model",
            signature=signature,
            input_example=input_example,
        )

        if _is_parameter_search_estimator(estimator):
            if hasattr(estimator, "best_estimator_"):
                try_mlflow_log(
                    log_model,
                    estimator.best_estimator_,
                    artifact_path="best_estimator",
                    signature=signature,
                    input_example=input_example,
                )

            if hasattr(estimator, "best_params_"):
                best_params = {
                    "best_{param_name}".format(param_name=param_name):
                    param_value
                    for param_name, param_value in
                    estimator.best_params_.items()
                }
                try_mlflow_log(mlflow.log_params, best_params)

            if hasattr(estimator, "cv_results_"):
                try:
                    # Fetch environment-specific tags (e.g., user and source) to ensure that lineage
                    # information is consistent with the parent run
                    environment_tags = context_registry.resolve_tags()
                    _create_child_runs_for_parameter_search(
                        cv_estimator=estimator,
                        parent_run=mlflow.active_run(),
                        child_tags=environment_tags,
                    )
                except Exception as e:  # pylint: disable=broad-except

                    msg = (
                        "Encountered exception during creation of child runs for parameter search."
                        " Child runs may be missing. Exception: {}".format(
                            str(e)))
                    _logger.warning(msg)

                try:
                    cv_results_df = pd.DataFrame.from_dict(
                        estimator.cv_results_)
                    _log_parameter_search_results_as_artifact(
                        cv_results_df,
                        mlflow.active_run().info.run_id)
                except Exception as e:  # pylint: disable=broad-except

                    msg = (
                        "Failed to log parameter search results as an artifact."
                        " Exception: {}".format(str(e)))
                    _logger.warning(msg)

    def patched_fit(self, func_name, *args, **kwargs):
        """
        To be applied to a sklearn model class that defines a `fit` method and
        inherits from `BaseEstimator` (thereby defining the `get_params()` method)
        """
        with _SklearnTrainingSession(clazz=self.__class__,
                                     allow_children=False) as t:
            if t.should_log():
                return fit_mlflow(self, func_name, *args, **kwargs)
            else:
                original_fit = gorilla.get_original_attribute(self, func_name)
                return original_fit(*args, **kwargs)

    def create_patch_func(func_name):
        def f(self, *args, **kwargs):
            return patched_fit(self, func_name, *args, **kwargs)

        return f

    patch_settings = gorilla.Settings(allow_hit=True, store_hit=True)
    _, estimators_to_patch = zip(*_all_estimators())
    # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected
    # for patching if they are not already included in the output of `all_estimators()`
    estimators_to_patch = set(estimators_to_patch).union(
        set(_get_meta_estimators_for_autologging()))

    for class_def in estimators_to_patch:
        for func_name in ["fit", "fit_transform", "fit_predict"]:
            if hasattr(class_def, func_name):
                original = getattr(class_def, func_name)

                # A couple of estimators use property methods to return fitting functions,
                # rather than defining the fitting functions on the estimator class directly.
                #
                # Example: https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/neighbors/_lof.py#L183  # noqa
                #
                # We currently exclude these property fitting methods from patching because
                # it's challenging to patch them correctly.
                #
                # Excluded fitting methods:
                # - sklearn.cluster._agglomerative.FeatureAgglomeration.fit_predict
                # - sklearn.neighbors._lof.LocalOutlierFactor.fit_predict
                #
                # You can list property fitting methods by inserting "print(class_def, func_name)"
                # in the if clause below.
                if isinstance(original, property):
                    continue

                patch_func = create_patch_func(func_name)
                # TODO(harupy): Package this wrap & patch routine into a utility function so we can
                # reuse it in other autologging integrations.
                # preserve original function attributes
                patch_func = functools.wraps(original)(patch_func)
                patch = gorilla.Patch(class_def,
                                      func_name,
                                      patch_func,
                                      settings=patch_settings)
                gorilla.apply(patch)
예제 #24
0
def autolog():
    """
    Enables automatic logging from LightGBM to MLflow. Logs the following.

    - parameters specified in `lightgbm.train`_.
    - metrics on each iteration (if ``valid_sets`` specified).
    - metrics at the best iteration (if ``early_stopping_rounds`` specified).
    - feature importance (both "split" and "gain") as JSON files and plots.
    - trained model.

    Note that the `scikit-learn API`_ is not supported.
    """
    import lightgbm
    import numpy as np

    @gorilla.patch(lightgbm)
    def train(*args, **kwargs):

        def record_eval_results(eval_results):
            """
            Create a callback function that records evaluation results.
            """
            def callback(env):
                res = {}
                for data_name, eval_name, value, _ in env.evaluation_result_list:
                    key = data_name + '-' + eval_name
                    res[key] = value

                eval_results.append(res)
            return callback

        def log_feature_importance_plot(features, importance, importance_type):
            """
            Log feature importance plot.
            """
            import matplotlib.pyplot as plt

            indices = np.argsort(importance)
            features = np.array(features)[indices]
            importance = importance[indices]
            num_features = len(features)

            # If num_features > 10, increase the figure height to prevent the plot
            # from being too dense.
            w, h = [6.4, 4.8]  # matplotlib's default figure size
            h = h + 0.1 * num_features if num_features > 10 else h
            fig, ax = plt.subplots(figsize=(w, h))

            yloc = np.arange(num_features)
            ax.barh(yloc, importance, align='center', height=0.5)
            ax.set_yticks(yloc)
            ax.set_yticklabels(features)
            ax.set_xlabel('Importance')
            ax.set_title('Feature Importance ({})'.format(importance_type))
            fig.tight_layout()

            tmpdir = tempfile.mkdtemp()
            try:
                # pylint: disable=undefined-loop-variable
                filepath = os.path.join(tmpdir, 'feature_importance_{}.png'.format(imp_type))
                fig.savefig(filepath)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                plt.close(fig)
                shutil.rmtree(tmpdir)

        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        original = gorilla.get_original_attribute(lightgbm, 'train')

        # logging booster params separately via mlflow.log_params to extract key/value pairs
        # and make it easier to compare them across runs.
        params = args[0] if len(args) > 0 else kwargs['params']
        try_mlflow_log(mlflow.log_params, params)

        unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval',
                           'init_model', 'evals_result', 'learning_rates', 'callbacks']

        log_fn_args_as_params(original, args, kwargs, unlogged_params)

        all_arg_names = inspect.getargspec(original)[0]  # pylint: disable=W1505
        num_pos_args = len(args)

        # adding a callback that records evaluation results.
        eval_results = []
        callbacks_index = all_arg_names.index('callbacks')
        callback = record_eval_results(eval_results)
        if num_pos_args >= callbacks_index + 1:
            tmp_list = list(args)
            tmp_list[callbacks_index] += [callback]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs and kwargs['callbacks'] is not None:
            kwargs['callbacks'] += [callback]
        else:
            kwargs['callbacks'] = [callback]

        # training model
        model = original(*args, **kwargs)

        # logging metrics on each iteration.
        for idx, metrics in enumerate(eval_results):
            try_mlflow_log(mlflow.log_metrics, metrics, step=idx)

        # If early_stopping_rounds is present, logging metrics at the best iteration
        # as extra metrics with the max step + 1.
        early_stopping_index = all_arg_names.index('early_stopping_rounds')
        early_stopping = (num_pos_args >= early_stopping_index + 1 or
                          'early_stopping_rounds' in kwargs)
        if early_stopping:
            extra_step = len(eval_results)
            try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results))
            # best_iteration is set even if training does not stop early.
            try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration)
            # iteration starts from 1 in LightGBM.
            try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1],
                           step=extra_step)

        # logging feature importance as artifacts.
        for imp_type in ['split', 'gain']:
            features = model.feature_name()
            importance = model.feature_importance(importance_type=imp_type)
            try:
                log_feature_importance_plot(features, importance, imp_type)
            except Exception:  # pylint: disable=broad-except
                _logger.exception('Failed to log feature importance plot. LightGBM autologging '
                                  'will ignore the failure and continue. Exception: ')

            imp = {ft: imp for ft, imp in zip(features, importance.tolist())}
            tmpdir = tempfile.mkdtemp()
            try:
                filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type))
                with open(filepath, 'w') as f:
                    json.dump(imp, f, indent=2)
                try_mlflow_log(mlflow.log_artifact, filepath)
            finally:
                shutil.rmtree(tmpdir)

        try_mlflow_log(log_model, model, artifact_path='model')

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)
        return model

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(lightgbm, 'train', train, settings=settings))
예제 #25
0
        if not trial.last_result:
            continue
        best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid'])
    print('iter',
          self._iteration,
          'top1_acc=%.3f' % best_top1_acc,
          cnts,
          end='\r')
    return original(self)


patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner,
                      'step',
                      step_w_log,
                      settings=gorilla.Settings(allow_hit=True))
gorilla.apply(patch)

logger = get_logger('Fast AutoAugment')


def _get_path(dataset, model, tag):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)),
                        'FastAutoAugment',
                        'models/%s_%s_%s.pt' % (dataset, model, tag))  # TODO


# @ray.remote(num_gpus=4, max_calls=1) #TODO: change to num_gpus=1 ???
# @ray.remote
def train_model(config,
                dataroot,
                augment,
예제 #26
0
def autolog(every_n_iter=100):
    # pylint: disable=E0611
    """
    Enable automatic logging from TensorFlow to MLflow. If applicable,
    model checkpoints are logged as artifacts to a 'models' directory, along
    with any TensorBoard log data.

    Refer to the tracking documentation for
    information on what is logged with different TensorFlow workflows.

    :param every_n_iter: The frequency with which metrics should be logged.
                                  Defaults to 100. Ex: a value of 100 will log metrics
                                  at step 0, 100, 200, etc.

    """
    global _LOG_EVERY_N_STEPS
    _LOG_EVERY_N_STEPS = every_n_iter

    if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'):
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    try:
        from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
        from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
        from tensorflow.python.saved_model import tag_constants
        from tensorflow.python.summary.writer.writer import FileWriter
    except ImportError:
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_saved_model(self, *args, **kwargs):
        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_saved_model')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        return serialized

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_savedmodel(self, *args, **kwargs):
        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, 'export_savedmodel')
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(log_model,
                       tf_saved_model_dir=serialized.decode('utf-8'),
                       tf_meta_graph_tags=[tag_constants.SERVING],
                       tf_signature_def_key='predict',
                       artifact_path='model')
        return serialized

    @gorilla.patch(tensorflow.keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                  'fit')
        # Checking if the 'callback' argument of fit() is set
        if len(args) >= 6:
            l = list(args)
            l[5], log_dir = _setup_callbacks(l[5])
            args = tuple(l)
        elif 'callbacks' in kwargs:
            kwargs['callbacks'], log_dir = _setup_callbacks(
                kwargs['callbacks'])
        else:
            kwargs['callbacks'], log_dir = _setup_callbacks([])
        result = original(self, *args, **kwargs)
        _flush_queue()
        _log_artifacts_with_warning(local_dir=log_dir,
                                    artifact_path='tensorboard_logs')
        shutil.rmtree(log_dir)
        return result

    @gorilla.patch(EventFileWriter)
    def add_event(self, event):
        _log_event(event)
        original = gorilla.get_original_attribute(EventFileWriter, 'add_event')
        return original(self, event)

    @gorilla.patch(FileWriter)
    def add_summary(self, *args, **kwargs):
        original = gorilla.get_original_attribute(FileWriter, 'add_summary')
        result = original(self, *args, **kwargs)
        _flush_queue()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patches = [
        gorilla.Patch(EventFileWriter,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(EventFileWriterV2,
                      'add_event',
                      add_event,
                      settings=settings),
        gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_saved_model',
                      export_saved_model,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      'export_savedmodel',
                      export_savedmodel,
                      settings=settings),
        gorilla.Patch(FileWriter,
                      'add_summary',
                      add_summary,
                      settings=settings),
    ]

    for x in patches:
        gorilla.apply(x)
예제 #27
0
from splicemachine.mlflow_support import *
from splicemachine.mlflow_support.mlflow_support import _GORILLA_SETTINGS
import gorilla
import mlflow.pyfunc


def _log_model(name='pyfunc_model', **flavor_options):
    model = None
    if 'python_model' in flavor_options:
        model = flavor_options.pop('python_model')
    mlflow.log_model(model, name=name, model_lib='pyfunc', **flavor_options)


gorilla.apply(
    gorilla.Patch(mlflow.pyfunc,
                  _log_model.__name__.lstrip('_'),
                  _log_model,
                  settings=_GORILLA_SETTINGS))
예제 #28
0
    def test_apply_patch_no_hit(self):
        name = 'dummy'
        settings = gorilla.Settings()

        source_paths = [''] + _list_attribute_paths(_frommodule)
        target_paths = _list_attribute_paths(_tomodule)

        branch_count = 0

        # Retrieve the destinations in two passes instead of directly using a
        # set in order to preserve the ordering.
        seen = set()
        destination_paths = [
            _split_attribute_path(path)[0] for path in target_paths
        ]
        destination_paths = [
            path for path in destination_paths
            if path not in seen and seen.add(path) is None
        ]
        combinations = itertools.product(source_paths, destination_paths)
        for source_path, destination_path in combinations:
            self.setUp()

            destination = _get_attribute_from_path(_tomodule, destination_path)
            obj = _get_attribute_from_path(_frommodule, source_path)
            patch = gorilla.Patch(destination, name, obj, settings=settings)
            gorilla.apply(patch)
            self.assertIs(
                destination,
                _get_attribute_from_path(_tomodule, destination_path))

            result = gorilla.get_attribute(destination, name)
            self.assertIs(result, obj)

            if source_path == '':
                branch_count += 1
                self.assertEqual(result.global_variable,
                                 "frommodule.global_variable")
                self.assertEqual(
                    result.function(),
                    "frommodule.function (frommodule.Class.STATIC_VALUE)")
                self.assertEqual(result.Class.STATIC_VALUE,
                                 "frommodule.Class.STATIC_VALUE")
                self.assertEqual(result.Class.Inner.STATIC_VALUE,
                                 "frommodule.Class.Inner.STATIC_VALUE")
                self.assertEqual(result.Parent.STATIC_VALUE,
                                 "frommodule.Parent.STATIC_VALUE")
                self.assertEqual(result.Child.STATIC_VALUE,
                                 "frommodule.Parent.STATIC_VALUE")
            elif source_path == 'global_variable':
                branch_count += 1
                self.assertEqual(result, "frommodule.global_variable")
            elif source_path == 'function':
                branch_count += 1
                self.assertEqual(
                    result(),
                    "frommodule.function (frommodule.Class.STATIC_VALUE)")
            elif source_path == 'unbound_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.unbound_method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)"
                        % ((_CLS_REFERENCES[destination_path], ) * 2))
            elif source_path == 'unbound_class_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.unbound_class_method (tomodule.%s.STATIC_VALUE)"
                        % (_CLS_REFERENCES[destination_path], ))
            elif source_path == 'unbound_static_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.unbound_static_method (frommodule.Class.STATIC_VALUE)"
                    )
            elif source_path == 'Class':
                branch_count += 1
                self.assertEqual(result.STATIC_VALUE,
                                 "frommodule.Class.STATIC_VALUE")
            elif source_path == 'Class.Inner':
                branch_count += 1
                self.assertEqual(result.STATIC_VALUE,
                                 "frommodule.Class.Inner.STATIC_VALUE")
            elif source_path == 'Class.value':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    instance = destination()
                    self.assertEqual(
                        getattr(instance, name),
                        "frommodule.Class.value.getter (tomodule.%s.instance_value)"
                        % (_CLS_REFERENCES[destination_path], ))
                    setattr(instance, name, 'hello')
                    self.assertEqual(getattr(instance, name),
                                     "frommodule.Class.value.getter (hello)")
            elif source_path == 'Class.method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.Class.method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)"
                        % ((_CLS_REFERENCES[destination_path], ) * 2))
            elif source_path == 'Class.class_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.Class.class_method (tomodule.%s.STATIC_VALUE)"
                        % (_CLS_REFERENCES[destination_path], ))
            elif source_path == 'Class.static_method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination, name)(),
                        "frommodule.Class.static_method (frommodule.Class.STATIC_VALUE)"
                    )
            elif source_path == 'Parent':
                branch_count += 1
                self.assertEqual(result.__slots__,
                                 ('instance_value', 'parent_value', 'to_value',
                                  'from_value'))
                self.assertEqual(result().parent_value,
                                 "frommodule.Parent.parent_value")
                self.assertEqual(
                    result().method(),
                    "frommodule.Parent.method (frommodule.Parent.instance_value)"
                )
            elif source_path == 'Parent.method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.Parent.method (tomodule.%s.instance_value)"
                        % (_CLS_REFERENCES[destination_path], ))
            elif source_path == 'Child':
                branch_count += 1
                self.assertEqual(result.__slots__, ('child_value', ))
                self.assertEqual(result().parent_value,
                                 "frommodule.Parent.parent_value")
                self.assertEqual(result().child_value,
                                 "frommodule.Child.child_value")
                self.assertEqual(
                    result().method(),
                    "frommodule.Parent.method (frommodule.Parent.instance_value)"
                )
            elif source_path == 'Child.method':
                branch_count += 1
                if destination_path in _CLS_REFERENCES:
                    branch_count += 1
                    self.assertEqual(
                        getattr(destination(), name)(),
                        "frommodule.Parent.method (tomodule.%s.instance_value)"
                        % (_CLS_REFERENCES[destination_path], ))

            self.tearDown()

        # Make sure that all test branches are covered.
        self.assertEqual(branch_count, 116)
예제 #29
0
def autolog(every_n_iter=100):
    # pylint: disable=E0611
    """
    Enables automatic logging from TensorFlow to MLflow.
    Note that autologging for ``tf.keras`` is handled by :py:func:`mlflow.tensorflow.autolog`,
    not :py:func:`mlflow.keras.autolog`.
    As an example, try running the
    `TensorFlow examples <https://github.com/mlflow/mlflow/tree/master/examples/tensorflow>`_.

    For each TensorFlow module, autologging captures the following information:

    **tf.keras**
     - **Metrics** and **Parameters**

      - Training loss; validation loss; user-specified metrics
      - ``fit()`` or ``fit_generator()`` parameters; optimizer name; learning rate; epsilon

     - **Artifacts**

      - Model summary on training start
      - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (Keras model)
      - TensorBoard logs on training end

    **tf.keras.callbacks.EarlyStopping**
     - **Metrics** and **Parameters**

      - Metrics from the ``EarlyStopping`` callbacks: ``stopped_epoch``, ``restored_epoch``,
        ``restore_best_weight``, etc
      - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``:
        ``min_delta``, ``patience``, ``baseline``, ``restore_best_weights``, etc

    **tf.estimator**
     - **Metrics** and **Parameters**

      - TensorBoard metrics: ``average_loss``, ``loss``, etc
      - Parameters ``steps`` and ``max_steps``

     - **Artifacts**

      - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (TF saved model) on call
        to ``tf.estimator.export_saved_model``

    **TensorFlow Core**
     - **Metrics**

      - All ``tf.summary.scalar`` calls

    Refer to the autologging tracking documentation for more
    information on `TensorFlow workflows
    <https://www.mlflow.org/docs/latest/tracking.html#tensorflow-and-keras-experimental>`_.

    :param every_n_iter: The frequency with which metrics should be logged.
                                  Defaults to 100. Ex: a value of 100 will log metrics
                                  at step 0, 100, 200, etc.
    """
    import tensorflow

    global _LOG_EVERY_N_STEPS
    _LOG_EVERY_N_STEPS = every_n_iter

    if LooseVersion(tensorflow.__version__) < LooseVersion("1.12"):
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    try:
        from tensorflow.python.summary.writer.event_file_writer import EventFileWriter
        from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2
        from tensorflow.python.saved_model import tag_constants
        from tensorflow.python.summary.writer.writer import FileWriter
    except ImportError:
        warnings.warn("Could not log to MLflow. Only TensorFlow versions" +
                      "1.12 <= v <= 2.0.0 are supported.")
        return

    @contextmanager
    def _manage_active_run():
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            global _AUTOLOG_RUN_ID
            if mlflow.active_run(
            ) is not None:  # defensive check in case `mlflow.start_run` fails
                _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id
        yield mlflow.active_run()
        if mlflow.active_run() is not None and mlflow.active_run(
        ).info.run_id == _AUTOLOG_RUN_ID:
            try_mlflow_log(mlflow.end_run)

    @gorilla.patch(tensorflow.estimator.Estimator)
    def train(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(
                tensorflow.estimator.Estimator, "train")

            # Checking step and max_step parameters for logging
            if len(args) >= 3:
                try_mlflow_log(mlflow.log_param, "steps", args[2])
                if len(args) >= 4:
                    try_mlflow_log(mlflow.log_param, "max_steps", args[3])
            if "steps" in kwargs:
                try_mlflow_log(mlflow.log_param, "steps", kwargs["steps"])
            if "max_steps" in kwargs:
                try_mlflow_log(mlflow.log_param, "max_steps",
                               kwargs["max_steps"])

            result = original(self, *args, **kwargs)

            return result

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_saved_model(self, *args, **kwargs):
        auto_end = False
        if not mlflow.active_run():
            global _AUTOLOG_RUN_ID
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, "export_saved_model")
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(
            log_model,
            tf_saved_model_dir=serialized.decode("utf-8"),
            tf_meta_graph_tags=[tag_constants.SERVING],
            tf_signature_def_key="predict",
            artifact_path="model",
        )
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id
                == _AUTOLOG_RUN_ID) or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    @gorilla.patch(tensorflow.estimator.Estimator)
    def export_savedmodel(self, *args, **kwargs):
        auto_end = False
        global _AUTOLOG_RUN_ID
        if not mlflow.active_run():
            if _AUTOLOG_RUN_ID:
                try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID)
            else:
                try_mlflow_log(mlflow.start_run)
                auto_end = True

        original = gorilla.get_original_attribute(
            tensorflow.estimator.Estimator, "export_savedmodel")
        serialized = original(self, *args, **kwargs)
        try_mlflow_log(
            log_model,
            tf_saved_model_dir=serialized.decode("utf-8"),
            tf_meta_graph_tags=[tag_constants.SERVING],
            tf_signature_def_key="predict",
            artifact_path="model",
        )
        if (mlflow.active_run() is not None and mlflow.active_run().info.run_id
                == _AUTOLOG_RUN_ID) or auto_end:
            try_mlflow_log(mlflow.end_run)
        return serialized

    def _early_stop_check(callbacks):
        for callback in callbacks:
            if isinstance(callback, tensorflow.keras.callbacks.EarlyStopping):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    "monitor": callback.monitor,
                    "min_delta": callback.min_delta,
                    "patience": callback.patience,
                    "baseline": callback.baseline,
                    "restore_best_weights": callback.restore_best_weights,
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return callback.stopped_epoch, callback.restore_best_weights, callback.patience
        except Exception:  # pylint: disable=W0703
            return None

    def _log_early_stop_callback_metrics(callback, history):
        if callback:
            callback_attrs = _get_early_stop_callback_attrs(callback)
            if callback_attrs is None:
                return
            stopped_epoch, restore_best_weights, patience = callback_attrs
            try_mlflow_log(mlflow.log_metric, "stopped_epoch", stopped_epoch)
            # Weights are restored only if early stopping occurs
            if stopped_epoch != 0 and restore_best_weights:
                restored_epoch = stopped_epoch - max(1, patience)
                try_mlflow_log(mlflow.log_metric, "restored_epoch",
                               restored_epoch)
                restored_metrics = {
                    key: history.history[key][restored_epoch]
                    for key in history.history.keys()
                }
                # Metrics are logged as 'epoch_loss' and 'epoch_acc' in TF 1.X
                if LooseVersion(
                        tensorflow.__version__) < LooseVersion("2.0.0"):
                    if "loss" in restored_metrics:
                        restored_metrics["epoch_loss"] = restored_metrics.pop(
                            "loss")
                    if "acc" in restored_metrics:
                        restored_metrics["epoch_acc"] = restored_metrics.pop(
                            "acc")
                # Checking that a metric history exists
                metric_key = next(iter(history.history), None)
                if metric_key is not None:
                    last_epoch = len(history.history[metric_key])
                    try_mlflow_log(mlflow.log_metrics,
                                   restored_metrics,
                                   step=last_epoch)

    @gorilla.patch(tensorflow.keras.Model)
    def fit(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      "fit")

            unlogged_params = [
                "self", "x", "y", "callbacks", "validation_data", "verbose"
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)
            early_stop_callback = None

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 6:
                tmp_list = list(args)
                early_stop_callback = _early_stop_check(tmp_list[5])
                tmp_list[5], log_dir = _setup_callbacks(tmp_list[5])
                args = tuple(tmp_list)
            elif "callbacks" in kwargs:
                early_stop_callback = _early_stop_check(kwargs["callbacks"])
                kwargs["callbacks"], log_dir = _setup_callbacks(
                    kwargs["callbacks"])
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([])

            _log_early_stop_callback_params(early_stop_callback)

            history = original(self, *args, **kwargs)

            _log_early_stop_callback_metrics(early_stop_callback, history)

            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path="tensorboard_logs")
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return history

    @gorilla.patch(tensorflow.keras.Model)
    def fit_generator(self, *args, **kwargs):
        with _manage_active_run():
            original = gorilla.get_original_attribute(tensorflow.keras.Model,
                                                      "fit_generator")

            unlogged_params = [
                "self", "generator", "callbacks", "validation_data", "verbose"
            ]

            log_fn_args_as_params(original, args, kwargs, unlogged_params)

            # Checking if the 'callback' argument of fit() is set
            if len(args) >= 5:
                tmp_list = list(args)
                tmp_list[4], log_dir = _setup_callbacks(tmp_list[4])
                args = tuple(tmp_list)
            elif "callbacks" in kwargs:
                kwargs["callbacks"], log_dir = _setup_callbacks(
                    kwargs["callbacks"])
            else:
                kwargs["callbacks"], log_dir = _setup_callbacks([])
            result = original(self, *args, **kwargs)
            _flush_queue()
            _log_artifacts_with_warning(local_dir=log_dir.location,
                                        artifact_path="tensorboard_logs")
            if log_dir.is_temp:
                shutil.rmtree(log_dir.location)

            return result

    @gorilla.patch(EventFileWriter)
    def add_event(self, event):
        _log_event(event)
        original = gorilla.get_original_attribute(EventFileWriter, "add_event")
        return original(self, event)

    @gorilla.patch(FileWriter)
    def add_summary(self, *args, **kwargs):
        original = gorilla.get_original_attribute(FileWriter, "add_summary")
        result = original(self, *args, **kwargs)
        _flush_queue()
        return result

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    patches = [
        gorilla.Patch(EventFileWriter,
                      "add_event",
                      add_event,
                      settings=settings),
        gorilla.Patch(EventFileWriterV2,
                      "add_event",
                      add_event,
                      settings=settings),
        gorilla.Patch(tensorflow.estimator.Estimator,
                      "train",
                      train,
                      settings=settings),
        gorilla.Patch(tensorflow.keras.Model, "fit", fit, settings=settings),
        gorilla.Patch(tensorflow.keras.Model,
                      "fit_generator",
                      fit_generator,
                      settings=settings),
        gorilla.Patch(
            tensorflow.estimator.Estimator,
            "export_saved_model",
            export_saved_model,
            settings=settings,
        ),
        gorilla.Patch(
            tensorflow.estimator.Estimator,
            "export_savedmodel",
            export_savedmodel,
            settings=settings,
        ),
        gorilla.Patch(FileWriter,
                      "add_summary",
                      add_summary,
                      settings=settings),
    ]

    for x in patches:
        gorilla.apply(x)
예제 #30
0
파일: __init__.py 프로젝트: xstex/mlflow
def autolog():
    """
    Enables autologging for scikit-learn estimators.

    **When is autologging performed?**
      Autologging is performed when you call:

      - ``estimator.fit``
      - ``estimator.fit_predict``
      - ``estimator.fit_transform``

    **Logged information**
      **Parameters**
        - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params``
          is called with ``deep=True``. This means when you fit a meta estimator that chains
          a series of estimators, the parameters of these child estimators are also logged.

      **Metrics**
        - A training score obtained by ``estimator.score``. Note that the training score is
          computed using parameters given to ``fit``.

      **Tags**
        - An estimator class name (e.g. "LinearRegression").
        - A fully qualified estimator class name
          (e.g. "sklearn.linear_model._base.LinearRegression").

      **Artifacts**
        - A fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`).

    **How does autologging work for meta estimators?**
      When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit``, it internally calls
      ``fit`` on its child estimators. Autologging does NOT perform logging on these constituent
      ``fit``.

    **Supported estimators**
      All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators).

    .. _sklearn.utils.all_estimators:
        https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html

    .. _Pipeline:
        https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html

    .. _GridSearchCV:
        https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

    **Example**

    .. code-block:: python

        from pprint import pprint
        import numpy as np
        import sklearn.linear_model
        import mlflow

        # enable autologging
        mlflow.sklearn.autolog()

        # prepare training data
        X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
        y = np.dot(X, np.array([1, 2])) + 3

        # train a model
        with mlflow.start_run() as run:
            reg = sklearn.linear_model.LinearRegression().fit(X, y)

        def fetch_logged_data(run_id):
            client = mlflow.tracking.MlflowClient()
            data = client.get_run(run_id).data
            tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")}
            artifacts = [f.path for f in client.list_artifacts(run_id, "model")]

        # fetch logged data
        params, metrics, tags, artifacts = fetch_logged_data(run._info.run_id)

        pprint(params)
        # {'copy_X': 'True',
        #  'fit_intercept': 'True',
        #  'n_jobs': 'None',
        #  'normalize': 'False'}

        pprint(metrics)
        # {'training_score': 1.0}

        pprint(tags)
        # {'estimator_class': 'sklearn.linear_model._base.LinearRegression',
        #  'estimator_name': 'LinearRegression'}

        pprint(artifacts)
        # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl']
    """
    import sklearn
    from mlflow.sklearn.utils import (
        _MIN_SKLEARN_VERSION,
        _is_supported_version,
        _chunk_dict,
        _get_args_for_score,
        _all_estimators,
        _truncate_dict,
    )
    from mlflow.utils.validation import (
        MAX_PARAMS_TAGS_PER_BATCH,
        MAX_PARAM_VAL_LENGTH,
        MAX_ENTITY_KEY_LENGTH,
    )

    if not _is_supported_version():
        warnings.warn(
            "Autologging utilities may not work properly on scikit-learn < {} "
            .format(_MIN_SKLEARN_VERSION) +
            "(current version: {})".format(sklearn.__version__),
            stacklevel=2,
        )

    def fit_mlflow(self, func_name, *args, **kwargs):
        should_start_run = mlflow.active_run() is None
        if should_start_run:
            try_mlflow_log(mlflow.start_run)

        # TODO: We should not log nested estimator parameters for
        # parameter search estimators (GridSearchCV, RandomizedSearchCV)

        # Chunk and truncate model parameters to avoid hitting the log_batch API limit
        for chunk in _chunk_dict(self.get_params(deep=True),
                                 chunk_size=MAX_PARAMS_TAGS_PER_BATCH):
            truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH,
                                       MAX_PARAM_VAL_LENGTH)
            try_mlflow_log(mlflow.log_params, truncated)

        try_mlflow_log(
            mlflow.set_tags,
            {
                "estimator_name":
                self.__class__.__name__,
                "estimator_class":
                self.__class__.__module__ + "." + self.__class__.__name__,
            },
        )

        original_fit = gorilla.get_original_attribute(self, func_name)
        try:
            fit_output = original_fit(*args, **kwargs)
        except Exception as e:
            if should_start_run:
                try_mlflow_log(mlflow.end_run,
                               RunStatus.to_string(RunStatus.FAILED))

            raise e

        if hasattr(self, "score"):
            try:
                score_args = _get_args_for_score(self.score, self.fit, args,
                                                 kwargs)
                training_score = self.score(*score_args)
            except Exception as e:  # pylint: disable=broad-except
                msg = (
                    self.score.__qualname__ +
                    " failed. The 'training_score' metric will not be recorded. Scoring error: "
                    + str(e))
                _logger.warning(msg)
            else:
                try_mlflow_log(mlflow.log_metric, "training_score",
                               training_score)

        try_mlflow_log(log_model, self, artifact_path="model")

        if should_start_run:
            try_mlflow_log(mlflow.end_run)

        return fit_output

    def patched_fit(self, func_name, *args, **kwargs):
        """
        To be applied to a sklearn model class that defines a `fit` method and
        inherits from `BaseEstimator` (thereby defining the `get_params()` method)
        """
        with _SklearnTrainingSession(clazz=self.__class__,
                                     allow_children=False) as t:
            if t.should_log():
                return fit_mlflow(self, func_name, *args, **kwargs)
            else:
                original_fit = gorilla.get_original_attribute(self, func_name)
                return original_fit(*args, **kwargs)

    def create_patch_func(func_name):
        def f(self, *args, **kwargs):
            return patched_fit(self, func_name, *args, **kwargs)

        return f

    patch_settings = gorilla.Settings(allow_hit=True, store_hit=True)
    for _, class_def in _all_estimators():
        for func_name in ["fit", "fit_transform", "fit_predict"]:
            if hasattr(class_def, func_name):
                original = getattr(class_def, func_name)

                # A couple of estimators use property methods to return fitting functions,
                # rather than defining the fitting functions on the estimator class directly.
                #
                # Example: https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/neighbors/_lof.py#L183  # noqa
                #
                # We currently exclude these property fitting methods from patching because
                # it's challenging to patch them correctly.
                #
                # Excluded fitting methods:
                # - sklearn.cluster._agglomerative.FeatureAgglomeration.fit_predict
                # - sklearn.neighbors._lof.LocalOutlierFactor.fit_predict
                #
                # You can list property fitting methods by inserting "print(class_def, func_name)"
                # in the if clause below.
                if isinstance(original, property):
                    continue

                patch_func = create_patch_func(func_name)
                # TODO(harupy): Package this wrap & patch routine into a utility function so we can
                # reuse it in other autologging integrations.
                # preserve original function attributes
                patch_func = functools.wraps(original)(patch_func)
                patch = gorilla.Patch(class_def,
                                      func_name,
                                      patch_func,
                                      settings=patch_settings)
                gorilla.apply(patch)
예제 #31
0
def autolog():
    """
    Enable automatic logging from Fastai to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.

    MLflow will also log the parameters of the EarlyStopping and OneCycleScheduler callbacks
    """
    from fastai.basic_train import LearnerCallback, Learner
    from fastai.callbacks.hooks import model_summary, layers_info
    from fastai.callbacks import EarlyStoppingCallback, OneCycleScheduler

    class __MLflowFastaiCallback(LearnerCallback):
        """
        Callback for auto-logging metrics and parameters.
        Records model structural information as params when training begins
        """
        def __init__(
            self,
            learner,
        ):
            super().__init__(learner)
            self.learner = learner
            self.opt = self.learn.opt
            self.metrics_names = ["train_loss", "valid_loss"
                                  ] + [o.__name__ for o in learner.metrics]

        def on_epoch_end(self, **kwargs):
            """
            Log loss and other metrics values after each epoch
            """
            if kwargs["smooth_loss"] is None or kwargs["last_metrics"] is None:
                return
            epoch = kwargs["epoch"]
            metrics = [kwargs["smooth_loss"]] + kwargs["last_metrics"]
            metrics = map(float, metrics)
            metrics = dict(zip(self.metrics_names, metrics))
            try_mlflow_log(mlflow.log_metrics, metrics, step=epoch)

        def on_train_begin(self, **kwargs):
            info = layers_info(self.learner)
            try_mlflow_log(mlflow.log_param, "num_layers", len(info))
            try_mlflow_log(mlflow.log_param, "opt_func",
                           self.opt_func.func.__name__)

            if hasattr(self.opt, "true_wd"):
                try_mlflow_log(mlflow.log_param, "true_wd", self.opt.true_wd)

            if hasattr(self.opt, "bn_wd"):
                try_mlflow_log(mlflow.log_param, "bn_wd", self.opt.bn_wd)

            if hasattr(self.opt, "train_bn"):
                try_mlflow_log(mlflow.log_param, "train_bn", self.train_bn)

            summary = model_summary(self.learner)
            try_mlflow_log(mlflow.set_tag, "model_summary", summary)

            tempdir = tempfile.mkdtemp()
            try:
                summary_file = os.path.join(tempdir, "model_summary.txt")
                with open(summary_file, "w") as f:
                    f.write(summary)
                try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
            finally:
                shutil.rmtree(tempdir)

        def on_train_end(self, **kwargs):
            try_mlflow_log(log_model, self.learner, artifact_path="model")

    def _find_callback_of_type(callback_type, callbacks):
        for callback in callbacks:
            if isinstance(callback, callback_type):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    "early_stop_monitor": callback.monitor,
                    "early_stop_min_delta": callback.min_delta,
                    "early_stop_patience": callback.patience,
                    "early_stop_mode": callback.mode,
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _log_one_cycle_callback_params(callback):
        if callback:
            try:
                params = {
                    "lr_max": callback.lr_max,
                    "div_factor": callback.div_factor,
                    "pct_start": callback.pct_start,
                    "final_div": callback.final_div,
                    "tot_epochs": callback.tot_epochs,
                    "start_epoch": callback.start_epoch,
                    "moms": callback.moms,
                }
                try_mlflow_log(mlflow.log_params, params)
            except Exception:  # pylint: disable=W0703
                return

    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, [self] + list(args), kwargs,
                              unlogged_params)

        callbacks = [cb(self)
                     for cb in self.callback_fns] + (self.callbacks or [])

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            callbacks += list(args[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowFastaiCallback(self)]
            args = tuple(tmp_list)
        elif "callbacks" in kwargs:
            callbacks += list(kwargs["callbacks"])
            kwargs["callbacks"] += [__MLflowFastaiCallback(self)]
        else:
            kwargs["callbacks"] = [__MLflowFastaiCallback(self)]

        early_stop_callback = _find_callback_of_type(EarlyStoppingCallback,
                                                     callbacks)
        one_cycle_callback = _find_callback_of_type(OneCycleScheduler,
                                                    callbacks)

        _log_early_stop_callback_params(early_stop_callback)
        _log_one_cycle_callback_params(one_cycle_callback)

        result = original(self, *args, **kwargs)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return result

    @gorilla.patch(Learner)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(Learner, "fit")
        unlogged_params = ["self", "callbacks", "learner"]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 3)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(Learner, "fit", fit, settings=settings))
예제 #32
0
def autolog():
    """
    Enable automatic logging from Keras to MLflow.
    Logs loss and any other metrics specified in the fit
    function, and optimizer data as parameters. Model checkpoints
    are logged as artifacts to a 'models' directory.

    EarlyStopping Integration with Keras Automatic Logging

    MLflow will detect if an ``EarlyStopping`` callback is used in a ``fit()``/``fit_generator()``
    call, and if the ``restore_best_weights`` parameter is set to be ``True``, then MLflow will
    log the metrics associated with the restored model as a final, extra step. The epoch of the
    restored model will also be logged as the metric ``restored_epoch``.
    This allows for easy comparison between the actual metrics of the restored model and
    the metrics of other models.

    If ``restore_best_weights`` is set to be ``False``,
    then MLflow will not log an additional step.

    Regardless of ``restore_best_weights``, MLflow will also log ``stopped_epoch``,
    which indicates the epoch at which training stopped due to early stopping.

    If training does not end due to early stopping, then ``stopped_epoch`` will be logged as ``0``.

    MLflow will also log the parameters of the EarlyStopping callback,
    excluding ``mode`` and ``verbose``.
    """
    import keras

    class __MLflowKerasCallback(keras.callbacks.Callback):
        """
        Callback for auto-logging metrics and parameters.
        Records available logs after each epoch.
        Records model structural information as params when training begins
        """
        def on_train_begin(self, logs=None):  # pylint: disable=unused-argument
            try_mlflow_log(mlflow.log_param, 'num_layers',
                           len(self.model.layers))
            try_mlflow_log(mlflow.log_param, 'optimizer_name',
                           type(self.model.optimizer).__name__)
            if hasattr(self.model.optimizer, 'lr'):
                lr = self.model.optimizer.lr if \
                    type(self.model.optimizer.lr) is float \
                    else keras.backend.eval(self.model.optimizer.lr)
                try_mlflow_log(mlflow.log_param, 'learning_rate', lr)
            if hasattr(self.model.optimizer, 'epsilon'):
                epsilon = self.model.optimizer.epsilon if \
                    type(self.model.optimizer.epsilon) is float \
                    else keras.backend.eval(self.model.optimizer.epsilon)
                try_mlflow_log(mlflow.log_param, 'epsilon', epsilon)

            sum_list = []
            self.model.summary(print_fn=sum_list.append)
            summary = '\n'.join(sum_list)
            tempdir = tempfile.mkdtemp()
            try:
                summary_file = os.path.join(tempdir, "model_summary.txt")
                with open(summary_file, 'w') as f:
                    f.write(summary)
                try_mlflow_log(mlflow.log_artifact, local_path=summary_file)
            finally:
                shutil.rmtree(tempdir)

        def on_epoch_end(self, epoch, logs=None):
            if not logs:
                return
            try_mlflow_log(mlflow.log_metrics, logs, step=epoch)

        def on_train_end(self, logs=None):
            try_mlflow_log(log_model, self.model, artifact_path='model')

    def _early_stop_check(callbacks):
        if LooseVersion(keras.__version__) < LooseVersion('2.3.0'):
            es_callback = keras.callbacks.EarlyStopping
        else:
            es_callback = keras.callbacks.callbacks.EarlyStopping
        for callback in callbacks:
            if isinstance(callback, es_callback):
                return callback
        return None

    def _log_early_stop_callback_params(callback):
        if callback:
            try:
                earlystopping_params = {
                    'monitor': callback.monitor,
                    'min_delta': callback.min_delta,
                    'patience': callback.patience,
                    'baseline': callback.baseline,
                    'restore_best_weights': callback.restore_best_weights
                }
                try_mlflow_log(mlflow.log_params, earlystopping_params)
            except Exception:  # pylint: disable=W0703
                return

    def _get_early_stop_callback_attrs(callback):
        try:
            return callback.stopped_epoch, callback.restore_best_weights, callback.patience
        except Exception:  # pylint: disable=W0703
            return None

    def _log_early_stop_callback_metrics(callback, history):
        if callback:
            callback_attrs = _get_early_stop_callback_attrs(callback)
            if callback_attrs is None:
                return
            stopped_epoch, restore_best_weights, patience = callback_attrs
            try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch)
            # Weights are restored only if early stopping occurs
            if stopped_epoch != 0 and restore_best_weights:
                restored_epoch = stopped_epoch - max(1, patience)
                try_mlflow_log(mlflow.log_metric, 'restored_epoch',
                               restored_epoch)
                restored_metrics = {
                    key: history.history[key][restored_epoch]
                    for key in history.history.keys()
                }
                # Checking that a metric history exists
                metric_key = next(iter(history.history), None)
                if metric_key is not None:
                    last_epoch = len(history.history[metric_key])
                    try_mlflow_log(mlflow.log_metrics,
                                   restored_metrics,
                                   step=last_epoch)

    def _run_and_log_function(self, original, args, kwargs, unlogged_params,
                              callback_arg_index):
        if not mlflow.active_run():
            try_mlflow_log(mlflow.start_run)
            auto_end_run = True
        else:
            auto_end_run = False

        log_fn_args_as_params(original, args, kwargs, unlogged_params)
        early_stop_callback = None

        # Checking if the 'callback' argument of the function is set
        if len(args) > callback_arg_index:
            tmp_list = list(args)
            early_stop_callback = _early_stop_check(
                tmp_list[callback_arg_index])
            tmp_list[callback_arg_index] += [__MLflowKerasCallback()]
            args = tuple(tmp_list)
        elif 'callbacks' in kwargs:
            early_stop_callback = _early_stop_check(kwargs['callbacks'])
            kwargs['callbacks'] += [__MLflowKerasCallback()]
        else:
            kwargs['callbacks'] = [__MLflowKerasCallback()]

        _log_early_stop_callback_params(early_stop_callback)

        history = original(self, *args, **kwargs)

        _log_early_stop_callback_metrics(early_stop_callback, history)

        if auto_end_run:
            try_mlflow_log(mlflow.end_run)

        return history

    @gorilla.patch(keras.Model)
    def fit(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit')
        unlogged_params = [
            'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose'
        ]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 5)

    @gorilla.patch(keras.Model)
    def fit_generator(self, *args, **kwargs):
        original = gorilla.get_original_attribute(keras.Model, 'fit_generator')
        unlogged_params = [
            'self', 'generator', 'callbacks', 'validation_data', 'verbose'
        ]
        return _run_and_log_function(self, original, args, kwargs,
                                     unlogged_params, 4)

    settings = gorilla.Settings(allow_hit=True, store_hit=True)
    gorilla.apply(gorilla.Patch(keras.Model, 'fit', fit, settings=settings))
    gorilla.apply(
        gorilla.Patch(keras.Model,
                      'fit_generator',
                      fit_generator,
                      settings=settings))