def test_patch_decorator(self): destination = _tomodule obj = gorilla.get_attribute(_frommodule, 'function') settings = gorilla.Settings(allow_hit=True, store_hit=True) self.assertIs(gorilla.patch(destination, settings=settings)(obj), obj) settings.allow_hit = False settings.store_hit = False decorator_data = gorilla.get_decorator_data(obj) expected_patches = [ gorilla.Patch(destination, 'function', obj, gorilla.Settings(allow_hit=True, store_hit=True)), ] self.assertEqual(decorator_data.patches, expected_patches)
def test_create_patches_5(self): destination = _tomodule obj = _frommodule gorilla.name('function')(gorilla.get_attribute(obj, 'Class')) gorilla.name('dummy_1')(gorilla.get_attribute(obj, 'Parent')) gorilla.name('dummy_2')(gorilla.get_attribute(obj, 'Child')) patches = gorilla.create_patches(destination, obj) expected_patches = [ gorilla.Patch(destination, 'dummy_2', gorilla.get_attribute(obj, 'Child')), gorilla.Patch(destination, 'function', gorilla.get_attribute(obj, 'Class')), gorilla.Patch(destination, 'dummy_1', gorilla.get_attribute(obj, 'Parent')), gorilla.Patch(destination, 'function', gorilla.get_attribute(obj, 'function')), gorilla.Patch(destination, 'global_variable', gorilla.get_attribute(obj, 'global_variable')), gorilla.Patch(destination, 'whatever', gorilla.get_attribute(obj, 'unbound_class_method')), gorilla.Patch(destination, 'unbound_static_method', gorilla.get_attribute(obj, 'unbound_static_method'), settings=gorilla.Settings(allow_hit=True)), ] self.assertEqual(patches, expected_patches)
def test_apply_patch_with_hit_3(self): settings = gorilla.Settings(allow_hit=True, store_hit=False) source_paths = [''] + _list_attribute_paths(_frommodule) target_paths = _list_attribute_paths(_tomodule) combinations = itertools.product(source_paths, target_paths) for source_path, target_path in combinations: self.setUp() destination_path, name = _split_attribute_path(target_path) destination = _get_attribute_from_path(_tomodule, destination_path) obj = _get_attribute_from_path(_frommodule, source_path) patch = gorilla.Patch(destination, name, obj, settings=settings) gorilla.apply(patch) self.assertIs( destination, _get_attribute_from_path(_tomodule, destination_path)) result = gorilla.get_attribute(destination, name) self.assertIs(result, obj) self.assertRaises(AttributeError, gorilla.get_original_attribute, destination, name) self.tearDown()
def test_settings(self): settings_1 = gorilla.Settings() settings_2 = gorilla.Settings(allow_hit=False, store_hit=True) self.assertEqual(settings_1, settings_2) self.assertNotEqual(settings_1, { 'allow_hit': False, 'store_hit': True }) settings_1.allow_hit = True self.assertNotEqual(settings_1, settings_2) settings_2.allow_hit = True self.assertEqual(settings_1, settings_2) settings_1.some_value = 123 self.assertNotEqual(settings_1, settings_2) settings_2.some_value = 123 self.assertEqual(settings_1, settings_2) self.assertEqual(str(gorilla.Settings()), "Settings(allow_hit=False, store_hit=True)") self.assertEqual(str(gorilla.Settings(allow_hit=True)), "Settings(allow_hit=True, store_hit=True)") self.assertEqual(str(gorilla.Settings(store_hit=False)), "Settings(allow_hit=False, store_hit=False)") self.assertEqual( str(gorilla.Settings(some_value=123)), "Settings(allow_hit=False, some_value=123, store_hit=True)") self.assertEqual( str(gorilla.Settings(string='abc')), "Settings(allow_hit=False, store_hit=True, string='abc')")
def set_seed(seed=100): np.random.seed(seed) random.seed(seed) # Monkey Patch get_seed. func = lambda op_seed: better_get_seed(seed, op_seed) settings = gorilla.Settings(allow_hit=True, store_hit=True) patch = gorilla.Patch( random_seed, 'get_seed', func, settings=settings) gorilla.apply(patch)
def patch_frelu(cls): import gorilla import torch.nn.functional as F settings = gorilla.Settings(allow_hit=True) patch = gorilla.Patch(F, 'relu', lambda x: x, settings=settings) gorilla.apply(patch) print('F.relu have been patched to identity fn.')
def autolog(): """ Enable automatic logging from TensorFlow to MLflow. Logs loss and any other metrics specified in the fit function, and optimizer data as parameters. Model checkpoints are logged as artifacts to a 'models' directory. """ import keras class __MLflowKerasCallback(keras.callbacks.Callback): """ Callback for auto-logging metrics and parameters. Records available logs after each epoch. Records model structural information as params after training finishes. """ def on_epoch_end(self, epoch, logs=None): if not logs: return try_mlflow_log(mlflow.log_metrics, logs, step=epoch) def on_train_end(self, logs=None): try_mlflow_log(mlflow.log_param, 'num_layers', len(self.model.layers)) try_mlflow_log(mlflow.log_param, 'optimizer_name', type(self.model.optimizer).__name__) if hasattr(self.model.optimizer, 'lr'): lr = self.model.optimizer.lr if \ type(self.model.optimizer.lr) is float \ else keras.backend.eval(self.model.optimizer.lr) try_mlflow_log(mlflow.log_param, 'learning_rate', lr) if hasattr(self.model.optimizer, 'epsilon'): epsilon = self.model.optimizer.epsilon if \ type(self.model.optimizer.epsilon) is float \ else keras.backend.eval(self.model.optimizer.epsilon) try_mlflow_log(mlflow.log_param, 'epsilon', epsilon) sum_list = [] self.model.summary(print_fn=sum_list.append) summary = '\n'.join(sum_list) try_mlflow_log(mlflow.set_tag, 'summary', summary) try_mlflow_log(log_model, self.model, artifact_path='model') @gorilla.patch(keras.Model) def fit(self, *args, **kwargs): original = gorilla.get_original_attribute(keras.Model, 'fit') if len(args) >= 6: l = list(args) l[5] += [__MLflowKerasCallback()] args = tuple(l) elif 'callbacks' in kwargs: kwargs['callbacks'] += [__MLflowKerasCallback()] else: kwargs['callbacks'] = [__MLflowKerasCallback()] return original(self, *args, **kwargs) settings = gorilla.Settings(allow_hit=True, store_hit=True) patch = gorilla.Patch(keras.Model, 'fit', fit, settings=settings) gorilla.apply(patch)
def test_settings_decorator_2(self): destination = _tomodule.Class obj = _frommodule.Class gorilla.settings(some_value=123)(gorilla.get_attribute(obj, 'method')) gorilla.settings(allow_hit=False)(gorilla.get_attribute( obj, 'static_method')) gorilla.settings(store_hit=True)(gorilla.get_attribute(obj, 'value')) gorilla.settings(allow_hit=False, store_hit=True)(gorilla.get_attribute( obj.Inner, 'method')) gorilla.patches(destination, settings=gorilla.Settings(allow_hit=True, store_hit=False))(obj) decorator_data = gorilla.get_decorator_data(obj) expected_patches = [ gorilla.Patch(destination, 'STATIC_VALUE', gorilla.get_attribute(obj, 'STATIC_VALUE'), settings=gorilla.Settings(allow_hit=True, store_hit=False)), gorilla.Patch(destination, 'class_method', gorilla.get_attribute(obj, 'class_method'), settings=gorilla.Settings(allow_hit=True, store_hit=False)), gorilla.Patch(destination, 'method', gorilla.get_attribute(obj, 'method'), settings=gorilla.Settings(allow_hit=True, some_value=123, store_hit=False)), gorilla.Patch(destination, 'static_method', gorilla.get_attribute(obj, 'static_method'), settings=gorilla.Settings(store_hit=False)), gorilla.Patch(destination, 'value', gorilla.get_attribute(obj, 'value'), settings=gorilla.Settings(allow_hit=True)), gorilla.Patch(destination.Inner, 'STATIC_VALUE', gorilla.get_attribute(obj.Inner, 'STATIC_VALUE'), settings=gorilla.Settings(allow_hit=True, store_hit=False)), gorilla.Patch(destination.Inner, 'method', gorilla.get_attribute(obj.Inner, 'method'), settings=gorilla.Settings(allow_hit=False, store_hit=True)), ] self.assertEqual(decorator_data.patches, expected_patches)
def apply_gorrila(function: Callable, module: Any): """Overriding a function using a gorilla patch. Args: function (Callable): Override function module (Any): Function caller module """ patch = gorilla.Patch(module, function.__name__, function, settings=gorilla.Settings(allow_hit=True)) gorilla.apply(patch)
def test_patches_decorator(self): destination = _tomodule.Class obj = _frommodule.Class settings = gorilla.Settings(allow_hit=True, store_hit=True) self.assertIs( gorilla.patches(destination, settings=settings)(obj), obj) settings.allow_hit = False settings.store_hit = False decorator_data = gorilla.get_decorator_data(obj) expected_patches = [ gorilla.Patch(destination, 'STATIC_VALUE', gorilla.get_attribute(obj, 'STATIC_VALUE'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), gorilla.Patch(destination, 'class_method', gorilla.get_attribute(obj, 'class_method'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), gorilla.Patch(destination, 'method', gorilla.get_attribute(obj, 'method'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), gorilla.Patch(destination, 'static_method', gorilla.get_attribute(obj, 'static_method'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), gorilla.Patch(destination, 'value', gorilla.get_attribute(obj, 'value'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), gorilla.Patch(destination.Inner, 'STATIC_VALUE', gorilla.get_attribute(obj.Inner, 'STATIC_VALUE'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), gorilla.Patch(destination.Inner, 'method', gorilla.get_attribute(obj.Inner, 'method'), settings=gorilla.Settings(allow_hit=True, store_hit=True)), ] self.assertEqual(decorator_data.patches, expected_patches)
def monkey_patch_tf_get_seed(seed: int, default_op_seed: int = 1923746) -> None: """ Monkey patching tensorflow.random.get_seed to avoid the increasing memory usage arising from repeated random sampling from tensorflow distributions. This code is taken from https://github.com/lerobitaille/tf-issue-36164-workaround which remedies issue 36164 (https://github.com/tensorflow/tensorflow/issues/36164). We have raised our own clearer and concise issue which should be the point at which should be the reference point for this memory leak: https://github.com/tensorflow/tensorflow/issues/37252 :param seed: Seed to set as the TensorFlow global seed. :param default_op_seed: Default seed for any random operations if required. """ warn( "WARNING: Patching native TensorFlow functionality to avoid memory leak when setting " "a random seed.") warn("WARNING: Patch required due to TensorFlow issue 37252. " "Check if the issue is resolved at " "https://github.com/tensorflow/tensorflow/issues/37252") # Lazy imports to show which imports to remove once the issue is resolved and to avoid wider # usage of monkey patching and usage of the TensorFlow back end which involves imports the # linter does not like. # pylint: disable=no-name-in-module,import-error from tensorflow.python.eager import context from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import random_seed # Remove gorilla dependency completely when issue fixed. (Remove from requirements.txt) import gorilla def better_get_seed(global_seed, op_seed): if op_seed is not None: return global_seed, op_seed else: return global_seed, default_op_seed # Monkey Patch get_seed. def func(op_seed): better_get_seed(seed, op_seed) settings = gorilla.Settings(allow_hit=True, store_hit=True) patch = gorilla.Patch(random_seed, 'get_seed', func, settings=settings) gorilla.apply(patch) # Also clear the kernel cache, to reset any existing seeds # pylint: disable=protected-access _context = context.context() if _context._context_handle is not None: pywrap_tensorflow.TFE_ContextClearCaches(_context._context_handle)
def test_get_original_attribute(self): destination = _tomodule.Class name = 'method' target = gorilla.get_attribute(destination, name) obj = gorilla.get_attribute(_frommodule, 'unbound_method') settings = gorilla.Settings(allow_hit=True) patch = gorilla.Patch(destination, name, obj, settings=settings) gorilla.apply(patch) self.assertIs( _unfold(gorilla.get_original_attribute(destination, name)), target) gorilla.apply(patch) self.assertIs( _unfold(gorilla.get_original_attribute(destination, name)), target)
def test_apply_patch_with_hit_1(self): settings = gorilla.Settings() source_paths = [''] + _list_attribute_paths(_frommodule) target_paths = _list_attribute_paths(_tomodule) combinations = itertools.product(source_paths, target_paths) for source_path, target_path in combinations: self.setUp() destination_path, name = _split_attribute_path(target_path) destination = _get_attribute_from_path(_tomodule, destination_path) obj = _get_attribute_from_path(_frommodule, source_path) patch = gorilla.Patch(destination, name, obj, settings=settings) self.assertRaises(RuntimeError, gorilla.apply, patch) self.tearDown()
def wrap_patch(destination, name, patch, settings=None): """ Apply a patch while preserving the attributes (e.g. __doc__) of an original function. :param destination: Patch destination :param name: Name of the attribute at the destination :param patch: Patch function :param settings: Settings for gorilla.Patch """ if settings is None: settings = gorilla.Settings(allow_hit=True, store_hit=True) original = getattr(destination, name) wrapped = functools.wraps(original)(patch) patch = gorilla.Patch(destination, name, wrapped, settings=settings) gorilla.apply(patch)
def test_create_patches_4(self): def filter(name, value): return 'method' in name destination = _tomodule obj = _frommodule patches = gorilla.create_patches(destination, obj, filter=filter) expected_patches = [ gorilla.Patch(destination, 'function', gorilla.get_attribute(obj, 'function')), gorilla.Patch(destination, 'whatever', gorilla.get_attribute(obj, 'unbound_class_method')), gorilla.Patch(destination, 'unbound_static_method', gorilla.get_attribute(obj, 'unbound_static_method'), settings=gorilla.Settings(allow_hit=True)) ] self.assertEqual(patches, expected_patches) destination = _tomodule.Class obj = _frommodule.Class patches = gorilla.create_patches(destination, obj, filter=filter) expected_patches = [ gorilla.Patch(destination, 'class_method', gorilla.get_attribute(obj, 'class_method')), gorilla.Patch(destination, 'whatever', gorilla.get_attribute(obj, 'method')), ] self.assertEqual(patches, expected_patches) destination = _tomodule.Parent obj = _frommodule.Parent patches = gorilla.create_patches(destination, obj, filter=filter) expected_patches = [ gorilla.Patch(destination, 'method', gorilla.get_attribute(obj, 'method')), ] self.assertEqual(patches, expected_patches) destination = _tomodule.Child obj = _frommodule.Child patches = gorilla.create_patches(destination, obj, filter=filter) expected_patches = [ gorilla.Patch(destination, 'method', gorilla.get_attribute(obj, 'method')), ] self.assertEqual(patches, expected_patches)
def initialize(): """Initialize the extensions. The patches from the Bana package are searched and applied to the Maya API. Patches that seem to have already been applied are skipped. """ packages = [ importlib.import_module('%s.%s' % (__package__, packageName)) for packageName in _PACKAGES ] defaultSettings = gorilla.Settings() for patch in gorilla.find_patches(packages): settings = (defaultSettings if patch.settings is None else patch.settings) if not settings.allow_hit and hasattr(patch.destination, patch.name): continue gorilla.apply(patch)
def autolog(): """ Enables autologging for scikit-learn estimators. **When is autologging performed?** Autologging is performed when you call: - ``estimator.fit()`` - ``estimator.fit_predict()`` - ``estimator.fit_transform()`` **Logged information** **Parameters** - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params`` is called with ``deep=True``. This means when you fit a meta estimator that chains a series of estimators, the parameters of these child estimators are also logged. **Metrics** - A training score obtained by ``estimator.score``. Note that the training score is computed using parameters given to ``fit()``. - Common metrics for classifier: - `precision score`_ .. _precision score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html - `recall score`_ .. _recall score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html - `f1 score`_ .. _f1 score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html - `accuracy score`_ .. _accuracy score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html If the classifier has method ``predict_proba``, we additionally log: - `log loss`_ .. _log loss: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html - `roc auc score`_ .. _roc auc score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html - Common metrics for regressor: - `mean squared error`_ .. _mean squared error: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html - root mean squared error - `mean absolute error`_ .. _mean absolute error: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html - `r2 score`_ .. _r2 score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html **Tags** - An estimator class name (e.g. "LinearRegression"). - A fully qualified estimator class name (e.g. "sklearn.linear_model._base.LinearRegression"). **Artifacts** - A fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`). **How does autologging work for meta estimators?** When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent ``fit()`` calls. **Parameter search** In addition to recording the information discussed above, autologging for parameter search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs with metrics for each set of explored parameters, as well as artifacts and parameters for the best model (if available). **Supported estimators** - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators). - `Pipeline`_ - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) .. _sklearn.utils.all_estimators: https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html .. _Pipeline: https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html .. _GridSearchCV: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html .. _RandomizedSearchCV: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html **Example** `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_ .. code-block:: python from pprint import pprint import numpy as np from sklearn.linear_model import LinearRegression import mlflow def fetch_logged_data(run_id): client = mlflow.tracking.MlflowClient() data = client.get_run(run_id).data tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} artifacts = [f.path for f in client.list_artifacts(run_id, "model")] return data.params, data.metrics, tags, artifacts # enable autologging mlflow.sklearn.autolog() # prepare training data X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) y = np.dot(X, np.array([1, 2])) + 3 # train a model model = LinearRegression() with mlflow.start_run() as run: model.fit(X, y) # fetch logged data params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id) pprint(params) # {'copy_X': 'True', # 'fit_intercept': 'True', # 'n_jobs': 'None', # 'normalize': 'False'} pprint(metrics) # {'training_score': 1.0, 'training_mae': 2.220446049250313e-16, 'training_mse': 1.9721522630525295e-31, 'training_r2_score': 1.0, 'training_rmse': 4.440892098500626e-16} pprint(tags) # {'estimator_class': 'sklearn.linear_model._base.LinearRegression', # 'estimator_name': 'LinearRegression'} pprint(artifacts) # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl'] """ import pandas as pd import sklearn from mlflow.models import infer_signature from mlflow.sklearn.utils import ( _MIN_SKLEARN_VERSION, _is_supported_version, _chunk_dict, _get_args_for_score, _log_specialized_estimator_content, _get_Xy, _all_estimators, _truncate_dict, _get_arg_names, _get_estimator_info_tags, _get_meta_estimators_for_autologging, _is_parameter_search_estimator, _log_parameter_search_results_as_artifact, _create_child_runs_for_parameter_search, ) from mlflow.tracking.context import registry as context_registry from mlflow.utils.validation import ( MAX_PARAMS_TAGS_PER_BATCH, MAX_PARAM_VAL_LENGTH, MAX_ENTITY_KEY_LENGTH, ) if not _is_supported_version(): warnings.warn( "Autologging utilities may not work properly on scikit-learn < {} " .format(_MIN_SKLEARN_VERSION) + "(current version: {})".format(sklearn.__version__), stacklevel=2, ) def fit_mlflow(self, func_name, *args, **kwargs): should_start_run = mlflow.active_run() is None if should_start_run: try_mlflow_log(mlflow.start_run) _log_pretraining_metadata(self, *args, **kwargs) original_fit = gorilla.get_original_attribute(self, func_name) try: fit_output = original_fit(*args, **kwargs) except Exception as e: if should_start_run: try_mlflow_log(mlflow.end_run, RunStatus.to_string(RunStatus.FAILED)) raise e _log_posttraining_metadata(self, *args, **kwargs) if should_start_run: try_mlflow_log(mlflow.end_run) return fit_output def _log_pretraining_metadata(estimator, *args, **kwargs): # pylint: disable=unused-argument """ Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ # Deep parameter logging includes parameters from children of a given # estimator. For some meta estimators (e.g., pipelines), recording # these parameters is desirable. For parameter search estimators, # however, child estimators act as seeds for the parameter search # process; accordingly, we avoid logging initial, untuned parameters # for these seed estimators. should_log_params_deeply = not _is_parameter_search_estimator( estimator) # Chunk model parameters to avoid hitting the log_batch API limit for chunk in _chunk_dict( estimator.get_params(deep=should_log_params_deeply), chunk_size=MAX_PARAMS_TAGS_PER_BATCH, ): truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH) try_mlflow_log(mlflow.log_params, truncated) try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator)) def _log_posttraining_metadata(estimator, *args, **kwargs): """ Records metadata for a scikit-learn estimator after training has completed. This is intended to be invoked within a patched scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active MLflow run that can be referenced via the fluent Tracking API. :param estimator: The scikit-learn estimator for which to log metadata. :param args: The arguments passed to the scikit-learn training routine (e.g., `fit()`, `fit_transform()`, ...). :param kwargs: The keyword arguments passed to the scikit-learn training routine. """ if hasattr(estimator, "score"): try: score_args = _get_args_for_score(estimator.score, estimator.fit, args, kwargs) training_score = estimator.score(*score_args) except Exception as e: # pylint: disable=broad-except msg = ( estimator.score.__qualname__ + " failed. The 'training_score' metric will not be recorded. Scoring error: " + str(e)) _logger.warning(msg) else: try_mlflow_log(mlflow.log_metric, "training_score", training_score) # log common metrics and artifacts for estimators (classifier, regressor) _log_specialized_estimator_content(estimator, mlflow.active_run().info.run_id, args, kwargs) input_example = None signature = None if hasattr(estimator, "predict"): try: # Fetch an input example using the first several rows of the array-like # training data supplied to the training routine (e.g., `fit()`) SAMPLE_ROWS = 5 fit_arg_names = _get_arg_names(estimator.fit) X_var_name, y_var_name = fit_arg_names[:2] input_example = _get_Xy(args, kwargs, X_var_name, y_var_name)[0][:SAMPLE_ROWS] model_output = estimator.predict(input_example) signature = infer_signature(input_example, model_output) except Exception as e: # pylint: disable=broad-except input_example = None msg = "Failed to infer an input example and model signature: " + str( e) _logger.warning(msg) try_mlflow_log( log_model, estimator, artifact_path="model", signature=signature, input_example=input_example, ) if _is_parameter_search_estimator(estimator): if hasattr(estimator, "best_estimator_"): try_mlflow_log( log_model, estimator.best_estimator_, artifact_path="best_estimator", signature=signature, input_example=input_example, ) if hasattr(estimator, "best_params_"): best_params = { "best_{param_name}".format(param_name=param_name): param_value for param_name, param_value in estimator.best_params_.items() } try_mlflow_log(mlflow.log_params, best_params) if hasattr(estimator, "cv_results_"): try: # Fetch environment-specific tags (e.g., user and source) to ensure that lineage # information is consistent with the parent run environment_tags = context_registry.resolve_tags() _create_child_runs_for_parameter_search( cv_estimator=estimator, parent_run=mlflow.active_run(), child_tags=environment_tags, ) except Exception as e: # pylint: disable=broad-except msg = ( "Encountered exception during creation of child runs for parameter search." " Child runs may be missing. Exception: {}".format( str(e))) _logger.warning(msg) try: cv_results_df = pd.DataFrame.from_dict( estimator.cv_results_) _log_parameter_search_results_as_artifact( cv_results_df, mlflow.active_run().info.run_id) except Exception as e: # pylint: disable=broad-except msg = ( "Failed to log parameter search results as an artifact." " Exception: {}".format(str(e))) _logger.warning(msg) def patched_fit(self, func_name, *args, **kwargs): """ To be applied to a sklearn model class that defines a `fit` method and inherits from `BaseEstimator` (thereby defining the `get_params()` method) """ with _SklearnTrainingSession(clazz=self.__class__, allow_children=False) as t: if t.should_log(): return fit_mlflow(self, func_name, *args, **kwargs) else: original_fit = gorilla.get_original_attribute(self, func_name) return original_fit(*args, **kwargs) def create_patch_func(func_name): def f(self, *args, **kwargs): return patched_fit(self, func_name, *args, **kwargs) return f patch_settings = gorilla.Settings(allow_hit=True, store_hit=True) _, estimators_to_patch = zip(*_all_estimators()) # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected # for patching if they are not already included in the output of `all_estimators()` estimators_to_patch = set(estimators_to_patch).union( set(_get_meta_estimators_for_autologging())) for class_def in estimators_to_patch: for func_name in ["fit", "fit_transform", "fit_predict"]: if hasattr(class_def, func_name): original = getattr(class_def, func_name) # A couple of estimators use property methods to return fitting functions, # rather than defining the fitting functions on the estimator class directly. # # Example: https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/neighbors/_lof.py#L183 # noqa # # We currently exclude these property fitting methods from patching because # it's challenging to patch them correctly. # # Excluded fitting methods: # - sklearn.cluster._agglomerative.FeatureAgglomeration.fit_predict # - sklearn.neighbors._lof.LocalOutlierFactor.fit_predict # # You can list property fitting methods by inserting "print(class_def, func_name)" # in the if clause below. if isinstance(original, property): continue patch_func = create_patch_func(func_name) # TODO(harupy): Package this wrap & patch routine into a utility function so we can # reuse it in other autologging integrations. # preserve original function attributes patch_func = functools.wraps(original)(patch_func) patch = gorilla.Patch(class_def, func_name, patch_func, settings=patch_settings) gorilla.apply(patch)
def autolog(): """ Enable automatic logging from Gluon to MLflow. Logs loss and any other metrics specified in the fit function, and optimizer data as parameters. Model checkpoints are logged as artifacts to a 'models' directory. """ class __MLflowGluonCallback(EpochEnd, TrainEnd, TrainBegin): def __init__(self): self.current_epoch = 0 def epoch_end(self, estimator, *args, **kwargs): logs = {} for metric in estimator.train_metrics: metric_name, metric_val = metric.get() logs[metric_name] = metric_val for metric in estimator.val_metrics: metric_name, metric_val = metric.get() logs[metric_name] = metric_val try_mlflow_log(mlflow.log_metrics, logs, step=self.current_epoch) self.current_epoch += 1 def train_begin(self, estimator, *args, **kwargs): try_mlflow_log(mlflow.log_param, "num_layers", len(estimator.net)) if estimator.max_epoch is not None: try_mlflow_log(mlflow.log_param, "epochs", estimator.max_epoch) if estimator.max_batch is not None: try_mlflow_log(mlflow.log_param, "batches", estimator.max_batch) try_mlflow_log(mlflow.log_param, "optimizer_name", type(estimator.trainer.optimizer).__name__) if hasattr(estimator.trainer.optimizer, "lr"): try_mlflow_log(mlflow.log_param, "learning_rate", estimator.trainer.optimizer.lr) if hasattr(estimator.trainer.optimizer, "epsilon"): try_mlflow_log(mlflow.log_param, "epsilon", estimator.trainer.optimizer.epsilon) def train_end(self, estimator, *args, **kwargs): if isinstance(estimator.net, HybridSequential): try_mlflow_log(log_model, estimator.net, artifact_path="model") @gorilla.patch(Estimator) def fit(self, *args, **kwargs): if not mlflow.active_run(): auto_end_run = True else: auto_end_run = False original = gorilla.get_original_attribute(Estimator, "fit") if len(args) >= 4: l = list(args) l[3] += [__MLflowGluonCallback()] args = tuple(l) elif "event_handlers" in kwargs: kwargs["event_handlers"] += [__MLflowGluonCallback()] else: kwargs["event_handlers"] = [__MLflowGluonCallback()] result = original(self, *args, **kwargs) if auto_end_run: mlflow.end_run() return result settings = gorilla.Settings(allow_hit=True, store_hit=True) gorilla.apply(gorilla.Patch(Estimator, "fit", fit, settings=settings))
def autolog(every_n_iter=100): # pylint: disable=E0611 """ Enable automatic logging from TensorFlow to MLflow. If applicable, model checkpoints are logged as artifacts to a 'models' directory, along with any TensorBoard log data. Refer to the tracking documentation for information on what is logged with different TensorFlow workflows. :param every_n_iter: The frequency with which metrics should be logged. Defaults to 100. Ex: a value of 100 will log metrics at step 0, 100, 200, etc. """ global _LOG_EVERY_N_STEPS _LOG_EVERY_N_STEPS = every_n_iter if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'): warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return try: from tensorflow.python.summary.writer.event_file_writer import EventFileWriter from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2 from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary.writer.writer import FileWriter except ImportError: warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return @contextmanager def _manage_active_run(): if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) global _AUTOLOG_RUN_ID if mlflow.active_run( ) is not None: # defensive check in case `mlflow.start_run` fails _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id yield mlflow.active_run() if mlflow.active_run() is not None and mlflow.active_run( ).info.run_id == _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.end_run) @gorilla.patch(tensorflow.estimator.Estimator) def train(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'train') # Checking step and max_step parameters for logging if len(args) >= 3: try_mlflow_log(mlflow.log_param, 'steps', args[2]) if len(args) >= 4: try_mlflow_log(mlflow.log_param, 'max_steps', args[3]) if 'steps' in kwargs: try_mlflow_log(mlflow.log_param, 'steps', kwargs['steps']) if 'max_steps' in kwargs: try_mlflow_log(mlflow.log_param, 'max_steps', kwargs['max_steps']) result = original(self, *args, **kwargs) return result @gorilla.patch(tensorflow.estimator.Estimator) def export_saved_model(self, *args, **kwargs): auto_end = False if not mlflow.active_run(): global _AUTOLOG_RUN_ID if _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID) else: try_mlflow_log(mlflow.start_run) auto_end = True original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'export_saved_model') serialized = original(self, *args, **kwargs) try_mlflow_log(log_model, tf_saved_model_dir=serialized.decode('utf-8'), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key='predict', artifact_path='model') if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\ or auto_end: try_mlflow_log(mlflow.end_run) return serialized @gorilla.patch(tensorflow.estimator.Estimator) def export_savedmodel(self, *args, **kwargs): auto_end = False global _AUTOLOG_RUN_ID if not mlflow.active_run(): if _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID) else: try_mlflow_log(mlflow.start_run) auto_end = True original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'export_savedmodel') serialized = original(self, *args, **kwargs) try_mlflow_log(log_model, tf_saved_model_dir=serialized.decode('utf-8'), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key='predict', artifact_path='model') if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\ or auto_end: try_mlflow_log(mlflow.end_run) return serialized def _early_stop_check(callbacks): for callback in callbacks: if isinstance(callback, tensorflow.keras.callbacks.EarlyStopping): return callback return None def _log_early_stop_callback_params(callback): if callback: try: earlystopping_params = { 'monitor': callback.monitor, 'min_delta': callback.min_delta, 'patience': callback.patience, 'baseline': callback.baseline, 'restore_best_weights': callback.restore_best_weights } try_mlflow_log(mlflow.log_params, earlystopping_params) except Exception: # pylint: disable=W0703 return def _get_early_stop_callback_attrs(callback): try: return callback.stopped_epoch, callback.restore_best_weights, callback.patience except Exception: # pylint: disable=W0703 return None def _log_early_stop_callback_metrics(callback, history): if callback: callback_attrs = _get_early_stop_callback_attrs(callback) if callback_attrs is None: return stopped_epoch, restore_best_weights, patience = callback_attrs try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch) # Weights are restored only if early stopping occurs if stopped_epoch != 0 and restore_best_weights: restored_epoch = stopped_epoch - max(1, patience) try_mlflow_log(mlflow.log_metric, 'restored_epoch', restored_epoch) restored_metrics = { key: history.history[key][restored_epoch] for key in history.history.keys() } # Metrics are logged as 'epoch_loss' and 'epoch_acc' in TF 1.X if LooseVersion( tensorflow.__version__) < LooseVersion('2.0.0'): if 'loss' in restored_metrics: restored_metrics['epoch_loss'] = restored_metrics.pop( 'loss') if 'acc' in restored_metrics: restored_metrics['epoch_acc'] = restored_metrics.pop( 'acc') # Checking that a metric history exists metric_key = next(iter(history.history), None) if metric_key is not None: last_epoch = len(history.history[metric_key]) try_mlflow_log(mlflow.log_metrics, restored_metrics, step=last_epoch) @gorilla.patch(tensorflow.keras.Model) def fit(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute(tensorflow.keras.Model, 'fit') unlogged_params = [ 'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose' ] log_fn_args_as_params(original, args, kwargs, unlogged_params) early_stop_callback = None # Checking if the 'callback' argument of fit() is set if len(args) >= 6: tmp_list = list(args) early_stop_callback = _early_stop_check(tmp_list[5]) tmp_list[5], log_dir = _setup_callbacks(tmp_list[5]) args = tuple(tmp_list) elif 'callbacks' in kwargs: early_stop_callback = _early_stop_check(kwargs['callbacks']) kwargs['callbacks'], log_dir = _setup_callbacks( kwargs['callbacks']) else: kwargs['callbacks'], log_dir = _setup_callbacks([]) _log_early_stop_callback_params(early_stop_callback) history = original(self, *args, **kwargs) _log_early_stop_callback_metrics(early_stop_callback, history) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir.location, artifact_path='tensorboard_logs') if log_dir.is_temp: shutil.rmtree(log_dir.location) return history @gorilla.patch(tensorflow.keras.Model) def fit_generator(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute(tensorflow.keras.Model, 'fit_generator') unlogged_params = [ 'self', 'generator', 'callbacks', 'validation_data', 'verbose' ] log_fn_args_as_params(original, args, kwargs, unlogged_params) # Checking if the 'callback' argument of fit() is set if len(args) >= 5: tmp_list = list(args) tmp_list[4], log_dir = _setup_callbacks(tmp_list[4]) args = tuple(tmp_list) elif 'callbacks' in kwargs: kwargs['callbacks'], log_dir = _setup_callbacks( kwargs['callbacks']) else: kwargs['callbacks'], log_dir = _setup_callbacks([]) result = original(self, *args, **kwargs) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir.location, artifact_path='tensorboard_logs') if log_dir.is_temp: shutil.rmtree(log_dir.location) return result @gorilla.patch(EventFileWriter) def add_event(self, event): _log_event(event) original = gorilla.get_original_attribute(EventFileWriter, 'add_event') return original(self, event) @gorilla.patch(FileWriter) def add_summary(self, *args, **kwargs): original = gorilla.get_original_attribute(FileWriter, 'add_summary') result = original(self, *args, **kwargs) _flush_queue() return result settings = gorilla.Settings(allow_hit=True, store_hit=True) patches = [ gorilla.Patch(EventFileWriter, 'add_event', add_event, settings=settings), gorilla.Patch(EventFileWriterV2, 'add_event', add_event, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'train', train, settings=settings), gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings), gorilla.Patch(tensorflow.keras.Model, 'fit_generator', fit_generator, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'export_saved_model', export_saved_model, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'export_savedmodel', export_savedmodel, settings=settings), gorilla.Patch(FileWriter, 'add_summary', add_summary, settings=settings), ] for x in patches: gorilla.apply(x)
def test_apply_patch_with_hit_2(self): settings = gorilla.Settings(allow_hit=True) branch_count = 0 source_paths = [''] + _list_attribute_paths(_frommodule) target_paths = _list_attribute_paths(_tomodule) combinations = itertools.product(source_paths, target_paths) for source_path, target_path in combinations: self.setUp() destination_path, name = _split_attribute_path(target_path) destination = _get_attribute_from_path(_tomodule, destination_path) target = gorilla.get_attribute(destination, name) obj = _get_attribute_from_path(_frommodule, source_path) patch = gorilla.Patch(destination, name, obj, settings=settings) gorilla.apply(patch) self.assertIs( destination, _get_attribute_from_path(_tomodule, destination_path)) result = gorilla.get_attribute(destination, name) self.assertIs(result, obj) # `gorilla.get_original_attribute` cannot be used here because it # could return a bounded method, which would not compare as # expected. original = gorilla.get_attribute(destination, '_gorilla_original_%s' % (name, )) self.assertIs(original, target) self.assertIsNot(original, result) if source_path == '': branch_count += 1 self.assertEqual(result.global_variable, "frommodule.global_variable") self.assertEqual( result.function(), "frommodule.function (frommodule.Class.STATIC_VALUE)") self.assertEqual(result.Class.STATIC_VALUE, "frommodule.Class.STATIC_VALUE") self.assertEqual(result.Class.Inner.STATIC_VALUE, "frommodule.Class.Inner.STATIC_VALUE") self.assertEqual(result.Parent.STATIC_VALUE, "frommodule.Parent.STATIC_VALUE") self.assertEqual(result.Child.STATIC_VALUE, "frommodule.Parent.STATIC_VALUE") elif source_path == 'global_variable': branch_count += 1 self.assertEqual(result, "frommodule.global_variable") if destination_path in _CLS_REFERENCES and name in ( 'STATIC_VALUE', ): branch_count += 1 self.assertEqual(destination.STATIC_VALUE, "frommodule.global_variable") if target_path == 'Class.STATIC_VALUE': branch_count += 1 self.assertEqual( destination.class_method(), "tomodule.Class.class_method (frommodule.global_variable)" ) self.assertEqual( destination.static_method(), "tomodule.Class.static_method (frommodule.global_variable)" ) elif source_path == 'function': branch_count += 1 self.assertEqual( result(), "frommodule.function (frommodule.Class.STATIC_VALUE)") elif source_path == 'unbound_method': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( 'STATIC_VALUE', '__init__'): branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.unbound_method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)" % ((_CLS_REFERENCES[destination_path], ) * 2)) elif source_path == 'unbound_class_method': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( 'STATIC_VALUE', ): branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.unbound_class_method (tomodule.%s.STATIC_VALUE)" % (_CLS_REFERENCES[destination_path], )) elif source_path == 'unbound_static_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.unbound_static_method (frommodule.Class.STATIC_VALUE)" ) elif source_path == 'Class': branch_count += 1 self.assertEqual(result.STATIC_VALUE, "frommodule.Class.STATIC_VALUE") elif source_path == 'Class.Inner': branch_count += 1 self.assertEqual(result.STATIC_VALUE, "frommodule.Class.Inner.STATIC_VALUE") elif source_path == 'Class.value': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( '__init__', ): branch_count += 1 instance = destination() self.assertEqual( getattr(instance, name), "frommodule.Class.value.getter (tomodule.%s.instance_value)" % (_CLS_REFERENCES[destination_path], )) setattr(instance, name, 'hello') self.assertEqual(getattr(instance, name), "frommodule.Class.value.getter (hello)") elif source_path == 'Class.method': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( 'STATIC_VALUE', '__init__'): branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.Class.method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)" % ((_CLS_REFERENCES[destination_path], ) * 2)) elif source_path == 'Class.class_method': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( 'STATIC_VALUE', ): branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.Class.class_method (tomodule.%s.STATIC_VALUE)" % (_CLS_REFERENCES[destination_path], )) elif source_path == 'Class.static_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.Class.static_method (frommodule.Class.STATIC_VALUE)" ) elif source_path == 'Parent': branch_count += 1 self.assertEqual(result.__slots__, ('instance_value', 'parent_value', 'to_value', 'from_value')) self.assertEqual(result().parent_value, "frommodule.Parent.parent_value") self.assertEqual( result().method(), "frommodule.Parent.method (frommodule.Parent.instance_value)" ) elif source_path == 'Parent.method': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( '__init__', ): branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.Parent.method (tomodule.%s.instance_value)" % (_CLS_REFERENCES[destination_path], )) elif source_path == 'Child': branch_count += 1 self.assertEqual(result.__slots__, ('child_value', )) self.assertEqual(result().parent_value, "frommodule.Parent.parent_value") self.assertEqual(result().child_value, "frommodule.Child.child_value") self.assertEqual( result().method(), "frommodule.Parent.method (frommodule.Parent.instance_value)" ) elif source_path == 'Child.method': branch_count += 1 if destination_path in _CLS_REFERENCES and name not in ( '__init__', ): branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.Parent.method (tomodule.%s.instance_value)" % (_CLS_REFERENCES[destination_path], )) self.tearDown() # Make sure that all test branches are covered. self.assertEqual(branch_count, 427)
def autolog(): """ Enables automatic logging from LightGBM to MLflow. Logs the following. - parameters specified in `lightgbm.train`_. - metrics on each iteration (if ``valid_sets`` specified). - metrics at the best iteration (if ``early_stopping_rounds`` specified). - feature importance (both "split" and "gain") as JSON files and plots. - trained model. Note that the `scikit-learn API`_ is not supported. """ import lightgbm import numpy as np @gorilla.patch(lightgbm) def train(*args, **kwargs): def record_eval_results(eval_results): """ Create a callback function that records evaluation results. """ def callback(env): res = {} for data_name, eval_name, value, _ in env.evaluation_result_list: key = data_name + '-' + eval_name res[key] = value eval_results.append(res) return callback def log_feature_importance_plot(features, importance, importance_type): """ Log feature importance plot. """ import matplotlib.pyplot as plt indices = np.argsort(importance) features = np.array(features)[indices] importance = importance[indices] num_features = len(features) # If num_features > 10, increase the figure height to prevent the plot # from being too dense. w, h = [6.4, 4.8] # matplotlib's default figure size h = h + 0.1 * num_features if num_features > 10 else h fig, ax = plt.subplots(figsize=(w, h)) yloc = np.arange(num_features) ax.barh(yloc, importance, align='center', height=0.5) ax.set_yticks(yloc) ax.set_yticklabels(features) ax.set_xlabel('Importance') ax.set_title('Feature Importance ({})'.format(importance_type)) fig.tight_layout() tmpdir = tempfile.mkdtemp() try: # pylint: disable=undefined-loop-variable filepath = os.path.join(tmpdir, 'feature_importance_{}.png'.format(imp_type)) fig.savefig(filepath) try_mlflow_log(mlflow.log_artifact, filepath) finally: plt.close(fig) shutil.rmtree(tmpdir) if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) auto_end_run = True else: auto_end_run = False original = gorilla.get_original_attribute(lightgbm, 'train') # logging booster params separately via mlflow.log_params to extract key/value pairs # and make it easier to compare them across runs. params = args[0] if len(args) > 0 else kwargs['params'] try_mlflow_log(mlflow.log_params, params) unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval', 'init_model', 'evals_result', 'learning_rates', 'callbacks'] log_fn_args_as_params(original, args, kwargs, unlogged_params) all_arg_names = inspect.getargspec(original)[0] # pylint: disable=W1505 num_pos_args = len(args) # adding a callback that records evaluation results. eval_results = [] callbacks_index = all_arg_names.index('callbacks') callback = record_eval_results(eval_results) if num_pos_args >= callbacks_index + 1: tmp_list = list(args) tmp_list[callbacks_index] += [callback] args = tuple(tmp_list) elif 'callbacks' in kwargs and kwargs['callbacks'] is not None: kwargs['callbacks'] += [callback] else: kwargs['callbacks'] = [callback] # training model model = original(*args, **kwargs) # logging metrics on each iteration. for idx, metrics in enumerate(eval_results): try_mlflow_log(mlflow.log_metrics, metrics, step=idx) # If early_stopping_rounds is present, logging metrics at the best iteration # as extra metrics with the max step + 1. early_stopping_index = all_arg_names.index('early_stopping_rounds') early_stopping = (num_pos_args >= early_stopping_index + 1 or 'early_stopping_rounds' in kwargs) if early_stopping: extra_step = len(eval_results) try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results)) # best_iteration is set even if training does not stop early. try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration) # iteration starts from 1 in LightGBM. try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1], step=extra_step) # logging feature importance as artifacts. for imp_type in ['split', 'gain']: features = model.feature_name() importance = model.feature_importance(importance_type=imp_type) try: log_feature_importance_plot(features, importance, imp_type) except Exception: # pylint: disable=broad-except _logger.exception('Failed to log feature importance plot. LightGBM autologging ' 'will ignore the failure and continue. Exception: ') imp = {ft: imp for ft, imp in zip(features, importance.tolist())} tmpdir = tempfile.mkdtemp() try: filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type)) with open(filepath, 'w') as f: json.dump(imp, f, indent=2) try_mlflow_log(mlflow.log_artifact, filepath) finally: shutil.rmtree(tmpdir) try_mlflow_log(log_model, model, artifact_path='model') if auto_end_run: try_mlflow_log(mlflow.end_run) return model settings = gorilla.Settings(allow_hit=True, store_hit=True) gorilla.apply(gorilla.Patch(lightgbm, 'train', train, settings=settings))
def autolog(): # pylint: disable=E0611 """ Enables automatic logging from Keras to MLflow. Autologging captures the following information: **Metrics** and **Parameters** - Training loss; validation loss; user-specified metrics - Metrics associated with the ``EarlyStopping`` callbacks: ``stopped_epoch``, ``restored_epoch``, ``restore_best_weight``, ``last_epoch``, etc - ``fit()`` or ``fit_generator()`` parameters; optimizer name; learning rate; epsilon - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``: ``min_delta``, ``patience``, ``baseline``, ``restore_best_weights``, etc **Artifacts** - Model summary on training start - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (Keras model) on training end .. code-block:: python :caption: Example import mlflow import mlflow.keras # Build, compile, enable autologging, and train your model keras_model = ... keras_model.compile(optimizer="rmsprop", loss="mse", metrics=["accuracy"]) # autolog your metrics, parameters, and model mlflow.keras.autolog() results = keras_model.fit( x_train, y_train, epochs=20, batch_size=128, validation_data=(x_val, y_val)) ``EarlyStopping Integration with Keras AutoLogging`` MLflow will detect if an ``EarlyStopping`` callback is used in a ``fit()`` or ``fit_generator()`` call, and if the ``restore_best_weights`` parameter is set to be ``True``, then MLflow will log the metrics associated with the restored model as a final, extra step. The epoch of the restored model will also be logged as the metric ``restored_epoch``. This allows for easy comparison between the actual metrics of the restored model and the metrics of other models. If ``restore_best_weights`` is set to be ``False``, then MLflow will not log an additional step. Regardless of ``restore_best_weights``, MLflow will also log ``stopped_epoch``, which indicates the epoch at which training stopped due to early stopping. If training does not end due to early stopping, then ``stopped_epoch`` will be logged as ``0``. MLflow will also log the parameters of the ``EarlyStopping`` callback, excluding ``mode`` and ``verbose``. """ import keras class __MLflowKerasCallback(keras.callbacks.Callback): """ Callback for auto-logging metrics and parameters. Records available logs after each epoch. Records model structural information as params when training begins """ def on_train_begin(self, logs=None): # pylint: disable=unused-argument try_mlflow_log(mlflow.log_param, 'num_layers', len(self.model.layers)) try_mlflow_log(mlflow.log_param, 'optimizer_name', type(self.model.optimizer).__name__) if hasattr(self.model.optimizer, 'lr'): lr = self.model.optimizer.lr if \ type(self.model.optimizer.lr) is float \ else keras.backend.eval(self.model.optimizer.lr) try_mlflow_log(mlflow.log_param, 'learning_rate', lr) if hasattr(self.model.optimizer, 'epsilon'): epsilon = self.model.optimizer.epsilon if \ type(self.model.optimizer.epsilon) is float \ else keras.backend.eval(self.model.optimizer.epsilon) try_mlflow_log(mlflow.log_param, 'epsilon', epsilon) sum_list = [] self.model.summary(print_fn=sum_list.append) summary = '\n'.join(sum_list) tempdir = tempfile.mkdtemp() try: summary_file = os.path.join(tempdir, "model_summary.txt") with open(summary_file, 'w') as f: f.write(summary) try_mlflow_log(mlflow.log_artifact, local_path=summary_file) finally: shutil.rmtree(tempdir) def on_epoch_end(self, epoch, logs=None): if not logs: return try_mlflow_log(mlflow.log_metrics, logs, step=epoch) def on_train_end(self, logs=None): try_mlflow_log(log_model, self.model, artifact_path='model') # As of Keras 2.4.0, Keras Callback implementations must define the following # methods indicating whether or not the callback overrides functions for # batch training/testing/inference def _implements_train_batch_hooks(self): return False def _implements_test_batch_hooks(self): return False def _implements_predict_batch_hooks(self): return False def _early_stop_check(callbacks): if LooseVersion(keras.__version__) < LooseVersion('2.3.0'): es_callback = keras.callbacks.EarlyStopping else: es_callback = keras.callbacks.callbacks.EarlyStopping for callback in callbacks: if isinstance(callback, es_callback): return callback return None def _log_early_stop_callback_params(callback): if callback: try: earlystopping_params = { 'monitor': callback.monitor, 'min_delta': callback.min_delta, 'patience': callback.patience, 'baseline': callback.baseline, 'restore_best_weights': callback.restore_best_weights } try_mlflow_log(mlflow.log_params, earlystopping_params) except Exception: # pylint: disable=W0703 return def _get_early_stop_callback_attrs(callback): try: return callback.stopped_epoch, callback.restore_best_weights, callback.patience except Exception: # pylint: disable=W0703 return None def _log_early_stop_callback_metrics(callback, history): if callback: callback_attrs = _get_early_stop_callback_attrs(callback) if callback_attrs is None: return stopped_epoch, restore_best_weights, patience = callback_attrs try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch) # Weights are restored only if early stopping occurs if stopped_epoch != 0 and restore_best_weights: restored_epoch = stopped_epoch - max(1, patience) try_mlflow_log(mlflow.log_metric, 'restored_epoch', restored_epoch) restored_metrics = { key: history.history[key][restored_epoch] for key in history.history.keys() } # Checking that a metric history exists metric_key = next(iter(history.history), None) if metric_key is not None: last_epoch = len(history.history[metric_key]) try_mlflow_log(mlflow.log_metrics, restored_metrics, step=last_epoch) def _run_and_log_function(self, original, args, kwargs, unlogged_params, callback_arg_index): if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) auto_end_run = True else: auto_end_run = False log_fn_args_as_params(original, args, kwargs, unlogged_params) early_stop_callback = None # Checking if the 'callback' argument of the function is set if len(args) > callback_arg_index: tmp_list = list(args) early_stop_callback = _early_stop_check( tmp_list[callback_arg_index]) tmp_list[callback_arg_index] += [__MLflowKerasCallback()] args = tuple(tmp_list) elif 'callbacks' in kwargs: early_stop_callback = _early_stop_check(kwargs['callbacks']) kwargs['callbacks'] += [__MLflowKerasCallback()] else: kwargs['callbacks'] = [__MLflowKerasCallback()] _log_early_stop_callback_params(early_stop_callback) history = original(self, *args, **kwargs) _log_early_stop_callback_metrics(early_stop_callback, history) if auto_end_run: try_mlflow_log(mlflow.end_run) return history @gorilla.patch(keras.Model) def fit(self, *args, **kwargs): original = gorilla.get_original_attribute(keras.Model, 'fit') unlogged_params = [ 'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose' ] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 5) @gorilla.patch(keras.Model) def fit_generator(self, *args, **kwargs): original = gorilla.get_original_attribute(keras.Model, 'fit_generator') unlogged_params = [ 'self', 'generator', 'callbacks', 'validation_data', 'verbose' ] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 4) settings = gorilla.Settings(allow_hit=True, store_hit=True) gorilla.apply(gorilla.Patch(keras.Model, 'fit', fit, settings=settings)) gorilla.apply( gorilla.Patch(keras.Model, 'fit_generator', fit_generator, settings=settings))
for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials): if not trial.last_result: continue best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid']) print('iter', self._iteration, 'top1_acc=%.3f' % best_top1_acc, cnts, end='\r') return original(self) patch = gorilla.Patch(ray.tune.trial_runner.TrialRunner, 'step', step_w_log, settings=gorilla.Settings(allow_hit=True)) gorilla.apply(patch) logger = get_logger('Fast AutoAugment') def _get_path(dataset, model, tag): return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'FastAutoAugment', 'models/%s_%s_%s.pt' % (dataset, model, tag)) # TODO # @ray.remote(num_gpus=4, max_calls=1) #TODO: change to num_gpus=1 ??? # @ray.remote def train_model(config, dataroot,
def autolog(): """ Enables automatic logging from LightGBM to MLflow. Logs the following. - parameters specified in `lightgbm.train`_. - metrics on each iteration (if ``valid_sets`` specified). - metrics at the best iteration (if ``early_stopping_rounds`` specified). - feature importance (both "split" and "gain"). - trained model. Note that the `scikit-learn API`_ is not supported. """ import lightgbm @gorilla.patch(lightgbm) def train(*args, **kwargs): def record_eval_results(eval_results): """ Create a callback function that records evaluation results. """ def callback(env): res = {} for data_name, eval_name, value, _ in env.evaluation_result_list: key = data_name + '-' + eval_name res[key] = value eval_results.append(res) return callback if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) auto_end_run = True else: auto_end_run = False original = gorilla.get_original_attribute(lightgbm, 'train') # logging booster params separately via mlflow.log_params to extract key/value pairs # and make it easier to compare them across runs. params = args[0] if len(args) > 0 else kwargs['params'] try_mlflow_log(mlflow.log_params, params) unlogged_params = ['params', 'train_set', 'valid_sets', 'valid_names', 'fobj', 'feval', 'init_model', 'evals_result', 'learning_rates', 'callbacks'] log_fn_args_as_params(original, args, kwargs, unlogged_params) all_arg_names = inspect.getargspec(original)[0] # pylint: disable=W1505 num_pos_args = len(args) # adding a callback that records evaluation results. eval_results = [] callbacks_index = all_arg_names.index('callbacks') callback = record_eval_results(eval_results) if num_pos_args >= callbacks_index + 1: tmp_list = list(args) tmp_list[callbacks_index] += [callback] args = tuple(tmp_list) elif 'callbacks' in kwargs and kwargs['callbacks'] is not None: kwargs['callbacks'] += [callback] else: kwargs['callbacks'] = [callback] # training model model = original(*args, **kwargs) # logging metrics on each iteration. for idx, metrics in enumerate(eval_results): try_mlflow_log(mlflow.log_metrics, metrics, step=idx) # If early_stopping_rounds is present, logging metrics at the best iteration # as extra metrics with the max step + 1. early_stopping_index = all_arg_names.index('early_stopping_rounds') early_stopping = (num_pos_args >= early_stopping_index + 1 or 'early_stopping_rounds' in kwargs) if early_stopping: extra_step = len(eval_results) try_mlflow_log(mlflow.log_metric, 'stopped_iteration', len(eval_results)) # best_iteration is set even if training does not stop early. try_mlflow_log(mlflow.log_metric, 'best_iteration', model.best_iteration) # iteration starts from 1 in LightGBM. try_mlflow_log(mlflow.log_metrics, eval_results[model.best_iteration - 1], step=extra_step) # logging feature importance as artifacts. for imp_type in ['split', 'gain']: features = model.feature_name() importance = model.feature_importance(importance_type=imp_type) imp = {ft: imp for ft, imp in zip(features, importance.tolist())} tmpdir = tempfile.mkdtemp() try: filepath = os.path.join(tmpdir, 'feature_importance_{}.json'.format(imp_type)) with open(filepath, 'w') as f: json.dump(imp, f) try_mlflow_log(mlflow.log_artifact, filepath) finally: shutil.rmtree(tmpdir) try_mlflow_log(log_model, model, artifact_path='model') if auto_end_run: try_mlflow_log(mlflow.end_run) return model settings = gorilla.Settings(allow_hit=True, store_hit=True) gorilla.apply(gorilla.Patch(lightgbm, 'train', train, settings=settings))
def autolog(every_n_iter=100): # pylint: disable=E0611 """ Enable automatic logging from TensorFlow to MLflow. If applicable, model checkpoints are logged as artifacts to a 'models' directory, along with any TensorBoard log data. Refer to the tracking documentation for information on what is logged with different TensorFlow workflows. :param every_n_iter: The frequency with which metrics should be logged. Defaults to 100. Ex: a value of 100 will log metrics at step 0, 100, 200, etc. """ global _LOG_EVERY_N_STEPS _LOG_EVERY_N_STEPS = every_n_iter if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'): warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return try: from tensorflow.python.summary.writer.event_file_writer import EventFileWriter from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2 from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary.writer.writer import FileWriter except ImportError: warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return @contextmanager def _manage_active_run(): if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) global _AUTOLOG_RUN_ID if mlflow.active_run( ) is not None: # defensive check in case `mlflow.start_run` fails _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id yield mlflow.active_run() if mlflow.active_run() is not None and mlflow.active_run( ).info.run_id == _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.end_run) @gorilla.patch(tensorflow.estimator.Estimator) def train(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'train') # Checking step and max_step parameters for logging if len(args) >= 3: try_mlflow_log(mlflow.log_param, 'steps', args[2]) if len(args) >= 4: try_mlflow_log(mlflow.log_param, 'max_steps', args[3]) if 'steps' in kwargs: try_mlflow_log(mlflow.log_param, 'steps', kwargs['steps']) if 'max_steps' in kwargs: try_mlflow_log(mlflow.log_param, 'max_steps', kwargs['max_steps']) result = original(self, *args, **kwargs) return result @gorilla.patch(tensorflow.estimator.Estimator) def export_saved_model(self, *args, **kwargs): auto_end = False if not mlflow.active_run(): global _AUTOLOG_RUN_ID if _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID) else: try_mlflow_log(mlflow.start_run) auto_end = True original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'export_saved_model') serialized = original(self, *args, **kwargs) try_mlflow_log(log_model, tf_saved_model_dir=serialized.decode('utf-8'), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key='predict', artifact_path='model') if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\ or auto_end: try_mlflow_log(mlflow.end_run) return serialized @gorilla.patch(tensorflow.estimator.Estimator) def export_savedmodel(self, *args, **kwargs): auto_end = False global _AUTOLOG_RUN_ID if not mlflow.active_run(): if _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID) else: try_mlflow_log(mlflow.start_run) auto_end = True original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'export_savedmodel') serialized = original(self, *args, **kwargs) try_mlflow_log(log_model, tf_saved_model_dir=serialized.decode('utf-8'), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key='predict', artifact_path='model') if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID)\ or auto_end: try_mlflow_log(mlflow.end_run) return serialized @gorilla.patch(tensorflow.keras.Model) def fit(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute(tensorflow.keras.Model, 'fit') # Checking if the 'callback' argument of fit() is set if len(args) >= 6: tmp_list = list(args) tmp_list[5], log_dir = _setup_callbacks(tmp_list[5]) args = tuple(tmp_list) elif 'callbacks' in kwargs: kwargs['callbacks'], log_dir = _setup_callbacks( kwargs['callbacks']) else: kwargs['callbacks'], log_dir = _setup_callbacks([]) result = original(self, *args, **kwargs) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir, artifact_path='tensorboard_logs') shutil.rmtree(log_dir) return result @gorilla.patch(tensorflow.keras.Model) def fit_generator(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute(tensorflow.keras.Model, 'fit_generator') # Checking if the 'callback' argument of fit() is set if len(args) >= 5: tmp_list = list(args) tmp_list[4], log_dir = _setup_callbacks(tmp_list[4]) args = tuple(tmp_list) elif 'callbacks' in kwargs: kwargs['callbacks'], log_dir = _setup_callbacks( kwargs['callbacks']) else: kwargs['callbacks'], log_dir = _setup_callbacks([]) result = original(self, *args, **kwargs) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir, artifact_path='tensorboard_logs') shutil.rmtree(log_dir) return result @gorilla.patch(EventFileWriter) def add_event(self, event): _log_event(event) original = gorilla.get_original_attribute(EventFileWriter, 'add_event') return original(self, event) @gorilla.patch(FileWriter) def add_summary(self, *args, **kwargs): original = gorilla.get_original_attribute(FileWriter, 'add_summary') result = original(self, *args, **kwargs) _flush_queue() return result settings = gorilla.Settings(allow_hit=True, store_hit=True) patches = [ gorilla.Patch(EventFileWriter, 'add_event', add_event, settings=settings), gorilla.Patch(EventFileWriterV2, 'add_event', add_event, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'train', train, settings=settings), gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings), gorilla.Patch(tensorflow.keras.Model, 'fit_generator', fit_generator, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'export_saved_model', export_saved_model, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'export_savedmodel', export_savedmodel, settings=settings), gorilla.Patch(FileWriter, 'add_summary', add_summary, settings=settings), ] for x in patches: gorilla.apply(x)
def autolog(): """ Enable automatic logging from Keras to MLflow. Logs loss and any other metrics specified in the fit function, and optimizer data as parameters. Model checkpoints are logged as artifacts to a 'models' directory. EarlyStopping Integration with Keras Automatic Logging MLflow will detect if an ``EarlyStopping`` callback is used in a ``fit()``/``fit_generator()`` call, and if the ``restore_best_weights`` parameter is set to be ``True``, then MLflow will log the metrics associated with the restored model as a final, extra step. The epoch of the restored model will also be logged as the metric ``restored_epoch``. This allows for easy comparison between the actual metrics of the restored model and the metrics of other models. If ``restore_best_weights`` is set to be ``False``, then MLflow will not log an additional step. Regardless of ``restore_best_weights``, MLflow will also log ``stopped_epoch``, which indicates the epoch at which training stopped due to early stopping. If training does not end due to early stopping, then ``stopped_epoch`` will be logged as ``0``. MLflow will also log the parameters of the EarlyStopping callback, excluding ``mode`` and ``verbose``. """ import keras class __MLflowKerasCallback(keras.callbacks.Callback): """ Callback for auto-logging metrics and parameters. Records available logs after each epoch. Records model structural information as params when training begins """ def on_train_begin(self, logs=None): # pylint: disable=unused-argument try_mlflow_log(mlflow.log_param, 'num_layers', len(self.model.layers)) try_mlflow_log(mlflow.log_param, 'optimizer_name', type(self.model.optimizer).__name__) if hasattr(self.model.optimizer, 'lr'): lr = self.model.optimizer.lr if \ type(self.model.optimizer.lr) is float \ else keras.backend.eval(self.model.optimizer.lr) try_mlflow_log(mlflow.log_param, 'learning_rate', lr) if hasattr(self.model.optimizer, 'epsilon'): epsilon = self.model.optimizer.epsilon if \ type(self.model.optimizer.epsilon) is float \ else keras.backend.eval(self.model.optimizer.epsilon) try_mlflow_log(mlflow.log_param, 'epsilon', epsilon) sum_list = [] self.model.summary(print_fn=sum_list.append) summary = '\n'.join(sum_list) tempdir = tempfile.mkdtemp() try: summary_file = os.path.join(tempdir, "model_summary.txt") with open(summary_file, 'w') as f: f.write(summary) try_mlflow_log(mlflow.log_artifact, local_path=summary_file) finally: shutil.rmtree(tempdir) def on_epoch_end(self, epoch, logs=None): if not logs: return try_mlflow_log(mlflow.log_metrics, logs, step=epoch) def on_train_end(self, logs=None): try_mlflow_log(log_model, self.model, artifact_path='model') def _early_stop_check(callbacks): if LooseVersion(keras.__version__) < LooseVersion('2.3.0'): es_callback = keras.callbacks.EarlyStopping else: es_callback = keras.callbacks.callbacks.EarlyStopping for callback in callbacks: if isinstance(callback, es_callback): return callback return None def _log_early_stop_callback_params(callback): if callback: try: earlystopping_params = { 'monitor': callback.monitor, 'min_delta': callback.min_delta, 'patience': callback.patience, 'baseline': callback.baseline, 'restore_best_weights': callback.restore_best_weights } try_mlflow_log(mlflow.log_params, earlystopping_params) except Exception: # pylint: disable=W0703 return def _get_early_stop_callback_attrs(callback): try: return callback.stopped_epoch, callback.restore_best_weights, callback.patience except Exception: # pylint: disable=W0703 return None def _log_early_stop_callback_metrics(callback, history): if callback: callback_attrs = _get_early_stop_callback_attrs(callback) if callback_attrs is None: return stopped_epoch, restore_best_weights, patience = callback_attrs try_mlflow_log(mlflow.log_metric, 'stopped_epoch', stopped_epoch) # Weights are restored only if early stopping occurs if stopped_epoch != 0 and restore_best_weights: restored_epoch = stopped_epoch - max(1, patience) try_mlflow_log(mlflow.log_metric, 'restored_epoch', restored_epoch) restored_metrics = { key: history.history[key][restored_epoch] for key in history.history.keys() } # Checking that a metric history exists metric_key = next(iter(history.history), None) if metric_key is not None: last_epoch = len(history.history[metric_key]) try_mlflow_log(mlflow.log_metrics, restored_metrics, step=last_epoch) def _run_and_log_function(self, original, args, kwargs, unlogged_params, callback_arg_index): if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) auto_end_run = True else: auto_end_run = False log_fn_args_as_params(original, args, kwargs, unlogged_params) early_stop_callback = None # Checking if the 'callback' argument of the function is set if len(args) > callback_arg_index: tmp_list = list(args) early_stop_callback = _early_stop_check( tmp_list[callback_arg_index]) tmp_list[callback_arg_index] += [__MLflowKerasCallback()] args = tuple(tmp_list) elif 'callbacks' in kwargs: early_stop_callback = _early_stop_check(kwargs['callbacks']) kwargs['callbacks'] += [__MLflowKerasCallback()] else: kwargs['callbacks'] = [__MLflowKerasCallback()] _log_early_stop_callback_params(early_stop_callback) history = original(self, *args, **kwargs) _log_early_stop_callback_metrics(early_stop_callback, history) if auto_end_run: try_mlflow_log(mlflow.end_run) return history @gorilla.patch(keras.Model) def fit(self, *args, **kwargs): original = gorilla.get_original_attribute(keras.Model, 'fit') unlogged_params = [ 'self', 'x', 'y', 'callbacks', 'validation_data', 'verbose' ] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 5) @gorilla.patch(keras.Model) def fit_generator(self, *args, **kwargs): original = gorilla.get_original_attribute(keras.Model, 'fit_generator') unlogged_params = [ 'self', 'generator', 'callbacks', 'validation_data', 'verbose' ] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 4) settings = gorilla.Settings(allow_hit=True, store_hit=True) gorilla.apply(gorilla.Patch(keras.Model, 'fit', fit, settings=settings)) gorilla.apply( gorilla.Patch(keras.Model, 'fit_generator', fit_generator, settings=settings))
def autolog(every_n_iter=100): # pylint: disable=E0611 """ Enables automatic logging from TensorFlow to MLflow. Note that autologging for ``tf.keras`` is handled by :py:func:`mlflow.tensorflow.autolog`, not :py:func:`mlflow.keras.autolog`. As an example, try running the `TensorFlow examples <https://github.com/mlflow/mlflow/tree/master/examples/tensorflow>`_. For each TensorFlow module, autologging captures the following information: **tf.keras** - **Metrics** and **Parameters** - Training loss; validation loss; user-specified metrics - ``fit()`` or ``fit_generator()`` parameters; optimizer name; learning rate; epsilon - **Artifacts** - Model summary on training start - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (Keras model) - TensorBoard logs on training end **tf.keras.callbacks.EarlyStopping** - **Metrics** and **Parameters** - Metrics from the ``EarlyStopping`` callbacks: ``stopped_epoch``, ``restored_epoch``, ``restore_best_weight``, etc - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``: ``min_delta``, ``patience``, ``baseline``, ``restore_best_weights``, etc **tf.estimator** - **Metrics** and **Parameters** - TensorBoard metrics: ``average_loss``, ``loss``, etc - Parameters ``steps`` and ``max_steps`` - **Artifacts** - `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ (TF saved model) on call to ``tf.estimator.export_saved_model`` **TensorFlow Core** - **Metrics** - All ``tf.summary.scalar`` calls Refer to the autologging tracking documentation for more information on `TensorFlow workflows <https://www.mlflow.org/docs/latest/tracking.html#tensorflow-and-keras-experimental>`_. :param every_n_iter: The frequency with which metrics should be logged. Defaults to 100. Ex: a value of 100 will log metrics at step 0, 100, 200, etc. """ import tensorflow global _LOG_EVERY_N_STEPS _LOG_EVERY_N_STEPS = every_n_iter if LooseVersion(tensorflow.__version__) < LooseVersion("1.12"): warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return try: from tensorflow.python.summary.writer.event_file_writer import EventFileWriter from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2 from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary.writer.writer import FileWriter except ImportError: warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return @contextmanager def _manage_active_run(): if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) global _AUTOLOG_RUN_ID if mlflow.active_run( ) is not None: # defensive check in case `mlflow.start_run` fails _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id yield mlflow.active_run() if mlflow.active_run() is not None and mlflow.active_run( ).info.run_id == _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.end_run) @gorilla.patch(tensorflow.estimator.Estimator) def train(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, "train") # Checking step and max_step parameters for logging if len(args) >= 3: try_mlflow_log(mlflow.log_param, "steps", args[2]) if len(args) >= 4: try_mlflow_log(mlflow.log_param, "max_steps", args[3]) if "steps" in kwargs: try_mlflow_log(mlflow.log_param, "steps", kwargs["steps"]) if "max_steps" in kwargs: try_mlflow_log(mlflow.log_param, "max_steps", kwargs["max_steps"]) result = original(self, *args, **kwargs) return result @gorilla.patch(tensorflow.estimator.Estimator) def export_saved_model(self, *args, **kwargs): auto_end = False if not mlflow.active_run(): global _AUTOLOG_RUN_ID if _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID) else: try_mlflow_log(mlflow.start_run) auto_end = True original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, "export_saved_model") serialized = original(self, *args, **kwargs) try_mlflow_log( log_model, tf_saved_model_dir=serialized.decode("utf-8"), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key="predict", artifact_path="model", ) if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID) or auto_end: try_mlflow_log(mlflow.end_run) return serialized @gorilla.patch(tensorflow.estimator.Estimator) def export_savedmodel(self, *args, **kwargs): auto_end = False global _AUTOLOG_RUN_ID if not mlflow.active_run(): if _AUTOLOG_RUN_ID: try_mlflow_log(mlflow.start_run, _AUTOLOG_RUN_ID) else: try_mlflow_log(mlflow.start_run) auto_end = True original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, "export_savedmodel") serialized = original(self, *args, **kwargs) try_mlflow_log( log_model, tf_saved_model_dir=serialized.decode("utf-8"), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key="predict", artifact_path="model", ) if (mlflow.active_run() is not None and mlflow.active_run().info.run_id == _AUTOLOG_RUN_ID) or auto_end: try_mlflow_log(mlflow.end_run) return serialized def _early_stop_check(callbacks): for callback in callbacks: if isinstance(callback, tensorflow.keras.callbacks.EarlyStopping): return callback return None def _log_early_stop_callback_params(callback): if callback: try: earlystopping_params = { "monitor": callback.monitor, "min_delta": callback.min_delta, "patience": callback.patience, "baseline": callback.baseline, "restore_best_weights": callback.restore_best_weights, } try_mlflow_log(mlflow.log_params, earlystopping_params) except Exception: # pylint: disable=W0703 return def _get_early_stop_callback_attrs(callback): try: return callback.stopped_epoch, callback.restore_best_weights, callback.patience except Exception: # pylint: disable=W0703 return None def _log_early_stop_callback_metrics(callback, history): if callback: callback_attrs = _get_early_stop_callback_attrs(callback) if callback_attrs is None: return stopped_epoch, restore_best_weights, patience = callback_attrs try_mlflow_log(mlflow.log_metric, "stopped_epoch", stopped_epoch) # Weights are restored only if early stopping occurs if stopped_epoch != 0 and restore_best_weights: restored_epoch = stopped_epoch - max(1, patience) try_mlflow_log(mlflow.log_metric, "restored_epoch", restored_epoch) restored_metrics = { key: history.history[key][restored_epoch] for key in history.history.keys() } # Metrics are logged as 'epoch_loss' and 'epoch_acc' in TF 1.X if LooseVersion( tensorflow.__version__) < LooseVersion("2.0.0"): if "loss" in restored_metrics: restored_metrics["epoch_loss"] = restored_metrics.pop( "loss") if "acc" in restored_metrics: restored_metrics["epoch_acc"] = restored_metrics.pop( "acc") # Checking that a metric history exists metric_key = next(iter(history.history), None) if metric_key is not None: last_epoch = len(history.history[metric_key]) try_mlflow_log(mlflow.log_metrics, restored_metrics, step=last_epoch) @gorilla.patch(tensorflow.keras.Model) def fit(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute(tensorflow.keras.Model, "fit") unlogged_params = [ "self", "x", "y", "callbacks", "validation_data", "verbose" ] log_fn_args_as_params(original, args, kwargs, unlogged_params) early_stop_callback = None # Checking if the 'callback' argument of fit() is set if len(args) >= 6: tmp_list = list(args) early_stop_callback = _early_stop_check(tmp_list[5]) tmp_list[5], log_dir = _setup_callbacks(tmp_list[5]) args = tuple(tmp_list) elif "callbacks" in kwargs: early_stop_callback = _early_stop_check(kwargs["callbacks"]) kwargs["callbacks"], log_dir = _setup_callbacks( kwargs["callbacks"]) else: kwargs["callbacks"], log_dir = _setup_callbacks([]) _log_early_stop_callback_params(early_stop_callback) history = original(self, *args, **kwargs) _log_early_stop_callback_metrics(early_stop_callback, history) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir.location, artifact_path="tensorboard_logs") if log_dir.is_temp: shutil.rmtree(log_dir.location) return history @gorilla.patch(tensorflow.keras.Model) def fit_generator(self, *args, **kwargs): with _manage_active_run(): original = gorilla.get_original_attribute(tensorflow.keras.Model, "fit_generator") unlogged_params = [ "self", "generator", "callbacks", "validation_data", "verbose" ] log_fn_args_as_params(original, args, kwargs, unlogged_params) # Checking if the 'callback' argument of fit() is set if len(args) >= 5: tmp_list = list(args) tmp_list[4], log_dir = _setup_callbacks(tmp_list[4]) args = tuple(tmp_list) elif "callbacks" in kwargs: kwargs["callbacks"], log_dir = _setup_callbacks( kwargs["callbacks"]) else: kwargs["callbacks"], log_dir = _setup_callbacks([]) result = original(self, *args, **kwargs) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir.location, artifact_path="tensorboard_logs") if log_dir.is_temp: shutil.rmtree(log_dir.location) return result @gorilla.patch(EventFileWriter) def add_event(self, event): _log_event(event) original = gorilla.get_original_attribute(EventFileWriter, "add_event") return original(self, event) @gorilla.patch(FileWriter) def add_summary(self, *args, **kwargs): original = gorilla.get_original_attribute(FileWriter, "add_summary") result = original(self, *args, **kwargs) _flush_queue() return result settings = gorilla.Settings(allow_hit=True, store_hit=True) patches = [ gorilla.Patch(EventFileWriter, "add_event", add_event, settings=settings), gorilla.Patch(EventFileWriterV2, "add_event", add_event, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, "train", train, settings=settings), gorilla.Patch(tensorflow.keras.Model, "fit", fit, settings=settings), gorilla.Patch(tensorflow.keras.Model, "fit_generator", fit_generator, settings=settings), gorilla.Patch( tensorflow.estimator.Estimator, "export_saved_model", export_saved_model, settings=settings, ), gorilla.Patch( tensorflow.estimator.Estimator, "export_savedmodel", export_savedmodel, settings=settings, ), gorilla.Patch(FileWriter, "add_summary", add_summary, settings=settings), ] for x in patches: gorilla.apply(x)
def autolog(every_n_iter=100): # pylint: disable=E0611 """ Enable automatic logging from TensorFlow to MLflow. If applicable, model checkpoints are logged as artifacts to a 'models' directory, along with any TensorBoard log data. Refer to the tracking documentation for information on what is logged with different TensorFlow workflows. :param every_n_iter: The frequency with which metrics should be logged. Defaults to 100. Ex: a value of 100 will log metrics at step 0, 100, 200, etc. """ global _LOG_EVERY_N_STEPS _LOG_EVERY_N_STEPS = every_n_iter if LooseVersion(tensorflow.__version__) < LooseVersion('1.12'): warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return try: from tensorflow.python.summary.writer.event_file_writer import EventFileWriter from tensorflow.python.summary.writer.event_file_writer_v2 import EventFileWriterV2 from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary.writer.writer import FileWriter except ImportError: warnings.warn("Could not log to MLflow. Only TensorFlow versions" + "1.12 <= v <= 2.0.0 are supported.") return @gorilla.patch(tensorflow.estimator.Estimator) def export_saved_model(self, *args, **kwargs): original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'export_saved_model') serialized = original(self, *args, **kwargs) try_mlflow_log(log_model, tf_saved_model_dir=serialized.decode('utf-8'), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key='predict', artifact_path='model') return serialized @gorilla.patch(tensorflow.estimator.Estimator) def export_savedmodel(self, *args, **kwargs): original = gorilla.get_original_attribute( tensorflow.estimator.Estimator, 'export_savedmodel') serialized = original(self, *args, **kwargs) try_mlflow_log(log_model, tf_saved_model_dir=serialized.decode('utf-8'), tf_meta_graph_tags=[tag_constants.SERVING], tf_signature_def_key='predict', artifact_path='model') return serialized @gorilla.patch(tensorflow.keras.Model) def fit(self, *args, **kwargs): original = gorilla.get_original_attribute(tensorflow.keras.Model, 'fit') # Checking if the 'callback' argument of fit() is set if len(args) >= 6: l = list(args) l[5], log_dir = _setup_callbacks(l[5]) args = tuple(l) elif 'callbacks' in kwargs: kwargs['callbacks'], log_dir = _setup_callbacks( kwargs['callbacks']) else: kwargs['callbacks'], log_dir = _setup_callbacks([]) result = original(self, *args, **kwargs) _flush_queue() _log_artifacts_with_warning(local_dir=log_dir, artifact_path='tensorboard_logs') shutil.rmtree(log_dir) return result @gorilla.patch(EventFileWriter) def add_event(self, event): _log_event(event) original = gorilla.get_original_attribute(EventFileWriter, 'add_event') return original(self, event) @gorilla.patch(FileWriter) def add_summary(self, *args, **kwargs): original = gorilla.get_original_attribute(FileWriter, 'add_summary') result = original(self, *args, **kwargs) _flush_queue() return result settings = gorilla.Settings(allow_hit=True, store_hit=True) patches = [ gorilla.Patch(EventFileWriter, 'add_event', add_event, settings=settings), gorilla.Patch(EventFileWriterV2, 'add_event', add_event, settings=settings), gorilla.Patch(tensorflow.keras.Model, 'fit', fit, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'export_saved_model', export_saved_model, settings=settings), gorilla.Patch(tensorflow.estimator.Estimator, 'export_savedmodel', export_savedmodel, settings=settings), gorilla.Patch(FileWriter, 'add_summary', add_summary, settings=settings), ] for x in patches: gorilla.apply(x)
def autolog(): """ Enable automatic logging from Fastai to MLflow. Logs loss and any other metrics specified in the fit function, and optimizer data as parameters. Model checkpoints are logged as artifacts to a 'models' directory. MLflow will also log the parameters of the EarlyStopping and OneCycleScheduler callbacks """ from fastai.basic_train import LearnerCallback, Learner from fastai.callbacks.hooks import model_summary, layers_info from fastai.callbacks import EarlyStoppingCallback, OneCycleScheduler class __MLflowFastaiCallback(LearnerCallback): """ Callback for auto-logging metrics and parameters. Records model structural information as params when training begins """ def __init__( self, learner, ): super().__init__(learner) self.learner = learner self.opt = self.learn.opt self.metrics_names = ["train_loss", "valid_loss" ] + [o.__name__ for o in learner.metrics] def on_epoch_end(self, **kwargs): """ Log loss and other metrics values after each epoch """ if kwargs["smooth_loss"] is None or kwargs["last_metrics"] is None: return epoch = kwargs["epoch"] metrics = [kwargs["smooth_loss"]] + kwargs["last_metrics"] metrics = map(float, metrics) metrics = dict(zip(self.metrics_names, metrics)) try_mlflow_log(mlflow.log_metrics, metrics, step=epoch) def on_train_begin(self, **kwargs): info = layers_info(self.learner) try_mlflow_log(mlflow.log_param, "num_layers", len(info)) try_mlflow_log(mlflow.log_param, "opt_func", self.opt_func.func.__name__) if hasattr(self.opt, "true_wd"): try_mlflow_log(mlflow.log_param, "true_wd", self.opt.true_wd) if hasattr(self.opt, "bn_wd"): try_mlflow_log(mlflow.log_param, "bn_wd", self.opt.bn_wd) if hasattr(self.opt, "train_bn"): try_mlflow_log(mlflow.log_param, "train_bn", self.train_bn) summary = model_summary(self.learner) try_mlflow_log(mlflow.set_tag, "model_summary", summary) tempdir = tempfile.mkdtemp() try: summary_file = os.path.join(tempdir, "model_summary.txt") with open(summary_file, "w") as f: f.write(summary) try_mlflow_log(mlflow.log_artifact, local_path=summary_file) finally: shutil.rmtree(tempdir) def on_train_end(self, **kwargs): try_mlflow_log(log_model, self.learner, artifact_path="model") def _find_callback_of_type(callback_type, callbacks): for callback in callbacks: if isinstance(callback, callback_type): return callback return None def _log_early_stop_callback_params(callback): if callback: try: earlystopping_params = { "early_stop_monitor": callback.monitor, "early_stop_min_delta": callback.min_delta, "early_stop_patience": callback.patience, "early_stop_mode": callback.mode, } try_mlflow_log(mlflow.log_params, earlystopping_params) except Exception: # pylint: disable=W0703 return def _log_one_cycle_callback_params(callback): if callback: try: params = { "lr_max": callback.lr_max, "div_factor": callback.div_factor, "pct_start": callback.pct_start, "final_div": callback.final_div, "tot_epochs": callback.tot_epochs, "start_epoch": callback.start_epoch, "moms": callback.moms, } try_mlflow_log(mlflow.log_params, params) except Exception: # pylint: disable=W0703 return def _run_and_log_function(self, original, args, kwargs, unlogged_params, callback_arg_index): if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) auto_end_run = True else: auto_end_run = False log_fn_args_as_params(original, [self] + list(args), kwargs, unlogged_params) callbacks = [cb(self) for cb in self.callback_fns] + (self.callbacks or []) # Checking if the 'callback' argument of the function is set if len(args) > callback_arg_index: tmp_list = list(args) callbacks += list(args[callback_arg_index]) tmp_list[callback_arg_index] += [__MLflowFastaiCallback(self)] args = tuple(tmp_list) elif "callbacks" in kwargs: callbacks += list(kwargs["callbacks"]) kwargs["callbacks"] += [__MLflowFastaiCallback(self)] else: kwargs["callbacks"] = [__MLflowFastaiCallback(self)] early_stop_callback = _find_callback_of_type(EarlyStoppingCallback, callbacks) one_cycle_callback = _find_callback_of_type(OneCycleScheduler, callbacks) _log_early_stop_callback_params(early_stop_callback) _log_one_cycle_callback_params(one_cycle_callback) result = original(self, *args, **kwargs) if auto_end_run: try_mlflow_log(mlflow.end_run) return result @gorilla.patch(Learner) def fit(self, *args, **kwargs): original = gorilla.get_original_attribute(Learner, "fit") unlogged_params = ["self", "callbacks", "learner"] return _run_and_log_function(self, original, args, kwargs, unlogged_params, 3) settings = gorilla.Settings(allow_hit=True, store_hit=True) gorilla.apply(gorilla.Patch(Learner, "fit", fit, settings=settings))
def test_apply_patch_no_hit(self): name = 'dummy' settings = gorilla.Settings() source_paths = [''] + _list_attribute_paths(_frommodule) target_paths = _list_attribute_paths(_tomodule) branch_count = 0 # Retrieve the destinations in two passes instead of directly using a # set in order to preserve the ordering. seen = set() destination_paths = [ _split_attribute_path(path)[0] for path in target_paths ] destination_paths = [ path for path in destination_paths if path not in seen and seen.add(path) is None ] combinations = itertools.product(source_paths, destination_paths) for source_path, destination_path in combinations: self.setUp() destination = _get_attribute_from_path(_tomodule, destination_path) obj = _get_attribute_from_path(_frommodule, source_path) patch = gorilla.Patch(destination, name, obj, settings=settings) gorilla.apply(patch) self.assertIs( destination, _get_attribute_from_path(_tomodule, destination_path)) result = gorilla.get_attribute(destination, name) self.assertIs(result, obj) if source_path == '': branch_count += 1 self.assertEqual(result.global_variable, "frommodule.global_variable") self.assertEqual( result.function(), "frommodule.function (frommodule.Class.STATIC_VALUE)") self.assertEqual(result.Class.STATIC_VALUE, "frommodule.Class.STATIC_VALUE") self.assertEqual(result.Class.Inner.STATIC_VALUE, "frommodule.Class.Inner.STATIC_VALUE") self.assertEqual(result.Parent.STATIC_VALUE, "frommodule.Parent.STATIC_VALUE") self.assertEqual(result.Child.STATIC_VALUE, "frommodule.Parent.STATIC_VALUE") elif source_path == 'global_variable': branch_count += 1 self.assertEqual(result, "frommodule.global_variable") elif source_path == 'function': branch_count += 1 self.assertEqual( result(), "frommodule.function (frommodule.Class.STATIC_VALUE)") elif source_path == 'unbound_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.unbound_method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)" % ((_CLS_REFERENCES[destination_path], ) * 2)) elif source_path == 'unbound_class_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.unbound_class_method (tomodule.%s.STATIC_VALUE)" % (_CLS_REFERENCES[destination_path], )) elif source_path == 'unbound_static_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.unbound_static_method (frommodule.Class.STATIC_VALUE)" ) elif source_path == 'Class': branch_count += 1 self.assertEqual(result.STATIC_VALUE, "frommodule.Class.STATIC_VALUE") elif source_path == 'Class.Inner': branch_count += 1 self.assertEqual(result.STATIC_VALUE, "frommodule.Class.Inner.STATIC_VALUE") elif source_path == 'Class.value': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 instance = destination() self.assertEqual( getattr(instance, name), "frommodule.Class.value.getter (tomodule.%s.instance_value)" % (_CLS_REFERENCES[destination_path], )) setattr(instance, name, 'hello') self.assertEqual(getattr(instance, name), "frommodule.Class.value.getter (hello)") elif source_path == 'Class.method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.Class.method (tomodule.%s.STATIC_VALUE, tomodule.%s.instance_value)" % ((_CLS_REFERENCES[destination_path], ) * 2)) elif source_path == 'Class.class_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.Class.class_method (tomodule.%s.STATIC_VALUE)" % (_CLS_REFERENCES[destination_path], )) elif source_path == 'Class.static_method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination, name)(), "frommodule.Class.static_method (frommodule.Class.STATIC_VALUE)" ) elif source_path == 'Parent': branch_count += 1 self.assertEqual(result.__slots__, ('instance_value', 'parent_value', 'to_value', 'from_value')) self.assertEqual(result().parent_value, "frommodule.Parent.parent_value") self.assertEqual( result().method(), "frommodule.Parent.method (frommodule.Parent.instance_value)" ) elif source_path == 'Parent.method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.Parent.method (tomodule.%s.instance_value)" % (_CLS_REFERENCES[destination_path], )) elif source_path == 'Child': branch_count += 1 self.assertEqual(result.__slots__, ('child_value', )) self.assertEqual(result().parent_value, "frommodule.Parent.parent_value") self.assertEqual(result().child_value, "frommodule.Child.child_value") self.assertEqual( result().method(), "frommodule.Parent.method (frommodule.Parent.instance_value)" ) elif source_path == 'Child.method': branch_count += 1 if destination_path in _CLS_REFERENCES: branch_count += 1 self.assertEqual( getattr(destination(), name)(), "frommodule.Parent.method (tomodule.%s.instance_value)" % (_CLS_REFERENCES[destination_path], )) self.tearDown() # Make sure that all test branches are covered. self.assertEqual(branch_count, 116)