Ejemplo n.º 1
0
    def patched_evaluate(original, self, *args, **kwargs):
        if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics():
            with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics(
            ):
                metric = original(self, *args, **kwargs)

            if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(metric):
                params = get_method_call_arg_value(1, "params", None, args,
                                                   kwargs)
                # we need generate evaluator param map so we call `self.copy(params)` to construct
                # an evaluator with the extra evaluation params.
                evaluator = self.copy(params) if params is not None else self
                metric_name = evaluator.getMetricName()
                evaluator_info = _AUTOLOGGING_METRICS_MANAGER.gen_evaluator_info(
                    evaluator)

                pred_result_dataset = get_method_call_arg_value(
                    0, "dataset", None, args, kwargs)
                (
                    run_id,
                    dataset_name,
                ) = _AUTOLOGGING_METRICS_MANAGER.get_run_id_and_dataset_name_for_evaluator_call(
                    pred_result_dataset)
                if run_id and dataset_name:
                    metric_key = _AUTOLOGGING_METRICS_MANAGER.register_evaluator_call(
                        run_id, metric_name, dataset_name, evaluator_info)
                    _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric(
                        run_id, metric_key, metric)
            return metric
        else:
            return original(self, *args, **kwargs)
Ejemplo n.º 2
0
    def fit_mlflow(original, self, *args, **kwargs):
        params = get_method_call_arg_value(1, "params", None, args, kwargs)

        # Do not perform autologging on direct calls to fit() for featurizers.
        # Note that featurizers will be autologged when they're fit as part of a Pipeline.
        if _get_fully_qualified_class_name(self).startswith(
                "pyspark.ml.feature."):
            return original(self, *args, **kwargs)
        elif isinstance(params, (list, tuple)):
            # skip the case params is a list or tuple, this case it will call
            # fitMultiple and return a model iterator
            _logger.warning(
                _get_warning_msg_for_fit_call_with_a_list_of_params(self))
            return original(self, *args, **kwargs)
        else:
            # we need generate estimator param map so we call `self.copy(params)` to construct
            # an estimator with the extra params.
            from pyspark.storagelevel import StorageLevel

            estimator = self.copy(params) if params is not None else self
            _log_pretraining_metadata(estimator, params)
            input_training_df = args[0].persist(StorageLevel.MEMORY_AND_DISK)
            spark_model = original(self, *args, **kwargs)
            _log_posttraining_metadata(estimator, spark_model, params,
                                       input_training_df)
            input_training_df.unpersist()

            return spark_model
Ejemplo n.º 3
0
 def patched_transform(original, self, *args, **kwargs):
     run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self)
     if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics(
     ) and run_id:
         predict_result = original(self, *args, **kwargs)
         eval_dataset = get_method_call_arg_value(0, "dataset", None, args,
                                                  kwargs)
         eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset(
             self, eval_dataset)
         _AUTOLOGGING_METRICS_MANAGER.register_prediction_result(
             run_id, eval_dataset_name, predict_result)
         return predict_result
     else:
         return original(self, *args, **kwargs)
Ejemplo n.º 4
0
def test_get_method_call_arg_value():
    # suppose we call on a method defined like: `def f1(a, b=3, *, c=4, e=5)`
    assert 2 == get_method_call_arg_value(1, "b", 3, [1, 2], {})
    assert 3 == get_method_call_arg_value(1, "b", 3, [1], {})
    assert 2 == get_method_call_arg_value(1, "b", 3, [1], {"b": 2})