Esempio n. 1
0
def create_ensemble_metrics(metric_fn,
                            use_tpu=False,
                            features=None,
                            labels=None,
                            estimator_spec=None,
                            architecture=None):
    """Creates an instance of the _EnsembleMetrics class.

  Args:
    metric_fn: A function which should obey the following signature:
    - Args: can only have following three arguments in any order:
        * predictions: Predictions `Tensor` or dict of `Tensor` created by given
          `Head`.
        * features: Input `dict` of `Tensor` objects created by `input_fn` which
          is given to `estimator.evaluate` as an argument.
        * labels:  Labels `Tensor` or dict of `Tensor` (for multi-head) created
          by `input_fn` which is given to `estimator.evaluate` as an argument.
      - Returns: Dict of metric results keyed by name. Final metrics are a union
        of this and `estimator`s existing metrics. If there is a name conflict
        between this and `estimator`s existing metrics, this will override the
        existing one. The values of the dict are the results of calling a metric
        function, namely a `(metric_tensor, update_op)` tuple.
    use_tpu: Whether to use TPU-specific variable sharing logic.
    features: Input `dict` of `Tensor` objects.
    labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
      (for multi-head).
    estimator_spec: The `EstimatorSpec` created by a `Head` instance.
    architecture: `_Architecture` object.

  Returns:
    An instance of _EnsembleMetrics.
  """

    if not estimator_spec:
        estimator_spec = tf_compat.v1.estimator.tpu.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf.constant(2.),
            predictions=None,
            eval_metrics=None)
        if not use_tpu:
            estimator_spec = estimator_spec.as_estimator_spec()

    if not architecture:
        architecture = _Architecture(None, None)

    metrics = _EnsembleMetrics(use_tpu=use_tpu)
    metrics.create_eval_metrics(features, labels, estimator_spec, metric_fn,
                                architecture)

    return metrics
Esempio n. 2
0
    def test_ensemble_metrics(self):
        architecture = _Architecture("test_ensemble_candidate")
        architecture.add_subnetwork(iteration_number=0, builder_name="b_0_0")
        architecture.add_subnetwork(iteration_number=0, builder_name="b_0_1")
        architecture.add_subnetwork(iteration_number=1, builder_name="b_1_0")
        architecture.add_subnetwork(iteration_number=2, builder_name="b_2_0")

        metrics = _EnsembleMetrics()
        metrics.create_eval_metrics(self._features, self._labels,
                                    self._estimator_spec, self._metric_fn,
                                    architecture)

        with self.test_session() as sess:
            actual = _run_metrics(sess, metrics.eval_metrics_tuple())

        serialized_arch_proto = actual["architecture/adanet/ensembles"]
        expected_arch_string = b"| b_0_0 | b_0_1 | b_1_0 | b_2_0 |"
        self.assertIn(expected_arch_string, serialized_arch_proto)
Esempio n. 3
0
  def build_ensemble_spec(self,
                          name,
                          candidate,
                          ensembler,
                          subnetwork_specs,
                          summary,
                          features,
                          mode,
                          iteration_step,
                          iteration_number,
                          labels=None,
                          previous_ensemble_spec=None):
    """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.

    Args:
      name: The string name of the ensemble. Typically the name of the builder
        that returned the given `Subnetwork`.
      candidate: The `adanet.ensemble.Candidate` for this spec.
      ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
        group of subnetworks.
      subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
      iteration_step: Integer `Tensor` representing the step since the beginning
        of the current iteration, as opposed to the global step.
      iteration_number: Integer current iteration number.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head).
      previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
        iteration t-1. Used for creating the subnetwork train_op.

    Returns:
      An `_EnsembleSpec` instance.
    """

    with tf.variable_scope("ensemble_{}".format(name)):
      architecture = _Architecture(candidate.name)
      previous_subnetworks = []
      subnetwork_builders = []
      previous_ensemble = None
      if previous_ensemble_spec:
        previous_ensemble = previous_ensemble_spec.ensemble
        previous_architecture = previous_ensemble_spec.architecture
        keep_indices = range(len(previous_ensemble.subnetworks))
        if len(candidate.subnetwork_builders) == 1 and previous_ensemble:
          # Prune previous ensemble according to the subnetwork.Builder for
          # backwards compatibility.
          tf.logging.warn(
              "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
              "is deprecated. Please use a custom `adanet.ensemble.Strategy` "
              "instead.")
          subnetwork_builder = candidate.subnetwork_builders[0]
          keep_indices = subnetwork_builder.prune_previous_ensemble(
              previous_ensemble)
        for i, builder in enumerate(previous_ensemble_spec.subnetwork_builders):
          if i not in keep_indices:
            continue
          if builder not in candidate.previous_ensemble_subnetwork_builders:
            continue
          previous_subnetworks.append(previous_ensemble.subnetworks[i])
          subnetwork_builders.append(builder)
          architecture.add_subnetwork(*previous_architecture.subnetworks[i])
      for builder in candidate.subnetwork_builders:
        architecture.add_subnetwork(iteration_number, builder.name)
        subnetwork_builders.append(builder)
      subnetwork_map = {s.builder.name: s.subnetwork for s in subnetwork_specs}
      subnetworks = [
          subnetwork_map[s.name] for s in candidate.subnetwork_builders
      ]
      ensemble_scope = tf.get_variable_scope()
      before_var_list = tf.trainable_variables()
      with summary.current_scope(), _monkey_patch_context(
          iteration_step_scope=ensemble_scope,
          scoped_summary=summary,
          trainable_vars=[]):
        ensemble = ensembler.build_ensemble(
            subnetworks,
            previous_ensemble_subnetworks=previous_subnetworks,
            features=features,
            labels=labels,
            logits_dimension=self._head.logits_dimension,
            training=mode == tf.estimator.ModeKeys.TRAIN,
            iteration_step=iteration_step,
            summary=summary,
            previous_ensemble=previous_ensemble)
      ensemble_var_list = _new_trainable_variables(before_var_list)

      estimator_spec = _create_estimator_spec(
          self._head, features, labels, mode, ensemble.logits, self._use_tpu)

      ensemble_loss = estimator_spec.loss
      adanet_loss = None
      if mode != tf.estimator.ModeKeys.PREDICT:
        # TODO: Support any kind of Ensemble. Use a moving average of
        # their train loss for the 'adanet_loss'.
        if not isinstance(ensemble, ComplexityRegularized):
          raise ValueError(
              "Only ComplexityRegularized ensembles are supported.")
        adanet_loss = estimator_spec.loss + ensemble.complexity_regularization

      ensemble_metrics = _EnsembleMetrics()
      if mode == tf.estimator.ModeKeys.EVAL:
        ensemble_metrics.create_eval_metrics(
            features=features,
            labels=labels,
            estimator_spec=estimator_spec,
            metric_fn=self._metric_fn,
            architecture=architecture)

      if mode == tf.estimator.ModeKeys.TRAIN:
        with summary.current_scope():
          summary.scalar("loss", estimator_spec.loss)

      # Create train ops for training subnetworks and ensembles.
      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        # Note that these mixture weights are on top of the last_layer of the
        # subnetwork constructed in TRAIN mode, which means that dropout is
        # still applied when the mixture weights are being trained.
        ensemble_scope = tf.get_variable_scope()
        with tf.variable_scope("train_mixture_weights"):
          with summary.current_scope(), _monkey_patch_context(
              iteration_step_scope=ensemble_scope,
              scoped_summary=summary,
              trainable_vars=ensemble_var_list):
            # For backwards compatibility.
            subnetwork_builder = candidate.subnetwork_builders[0]
            old_train_op_fn = getattr(subnetwork_builder,
                                      "build_mixture_weights_train_op", None)
            if callable(old_train_op_fn):
              tf.logging.warn(
                  "The `build_mixture_weights_train_op` method is deprecated. "
                  "Please use the `Ensembler#build_train_op` instead.")
              train_op = _to_train_op_spec(
                  subnetwork_builder.build_mixture_weights_train_op(
                      loss=adanet_loss,
                      var_list=ensemble_var_list,
                      logits=ensemble.logits,
                      labels=labels,
                      iteration_step=iteration_step,
                      summary=summary))
            else:
              train_op = _to_train_op_spec(
                  ensembler.build_train_op(
                      ensemble=ensemble,
                      loss=adanet_loss,
                      var_list=ensemble_var_list,
                      labels=labels,
                      iteration_step=iteration_step,
                      summary=summary,
                      previous_ensemble=previous_ensemble))
    return _EnsembleSpec(
        name=name,
        architecture=architecture,
        subnetwork_builders=subnetwork_builders,
        ensemble=ensemble,
        predictions=estimator_spec.predictions,
        loss=ensemble_loss,
        adanet_loss=adanet_loss,
        train_op=train_op,
        eval_metrics=ensemble_metrics.eval_metrics_tuple(),
        export_outputs=estimator_spec.export_outputs)
    def build_ensemble_spec(self, name, candidate, ensembler, subnetwork_specs,
                            summary, features, mode, iteration_number, labels,
                            my_ensemble_index, previous_ensemble_spec,
                            previous_iteration_checkpoint):
        """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.

    Args:
      name: The string name of the ensemble. Typically the name of the builder
        that returned the given `Subnetwork`.
      candidate: The `adanet.ensemble.Candidate` for this spec.
      ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
        group of subnetworks.
      subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
      iteration_number: Integer current iteration number.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head).
      my_ensemble_index: An integer holding the index of the ensemble in the
        candidates list of AdaNet.
      previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
        iteration t-1. Used for creating the subnetwork train_op.
      previous_iteration_checkpoint: `tf.train.Checkpoint` for iteration t-1.

    Returns:
      An `_EnsembleSpec` instance.
    """

        with tf_compat.v1.variable_scope("ensemble_{}".format(name)):
            step = tf_compat.v1.get_variable(
                "step",
                shape=[],
                initializer=tf_compat.v1.zeros_initializer(),
                trainable=False,
                dtype=tf.int64)
            # Convert to tensor so that users cannot mutate it.
            step_tensor = tf.convert_to_tensor(value=step)
            with summary.current_scope():
                summary.scalar("iteration_step/adanet/iteration_step",
                               step_tensor)
            replay_indices = []
            if previous_ensemble_spec:
                replay_indices = copy.copy(
                    previous_ensemble_spec.architecture.replay_indices)
            if my_ensemble_index is not None:
                replay_indices.append(my_ensemble_index)

            architecture = _Architecture(candidate.name,
                                         ensembler.name,
                                         replay_indices=replay_indices)
            previous_subnetworks = []
            previous_subnetwork_specs = []
            subnetwork_builders = []
            previous_ensemble = None
            if previous_ensemble_spec:
                previous_ensemble = previous_ensemble_spec.ensemble
                previous_architecture = previous_ensemble_spec.architecture
                keep_indices = range(len(previous_ensemble.subnetworks))
                if len(candidate.subnetwork_builders
                       ) == 1 and previous_ensemble:
                    # Prune previous ensemble according to the subnetwork.Builder for
                    # backwards compatibility.
                    subnetwork_builder = candidate.subnetwork_builders[0]
                    prune_previous_ensemble = getattr(
                        subnetwork_builder, "prune_previous_ensemble", None)
                    if callable(prune_previous_ensemble):
                        logging.warn(
                            "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
                            "is deprecated. Please use a custom `adanet.ensemble.Strategy` "
                            "instead.")
                        keep_indices = prune_previous_ensemble(
                            previous_ensemble)
                for i, builder in enumerate(
                        previous_ensemble_spec.subnetwork_builders):
                    if i not in keep_indices:
                        continue
                    if builder not in candidate.previous_ensemble_subnetwork_builders:
                        continue
                    previous_subnetworks.append(
                        previous_ensemble.subnetworks[i])
                    previous_subnetwork_specs.append(
                        previous_ensemble_spec.subnetwork_specs[i])
                    subnetwork_builders.append(builder)
                    architecture.add_subnetwork(
                        *previous_architecture.subnetworks[i])
            for builder in candidate.subnetwork_builders:
                architecture.add_subnetwork(iteration_number, builder.name)
                subnetwork_builders.append(builder)
            subnetwork_spec_map = {s.builder.name: s for s in subnetwork_specs}
            relevant_subnetwork_specs = [
                subnetwork_spec_map[s.name]
                for s in candidate.subnetwork_builders
            ]
            ensemble_scope = tf_compat.v1.get_variable_scope()

            old_vars = _get_current_vars()

            with summary.current_scope(), _monkey_patch_context(
                    iteration_step_scope=ensemble_scope,
                    scoped_summary=summary,
                    trainable_vars=[]):
                ensemble = ensembler.build_ensemble(
                    subnetworks=[
                        s.subnetwork for s in relevant_subnetwork_specs
                    ],
                    previous_ensemble_subnetworks=previous_subnetworks,
                    features=features,
                    labels=labels,
                    logits_dimension=self._head.logits_dimension,
                    training=mode == tf.estimator.ModeKeys.TRAIN,
                    iteration_step=step_tensor,
                    summary=summary,
                    previous_ensemble=previous_ensemble,
                    previous_iteration_checkpoint=previous_iteration_checkpoint
                )

            estimator_spec = _create_estimator_spec(self._head, features,
                                                    labels, mode,
                                                    ensemble.logits,
                                                    self._use_tpu)

            ensemble_loss = estimator_spec.loss
            adanet_loss = None
            if mode != tf.estimator.ModeKeys.PREDICT:
                adanet_loss = estimator_spec.loss
                # Add ensembler specific loss
                if isinstance(ensemble, ensemble_lib.ComplexityRegularized):
                    adanet_loss += ensemble.complexity_regularization

            predictions = estimator_spec.predictions
            export_outputs = estimator_spec.export_outputs

            if (self._export_subnetwork_logits and export_outputs
                    and subnetwork_spec_map):
                first_subnetwork_logits = list(
                    subnetwork_spec_map.values())[0].subnetwork.logits
                if isinstance(first_subnetwork_logits, dict):
                    for head_name in first_subnetwork_logits.keys():
                        subnetwork_logits = {
                            subnetwork_name:
                            subnetwork_spec.subnetwork.logits[head_name]
                            for subnetwork_name, subnetwork_spec in
                            subnetwork_spec_map.items()
                        }
                        export_outputs.update({
                            "{}_{}".format(
                                _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE, head_name):
                            tf.estimator.export.PredictOutput(
                                subnetwork_logits)
                        })
                else:
                    subnetwork_logits = {
                        subnetwork_name: subnetwork_spec.subnetwork.logits
                        for subnetwork_name, subnetwork_spec in
                        subnetwork_spec_map.items()
                    }
                    export_outputs.update({
                        _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE:
                        tf.estimator.export.PredictOutput(subnetwork_logits)
                    })

            if (self._export_subnetwork_last_layer and export_outputs
                    and subnetwork_spec_map and list(
                        subnetwork_spec_map.values())[0].subnetwork.last_layer
                    is not None):
                first_subnetwork_last_layer = list(
                    subnetwork_spec_map.values())[0].subnetwork.last_layer
                if isinstance(first_subnetwork_last_layer, dict):
                    for head_name in first_subnetwork_last_layer.keys():
                        subnetwork_last_layer = {
                            subnetwork_name:
                            subnetwork_spec.subnetwork.last_layer[head_name]
                            for subnetwork_name, subnetwork_spec in
                            subnetwork_spec_map.items()
                        }
                        export_outputs.update({
                            "{}_{}".format(
                                _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE, head_name):
                            tf.estimator.export.PredictOutput(
                                subnetwork_last_layer)
                        })
                else:
                    subnetwork_last_layer = {
                        subnetwork_name: subnetwork_spec.subnetwork.last_layer
                        for subnetwork_name, subnetwork_spec in
                        subnetwork_spec_map.items()
                    }
                    export_outputs.update({
                        _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE:
                        tf.estimator.export.PredictOutput(
                            subnetwork_last_layer)
                    })

            if ensemble.predictions and predictions:
                predictions.update(ensemble.predictions)
            if ensemble.predictions and export_outputs:
                export_outputs.update({
                    k: tf.estimator.export.PredictOutput(v)
                    for k, v in ensemble.predictions.items()
                })

            ensemble_metrics = _EnsembleMetrics(use_tpu=self._use_tpu)
            if mode == tf.estimator.ModeKeys.EVAL:
                ensemble_metrics.create_eval_metrics(
                    features=features,
                    labels=labels,
                    estimator_spec=estimator_spec,
                    metric_fn=self._metric_fn,
                    architecture=architecture)

            if mode == tf.estimator.ModeKeys.TRAIN:
                with summary.current_scope():
                    summary.scalar("loss", estimator_spec.loss)

            ensemble_trainable_vars = _get_current_vars(
                diffbase=old_vars)["trainable"]
            # Create train ops for training subnetworks and ensembles.
            train_op = None
            if mode == tf.estimator.ModeKeys.TRAIN:
                # Note that these mixture weights are on top of the last_layer of the
                # subnetwork constructed in TRAIN mode, which means that dropout is
                # still applied when the mixture weights are being trained.
                ensemble_scope = tf_compat.v1.get_variable_scope()
                with tf_compat.v1.variable_scope("train_mixture_weights"):
                    with summary.current_scope(), _monkey_patch_context(
                            iteration_step_scope=ensemble_scope,
                            scoped_summary=summary,
                            trainable_vars=ensemble_trainable_vars):
                        # For backwards compatibility.
                        subnetwork_builder = candidate.subnetwork_builders[0]
                        old_train_op_fn = getattr(
                            subnetwork_builder,
                            "build_mixture_weights_train_op", None)
                        if callable(old_train_op_fn):
                            logging.warn(
                                "The `build_mixture_weights_train_op` method is deprecated. "
                                "Please use the `Ensembler#build_train_op` instead."
                            )
                            train_op = _to_train_op_spec(
                                subnetwork_builder.
                                build_mixture_weights_train_op(
                                    loss=adanet_loss,
                                    var_list=ensemble_trainable_vars,
                                    logits=ensemble.logits,
                                    labels=labels,
                                    iteration_step=step_tensor,
                                    summary=summary))
                        else:
                            train_op = _to_train_op_spec(
                                ensembler.build_train_op(
                                    ensemble=ensemble,
                                    loss=adanet_loss,
                                    var_list=ensemble_trainable_vars,
                                    labels=labels,
                                    iteration_step=step_tensor,
                                    summary=summary,
                                    previous_ensemble=previous_ensemble))

            new_vars = _get_current_vars(diffbase=old_vars)
            # Sort our dictionary by key to remove non-determinism of variable order.
            new_vars = collections.OrderedDict(sorted(new_vars.items()))
            # Combine all trainable, global and savable variables into a single list.
            ensemble_variables = sum(new_vars.values(), []) + [step]

        return _EnsembleSpec(name=name,
                             architecture=architecture,
                             subnetwork_builders=subnetwork_builders,
                             subnetwork_specs=previous_subnetwork_specs +
                             relevant_subnetwork_specs,
                             ensemble=ensemble,
                             predictions=predictions,
                             step=step,
                             variables=ensemble_variables,
                             loss=ensemble_loss,
                             adanet_loss=adanet_loss,
                             train_op=train_op,
                             eval_metrics=ensemble_metrics,
                             export_outputs=export_outputs)