예제 #1
0
def format_objective(objective, direction=None):
    """Formats objective to a list of oracle_module.Objective.

  Arguments:
    objective: If a string, the direction of the optimization (min or max) will
      be inferred.
    direction: Optional. e.g. 'min' or 'max'.

  Returns:
    A list of oracle_module.Objective.

  Raises:
    TypeError: indicates wrong objective format.
  """
    if isinstance(objective, oracle_module.Objective):
        return [objective]
    if isinstance(objective, str):
        if direction:
            return [oracle_module.Objective(objective, direction)]
        return [
            oracle_module.Objective(
                objective, metrics_tracking.infer_metric_direction(objective))
        ]
    if isinstance(objective, list):
        if isinstance(objective[0], oracle_module.Objective):
            return objective
        if isinstance(objective[0], str):
            return [
                oracle_module.Objective(
                    m, metrics_tracking.infer_metric_direction(m))
                for m in objective
            ]
    raise TypeError(
        'Objective should be either string or oracle_module.Objective, found {}'
        .format(objective))
예제 #2
0
    def update_trial(self,
                     trial_id: Text,
                     metrics: Mapping[Text, Union[int, float]],
                     step: int = 0):
        """Used by a worker to report the status of a trial."""
        # Constructs the measurement.
        # Adds the measurement of the objective functions to a trial.
        elapsed_secs = time.time() - self._start_time
        if elapsed_secs < 0 or step < 0:
            raise ValueError(
                "Both elapsed_secs and step must be non-negative.")
        if elapsed_secs == 0 and step == 0:
            raise ValueError(
                "At least one of {elapsed_secs, step} must be positive")

        metric_list = []
        for ob in self._get_objective():
            if ob.name not in metrics:
                ob_name = ob.name.replace("val_", "")
                if ob_name in metrics:
                    metric_list.append({
                        "metric": ob_name,
                        "value": float(metrics.get(ob_name))
                    })
                tf.get_logger().info(
                    'Objective "{}" is not found in metrics.'.format(ob.name))
                continue

            metric_list.append({
                "metric": ob.name,
                "value": float(metrics.get(ob.name))
            })

        self.service.report_intermediate_objective_value(
            step, elapsed_secs, metric_list, trial_id)

        # Ensure metrics of trials are updated locally.
        keras_tuner_trial = self.trials[trial_id]
        for metric_name, metric_value in metrics.items():
            if not keras_tuner_trial.metrics.exists(metric_name):
                direction = metrics_tracking.infer_metric_direction(
                    metric_name)
                keras_tuner_trial.metrics.register(metric_name,
                                                   direction=direction)
            keras_tuner_trial.metrics.update(metric_name,
                                             metric_value,
                                             step=step)

        # Checks whether a trial should stop or not.
        tf.get_logger().info("UpdateTrial: polls the stop decision.")
        should_stop = self.service.should_trial_stop(trial_id)

        if should_stop:
            keras_tuner_trial.status = trial_module.TrialStatus.STOPPED
        return keras_tuner_trial.status
예제 #3
0
def _format_objective(objective):
    if objective is None:
        return Objective("default_objective", "min")
    if isinstance(objective, list):
        return [_format_objective(obj) for obj in objective]
    if isinstance(objective, Objective):
        return objective
    if isinstance(objective, str):
        direction = metrics_tracking.infer_metric_direction(objective)
        if direction is None:
            error_msg = (
                'Could not infer optimization direction ("min" or "max") '
                'for unknown metric "{obj}". Please specify the objective  as'
                "a `keras_tuner.Objective`, for example `keras_tuner.Objective("
                '"{obj}", direction="min")`.')
            error_msg = error_msg.format(obj=objective)
            raise ValueError(error_msg)
        return Objective(name=objective, direction=direction)
    else:
        raise ValueError("`objective` not understood, expected str or "
                         "`Objective` object, found: {}".format(objective))
def test_metric_direction_inference():
    # Test min metrics.
    assert metrics_tracking.infer_metric_direction("MAE") == "min"
    assert (metrics_tracking.infer_metric_direction(
        metrics.binary_crossentropy) == "min")
    assert metrics_tracking.infer_metric_direction(
        metrics.FalsePositives()) == "min"

    # All losses in keras.losses are considered as 'min'.
    assert metrics_tracking.infer_metric_direction("squared_hinge") == "min"
    assert metrics_tracking.infer_metric_direction(losses.hinge) == "min"
    assert (metrics_tracking.infer_metric_direction(
        losses.CategoricalCrossentropy()) == "min")

    # Test max metrics.
    assert metrics_tracking.infer_metric_direction("binary_accuracy") == "max"
    assert (metrics_tracking.infer_metric_direction(
        metrics.categorical_accuracy) == "max")
    assert metrics_tracking.infer_metric_direction(
        metrics.Precision()) == "max"

    # Test unknown metrics.
    assert metrics_tracking.infer_metric_direction("my_metric") is None

    def my_metric_fn(x, y):
        return x

    assert metrics_tracking.infer_metric_direction(my_metric_fn) is None

    class MyMetric(metrics.Metric):
        def update_state(self, x, y):
            return 1

        def result(self):
            return 1

    assert metrics_tracking.infer_metric_direction(MyMetric()) is None

    # Test special cases.
    assert metrics_tracking.infer_metric_direction("loss") == "min"
    assert metrics_tracking.infer_metric_direction("acc") == "max"
    assert metrics_tracking.infer_metric_direction("val_acc") == "max"
    assert metrics_tracking.infer_metric_direction("crossentropy") == "min"
    assert metrics_tracking.infer_metric_direction("ce") == "min"
    assert metrics_tracking.infer_metric_direction("weighted_acc") == "max"
    assert metrics_tracking.infer_metric_direction("val_weighted_ce") == "min"
    assert (metrics_tracking.infer_metric_direction("weighted_binary_accuracy")
            == "max")