def test_callback_is_picklable():
    cb = __MLflowTfKeras2Callback(
        log_models=True,
        metrics_logger=BatchMetricsLogger(run_id="1234"),
        log_every_n_steps=5)
    pickle.dumps(cb)

    tb = _TensorBoard()
    pickle.dumps(tb)
Example #2
0
    def _run_and_log_function(self, original, args, kwargs):
        """
        This method would be called from patched fit method and
        It adds the custom callback class into callback list.
        """

        # The run_id is not set here. Rather it will be retrieved from
        # the current mlfow run's training session inside of BatchMetricsLogger.
        metrics_logger = BatchMetricsLogger()
        __MLflowPLCallback = getPLCallback(log_models, metrics_logger)
        if not any(isinstance(callbacks, __MLflowPLCallback) for callbacks in self.callbacks):
            self.callbacks += [__MLflowPLCallback()]
        result = original(self, *args, **kwargs)

        return result
def patched_fit(original, self, *args, **kwargs):
    run_id = mlflow.active_run().info.run_id
    tracking_uri = mlflow.get_tracking_uri()
    client = MlflowAutologgingQueueingClient(tracking_uri)
    metrics_logger = BatchMetricsLogger(run_id, tracking_uri)

    log_models = get_autologging_config(mlflow.paddle.FLAVOR_NAME,
                                        "log_models", True)
    log_every_n_epoch = get_autologging_config(mlflow.paddle.FLAVOR_NAME,
                                               "log_every_n_epoch", 1)

    early_stop_callback = None
    mlflow_callback = __MLflowPaddleCallback(client, metrics_logger, run_id,
                                             log_models, log_every_n_epoch)
    if "callbacks" in kwargs:
        callbacks = kwargs["callbacks"]
        for callback in callbacks:
            if isinstance(callback, paddle.callbacks.EarlyStopping):
                early_stop_callback = callback
                _log_early_stop_params(early_stop_callback, client, run_id)
                break
        kwargs["callbacks"].append(mlflow_callback)
    else:
        kwargs["callbacks"] = [mlflow_callback]
    client.flush(synchronous=False)

    result = original(self, *args, **kwargs)

    if early_stop_callback is not None:
        _log_early_stop_metrics(early_stop_callback, client, run_id)

    mlflow.log_text(str(self.summary()), "model_summary.txt")

    if log_models:
        mlflow.paddle.log_model(pd_model=self, artifact_path="model")

    client.flush(synchronous=True)

    return result
Example #4
0
def test_batch_metrics_logger_flush_logs_to_mlflow(start_run):
    run_id = mlflow.active_run().info.run_id

    # Need to patch _should_flush() to return False, so that we can manually flush the logger
    with mock.patch(
        "mlflow.utils.autologging_utils.BatchMetricsLogger._should_flush", return_value=False
    ):
        metrics_logger = BatchMetricsLogger(run_id)
        metrics_logger.record_metrics({"my_metric": 10}, 5)

        # Recorded metrics should not be logged to mlflow run before flushing BatchMetricsLogger
        metrics_on_run = mlflow.tracking.MlflowClient().get_run(run_id).data.metrics
        assert "my_metric" not in metrics_on_run

        metrics_logger.flush()

        # Recorded metric should be logged to mlflow run after flushing BatchMetricsLogger
        metrics_on_run = mlflow.tracking.MlflowClient().get_run(run_id).data.metrics
        assert "my_metric" in metrics_on_run
        assert metrics_on_run["my_metric"] == 10
Example #5
0
def test_callback_is_callable():
    cb = __MLflowGluonCallback(
        log_models=True, metrics_logger=BatchMetricsLogger(run_id="1234"))
    pickle.dumps(cb)
Example #6
0
def patched_fit(original, self, *args, **kwargs):
    """
    A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the
    following parameters, metrics and artifacts:

    - Training epochs
    - Optimizer parameters
    - `EarlyStoppingCallback`_ parameters
    - Metrics stored in `trainer.callback_metrics`
    - Model checkpoints
    - Trained model

    .. _EarlyStoppingCallback:
        https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
    """
    run_id = mlflow.active_run().info.run_id
    tracking_uri = mlflow.get_tracking_uri()
    client = MlflowAutologgingQueueingClient(tracking_uri)
    metrics_logger = BatchMetricsLogger(run_id, tracking_uri)

    log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME,
                                        "log_models", True)
    log_every_n_epoch = get_autologging_config(mlflow.pytorch.FLAVOR_NAME,
                                               "log_every_n_epoch", 1)
    log_every_n_step = get_autologging_config(mlflow.pytorch.FLAVOR_NAME,
                                              "log_every_n_step", None)

    early_stop_callback = None
    for callback in self.callbacks:
        if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping):
            early_stop_callback = callback
            _log_early_stop_params(early_stop_callback, client, run_id)

    if not any(
            isinstance(callbacks, __MLflowPLCallback)
            for callbacks in self.callbacks):
        self.callbacks += [
            __MLflowPLCallback(client, metrics_logger, run_id, log_models,
                               log_every_n_epoch, log_every_n_step)
        ]

    client.flush(synchronous=False)

    result = original(self, *args, **kwargs)

    if early_stop_callback is not None:
        _log_early_stop_metrics(early_stop_callback, client, run_id)

    if Version(pl.__version__) < Version("1.4.0"):
        summary = str(ModelSummary(self.model, mode="full"))
    else:
        summary = str(ModelSummary(self.model, max_depth=-1))

    tempdir = tempfile.mkdtemp()
    try:
        summary_file = os.path.join(tempdir, "model_summary.txt")
        with open(summary_file, "w") as f:
            f.write(summary)

        mlflow.log_artifact(local_path=summary_file)
    finally:
        shutil.rmtree(tempdir)

    if log_models:
        registered_model_name = get_autologging_config(
            mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None)
        mlflow.pytorch.log_model(
            pytorch_model=self.model,
            artifact_path="model",
            registered_model_name=registered_model_name,
        )

        if early_stop_callback is not None and self.checkpoint_callback.best_model_path:
            mlflow.log_artifact(
                local_path=self.checkpoint_callback.best_model_path,
                artifact_path="restored_model_checkpoint",
            )

    client.flush(synchronous=True)

    return result
def test_callback_class_is_pickable():
    from mlflow.xgboost._autolog import AutologCallback

    cb = AutologCallback(BatchMetricsLogger(run_id="1234"), eval_results={})
    pickle.dumps(cb)
def test_callback_func_is_pickable():
    cb = picklable_exception_safe_function(
        functools.partial(autolog_callback,
                          BatchMetricsLogger(run_id="1234"),
                          eval_results={}))
    pickle.dumps(cb)
Example #9
0
def test_callback_is_picklable():
    cb = __MlflowFastaiCallback(BatchMetricsLogger(run_id="1234"),
                                log_models=True,
                                is_fine_tune=False)
    pickle.dumps(cb)