Ejemplo n.º 1
0
  def build_iteration(self,
                      base_global_step,
                      iteration_number,
                      ensemble_candidates,
                      subnetwork_builders,
                      features,
                      mode,
                      config,
                      labels=None,
                      previous_ensemble_summary=None,
                      previous_ensemble_spec=None,
                      rebuilding=False,
                      rebuilding_ensembler_name=None,
                      best_ensemble_index_override=None):
    """Builds and returns AdaNet iteration t.

    This method uses the generated the candidate subnetworks given the ensemble
    at iteration t-1 and creates graph operations to train them. The returned
    `_Iteration` tracks the training of all candidates to know when the
    iteration is over, and tracks the best candidate's predictions and loss, as
    defined by lowest complexity-regularized loss on the train set.

    Args:
      base_global_step: Integer global step at the beginning of this iteration.
      iteration_number: Integer iteration number.
      ensemble_candidates: Iterable of `adanet.ensemble.Candidate` instances.
      subnetwork_builders: A list of `Builders` for adding ` Subnetworks` to the
        graph. Each subnetwork is then wrapped in a `_Candidate` to train.
      features: Dictionary of `Tensor` objects keyed by feature name.
      mode: Defines whether this is training, evaluation or prediction. See
        `ModeKeys`.
      config: The `tf.estimator.RunConfig` to use this iteration.
      labels: `Tensor` of labels. Can be `None`.
      previous_ensemble_summary: The `adanet.Summary` for the previous ensemble.
      previous_ensemble_spec: Optional `_EnsembleSpec` for iteration t-1.
      rebuilding: Boolean whether the iteration is being rebuilt only to restore
        the previous best subnetworks and ensembles.
      rebuilding_ensembler_name: Optional ensembler to restrict to, only
        relevant when rebuilding is set as True.
      best_ensemble_index_override: Integer index to identify the best ensemble
        candidate instead of computing the best ensemble index dynamically
        conditional on the ensemble AdaNet losses.

    Returns:
      An _Iteration instance.

    Raises:
      ValueError: If subnetwork_builders is empty.
      ValueError: If two subnetworks share the same name.
      ValueError: If two ensembles share the same name.
    """

    self._placement_strategy.config = config

    logging.info("%s iteration %s", "Rebuilding" if rebuilding else "Building",
                 iteration_number)

    if not subnetwork_builders:
      raise ValueError("Each iteration must have at least one Builder.")

    # TODO: Consider moving builder mode logic to ensemble_builder.py.
    builder_mode = mode
    if rebuilding:
      # Build the subnetworks and ensembles in EVAL mode by default. This way
      # their outputs aren't affected by dropout etc.
      builder_mode = tf.estimator.ModeKeys.EVAL
      if mode == tf.estimator.ModeKeys.PREDICT:
        builder_mode = mode

      # Only replicate in training mode when the user requests it.
      if self._replicate_ensemble_in_training and (
          mode == tf.estimator.ModeKeys.TRAIN):
        builder_mode = mode

    features, labels = self._check_numerics(features, labels)

    training = mode == tf.estimator.ModeKeys.TRAIN
    skip_summaries = mode == tf.estimator.ModeKeys.PREDICT or rebuilding
    with tf_compat.v1.variable_scope("iteration_{}".format(iteration_number)):
      seen_builder_names = {}
      candidates = []
      summaries = []
      subnetwork_reports = {}
      previous_ensemble = None

      if previous_ensemble_spec:
        previous_ensemble = previous_ensemble_spec.ensemble
        # Include previous best subnetwork as a candidate so that its
        # predictions are returned until a new candidate outperforms.
        seen_builder_names = {previous_ensemble_spec.name: True}
        previous_best_candidate = self._candidate_builder.build_candidate(
            ensemble_spec=previous_ensemble_spec,
            training=training,
            summary=previous_ensemble_summary)
        candidates.append(previous_best_candidate)
        if self._enable_ensemble_summaries:
          summaries.append(previous_ensemble_summary)

        # Generate subnetwork reports.
        if mode == tf.estimator.ModeKeys.EVAL:
          metrics = call_eval_metrics(previous_ensemble_spec.eval_metrics)
          subnetwork_report = subnetwork.Report(
              hparams={},
              attributes={},
              metrics=metrics,
          )
          subnetwork_report.metrics["adanet_loss"] = tf_compat.v1.metrics.mean(
              previous_ensemble_spec.adanet_loss)
          subnetwork_reports["previous_ensemble"] = subnetwork_report

      for subnetwork_builder in subnetwork_builders:
        if subnetwork_builder.name in seen_builder_names:
          raise ValueError("Two subnetworks have the same name '{}'".format(
              subnetwork_builder.name))
        seen_builder_names[subnetwork_builder.name] = True
      subnetwork_specs = []
      num_subnetworks = len(subnetwork_builders)
      skip_summary = skip_summaries or not self._enable_subnetwork_summaries
      for i, subnetwork_builder in enumerate(subnetwork_builders):
        if not self._placement_strategy.should_build_subnetwork(
            num_subnetworks, i) and not rebuilding:
          continue
        with self._placement_strategy.subnetwork_devices(num_subnetworks, i):
          subnetwork_name = "t{}_{}".format(iteration_number,
                                            subnetwork_builder.name)
          subnetwork_summary = self._summary_maker(
              namespace="subnetwork",
              scope=subnetwork_name,
              skip_summary=skip_summary)
          if not skip_summary:
            summaries.append(subnetwork_summary)
          logging.info("%s subnetwork '%s'",
                       "Rebuilding" if rebuilding else "Building",
                       subnetwork_builder.name)
          subnetwork_spec = self._subnetwork_manager.build_subnetwork_spec(
              name=subnetwork_name,
              subnetwork_builder=subnetwork_builder,
              summary=subnetwork_summary,
              features=features,
              mode=builder_mode,
              labels=labels,
              previous_ensemble=previous_ensemble)
          subnetwork_specs.append(subnetwork_spec)
          # Workers that don't build ensembles need a dummy candidate in order
          # to train the subnetwork.
          # Because only ensembles can be considered candidates, we need to
          # convert the subnetwork into a dummy ensemble and subsequently a
          # dummy candidate. However, this dummy candidate is never considered a
          # true candidate during candidate evaluation and selection.
          # TODO: Eliminate need for candidates.
          if not self._placement_strategy.should_build_ensemble(
              num_subnetworks) and not rebuilding:
            candidates.append(
                self._create_dummy_candidate(subnetwork_spec,
                                             subnetwork_builders,
                                             subnetwork_summary, training))
        # Generate subnetwork reports.
        if mode != tf.estimator.ModeKeys.PREDICT:
          subnetwork_report = subnetwork_builder.build_subnetwork_report()
          if not subnetwork_report:
            subnetwork_report = subnetwork.Report(
                hparams={}, attributes={}, metrics={})
          metrics = call_eval_metrics(subnetwork_spec.eval_metrics)
          for metric_name in sorted(metrics):
            metric = metrics[metric_name]
            subnetwork_report.metrics[metric_name] = metric
          subnetwork_reports[subnetwork_builder.name] = subnetwork_report

      # Create (ensemble_candidate*ensembler) ensembles.
      skip_summary = skip_summaries or not self._enable_ensemble_summaries
      seen_ensemble_names = {}
      for ensembler in self._ensemblers:
        if rebuilding and rebuilding_ensembler_name and (
            ensembler.name != rebuilding_ensembler_name):
          continue
        for ensemble_candidate in ensemble_candidates:
          if not self._placement_strategy.should_build_ensemble(
              num_subnetworks) and not rebuilding:
            continue
          ensemble_name = "t{}_{}_{}".format(iteration_number,
                                             ensemble_candidate.name,
                                             ensembler.name)
          if ensemble_name in seen_ensemble_names:
            raise ValueError(
                "Two ensembles have the same name '{}'".format(ensemble_name))
          seen_ensemble_names[ensemble_name] = True
          summary = self._summary_maker(
              namespace="ensemble",
              scope=ensemble_name,
              skip_summary=skip_summary)
          if not skip_summary:
            summaries.append(summary)
          ensemble_spec = self._ensemble_builder.build_ensemble_spec(
              name=ensemble_name,
              candidate=ensemble_candidate,
              ensembler=ensembler,
              subnetwork_specs=subnetwork_specs,
              summary=summary,
              features=features,
              mode=builder_mode,
              iteration_number=iteration_number,
              labels=labels,
              previous_ensemble_spec=previous_ensemble_spec)
          # TODO: Eliminate need for candidates.
          # TODO: Don't track moving average of loss when rebuilding
          # previous ensemble.
          candidate = self._candidate_builder.build_candidate(
              ensemble_spec=ensemble_spec, training=training, summary=summary)
          candidates.append(candidate)
          # TODO: Move adanet_loss from subnetwork report to a new
          # ensemble report, since the adanet_loss is associated with an
          # ensemble, and only when using a ComplexityRegularizedEnsemblers.
          # Keep adanet_loss in subnetwork report for backwards compatibility.
          if len(ensemble_candidates) != len(subnetwork_builders):
            continue
          if len(ensemble_candidate.subnetwork_builders) > 1:
            continue
          if mode == tf.estimator.ModeKeys.PREDICT:
            continue
          builder_name = ensemble_candidate.subnetwork_builders[0].name
          subnetwork_reports[builder_name].metrics[
              "adanet_loss"] = tf_compat.v1.metrics.mean(
                  ensemble_spec.adanet_loss)

      # Dynamically select the outputs of best candidate.
      best_candidate_index = self._best_candidate_index(
          candidates, best_ensemble_index_override)
      best_predictions = self._best_predictions(candidates,
                                                best_candidate_index)
      best_loss = self._best_loss(candidates, best_candidate_index, mode)
      best_export_outputs = self._best_export_outputs(candidates,
                                                      best_candidate_index,
                                                      mode, best_predictions)
      train_manager_dir = os.path.join(config.model_dir, "train_manager",
                                       "t{}".format(iteration_number))
      train_manager, training_chief_hooks, training_hooks = self._create_hooks(
          base_global_step, subnetwork_specs, candidates, num_subnetworks,
          rebuilding, train_manager_dir, config.is_chief)
      # Iteration summaries.
      summary = self._summary_maker(
          namespace=None, scope=None, skip_summary=skip_summaries)
      summaries.append(summary)
      with summary.current_scope():
        summary.scalar("iteration/adanet/iteration", iteration_number)
        if best_loss is not None:
          summary.scalar("loss", best_loss)
      iteration_metrics = _IterationMetrics(iteration_number, candidates,
                                            subnetwork_specs)
      # All training happens in hooks so we don't need a train op.
      train_op = tf.no_op()
      if self._use_tpu:
        estimator_spec = tf_compat.v1.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=best_predictions,
            loss=best_loss,
            train_op=self._create_tpu_train_op(base_global_step,
                                               subnetwork_specs, candidates,
                                               mode, num_subnetworks, config),
            eval_metrics=iteration_metrics.best_eval_metrics_tuple(
                best_candidate_index, mode),
            export_outputs=best_export_outputs,
            training_hooks=training_hooks)
      else:
        estimator_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=best_predictions,
            loss=best_loss,
            train_op=train_op,
            eval_metric_ops=iteration_metrics.best_eval_metric_ops(
                best_candidate_index, mode),
            export_outputs=best_export_outputs,
            training_chief_hooks=training_chief_hooks,
            training_hooks=training_hooks)

      return _Iteration(
          number=iteration_number,
          candidates=candidates,
          subnetwork_specs=subnetwork_specs,
          estimator_spec=estimator_spec,
          best_candidate_index=best_candidate_index,
          summaries=summaries,
          train_manager=train_manager,
          subnetwork_reports=subnetwork_reports)
Ejemplo n.º 2
0
  def build_iteration(self,
                      iteration_number,
                      ensemble_candidates,
                      subnetwork_builders,
                      features,
                      mode,
                      labels=None,
                      previous_ensemble_summary=None,
                      previous_ensemble_spec=None,
                      skip_summaries=False,
                      rebuilding=False):
    """Builds and returns AdaNet iteration t.

    This method uses the generated the candidate subnetworks given the ensemble
    at iteration t-1 and creates graph operations to train them. The returned
    `_Iteration` tracks the training of all candidates to know when the
    iteration is over, and tracks the best candidate's predictions and loss, as
    defined by lowest complexity-regularized loss on the train set.

    Args:
      iteration_number: Integer iteration number.
      ensemble_candidates: Iterable of `adanet.ensemble.Candidate` instances.
      subnetwork_builders: A list of `Builders` for adding ` Subnetworks` to the
        graph. Each subnetwork is then wrapped in a `_Candidate` to train.
      features: Dictionary of `Tensor` objects keyed by feature name.
      mode: Defines whether this is training, evaluation or prediction. See
        `ModeKeys`.
      labels: `Tensor` of labels. Can be `None`.
      previous_ensemble_summary: The `adanet.Summary` for the previous ensemble.
      previous_ensemble_spec: Optional `_EnsembleSpec` for iteration t-1.
      skip_summaries: Whether to skip creating the summary ops when building
        the `_Iteration`.
      rebuilding: Boolean whether the iteration is being rebuilt only to restore
        the previous best subnetworks and ensembles.

    Returns:
      An _Iteration instance.

    Raises:
      ValueError: If subnetwork_builders is empty.
      ValueError: If two subnetworks share the same name.
      ValueError: If two ensembles share the same name.
    """

    tf.logging.info("%s iteration %s",
                    "Rebuilding" if rebuilding else "Building",
                    iteration_number)

    if not subnetwork_builders:
      raise ValueError("Each iteration must have at least one Builder.")

    # TODO: Consider moving builder mode logic to ensemble_builder.py.
    builder_mode = mode
    if rebuilding:
      # Build the subnetworks and ensembles in EVAL mode by default. This way
      # their outputs aren't affected by dropout etc.
      builder_mode = tf.estimator.ModeKeys.EVAL
      if mode == tf.estimator.ModeKeys.PREDICT:
        builder_mode = mode

      # Only replicate in training mode when the user requests it.
      if self._replicate_ensemble_in_training and (
          mode == tf.estimator.ModeKeys.TRAIN):
        builder_mode = mode

    features, labels = self._check_numerics(features, labels)

    training = mode == tf.estimator.ModeKeys.TRAIN
    skip_summaries = mode == tf.estimator.ModeKeys.PREDICT
    with tf.variable_scope("iteration_{}".format(iteration_number)):
      # Iteration step to use instead of global step.
      iteration_step = tf.get_variable(
          "step",
          shape=[],
          initializer=tf.zeros_initializer(),
          trainable=False,
          dtype=tf.int64)

      # Convert to tensor so that users cannot mutate it.
      iteration_step_tensor = tf.convert_to_tensor(iteration_step)

      seen_builder_names = {}
      candidates = []
      summaries = []
      subnetwork_reports = {}
      previous_ensemble = None

      if previous_ensemble_spec:
        previous_ensemble = previous_ensemble_spec.ensemble
        # Include previous best subnetwork as a candidate so that its
        # predictions are returned until a new candidate outperforms.
        seen_builder_names = {previous_ensemble_spec.name: True}
        previous_best_candidate = self._candidate_builder.build_candidate(
            ensemble_spec=previous_ensemble_spec,
            training=training,
            iteration_step=iteration_step_tensor,
            summary=previous_ensemble_summary,
            is_previous_best=True)
        candidates.append(previous_best_candidate)
        summaries.append(previous_ensemble_summary)

        # Generate subnetwork reports.
        if mode == tf.estimator.ModeKeys.EVAL:
          metrics = call_eval_metrics(previous_ensemble_spec.eval_metrics)
          subnetwork_report = subnetwork.Report(
              hparams={},
              attributes={},
              metrics=metrics,
          )
          subnetwork_report.metrics["adanet_loss"] = tf.metrics.mean(
              previous_ensemble_spec.adanet_loss)
          subnetwork_reports["previous_ensemble"] = subnetwork_report

      for subnetwork_builder in subnetwork_builders:
        if subnetwork_builder.name in seen_builder_names:
          raise ValueError("Two subnetworks have the same name '{}'".format(
              subnetwork_builder.name))
        seen_builder_names[subnetwork_builder.name] = True
      subnetwork_specs = []
      num_subnetworks = len(subnetwork_builders)
      for i, subnetwork_builder in enumerate(subnetwork_builders):
        if not self._placement_strategy.should_build_subnetwork(
            num_subnetworks, i) and not rebuilding:
          continue
        subnetwork_name = "t{}_{}".format(iteration_number,
                                          subnetwork_builder.name)
        subnetwork_summary = self._summary_maker(
            namespace="subnetwork",
            scope=subnetwork_name,
            skip_summary=skip_summaries or rebuilding)
        summaries.append(subnetwork_summary)
        tf.logging.info("%s subnetwork '%s'",
                        "Rebuilding" if rebuilding else "Building",
                        subnetwork_builder.name)
        subnetwork_spec = self._subnetwork_manager.build_subnetwork_spec(
            name=subnetwork_name,
            subnetwork_builder=subnetwork_builder,
            iteration_step=iteration_step_tensor,
            summary=subnetwork_summary,
            features=features,
            mode=builder_mode,
            labels=labels,
            previous_ensemble=previous_ensemble)
        subnetwork_specs.append(subnetwork_spec)
        if not self._placement_strategy.should_build_ensemble(
            num_subnetworks) and not rebuilding:
          # Workers that don't build ensembles need a dummy candidate in order
          # to train the subnetwork.
          # Because only ensembles can be considered candidates, we need to
          # convert the subnetwork into a dummy ensemble and subsequently a
          # dummy candidate. However, this dummy candidate is never considered a
          # true candidate during candidate evaluation and selection.
          # TODO: Eliminate need for candidates.
          dummy_candidate = self._candidate_builder.build_candidate(
              # pylint: disable=protected-access
              ensemble_spec=ensemble_builder_lib._EnsembleSpec(
                  name=subnetwork_name,
                  ensemble=None,
                  architecture=None,
                  subnetwork_builders=subnetwork_builders,
                  predictions=subnetwork_spec.predictions,
                  loss=subnetwork_spec.loss,
                  adanet_loss=0.),
              # pylint: enable=protected-access
              training=training,
              iteration_step=iteration_step_tensor,
              summary=subnetwork_summary,
              track_moving_average=False)
          candidates.append(dummy_candidate)
        # Generate subnetwork reports.
        if mode != tf.estimator.ModeKeys.PREDICT:
          subnetwork_report = subnetwork_builder.build_subnetwork_report()
          if not subnetwork_report:
            subnetwork_report = subnetwork.Report(
                hparams={}, attributes={}, metrics={})
          metrics = call_eval_metrics(subnetwork_spec.eval_metrics)
          for metric_name in sorted(metrics):
            metric = metrics[metric_name]
            subnetwork_report.metrics[metric_name] = metric
          subnetwork_reports[subnetwork_builder.name] = subnetwork_report

      # Create (ensembler_candidate*ensembler) ensembles.
      seen_ensemble_names = {}
      for ensembler in self._ensemblers:
        for ensemble_candidate in ensemble_candidates:
          if not self._placement_strategy.should_build_ensemble(
              num_subnetworks) and not rebuilding:
            continue
          ensemble_name = "t{}_{}_{}".format(
              iteration_number, ensemble_candidate.name, ensembler.name)
          if ensemble_name in seen_ensemble_names:
            raise ValueError(
                "Two ensembles have the same name '{}'".format(ensemble_name))
          seen_ensemble_names[ensemble_name] = True
          summary = self._summary_maker(
              namespace="ensemble",
              scope=ensemble_name,
              skip_summary=skip_summaries or rebuilding)
          summaries.append(summary)
          ensemble_spec = self._ensemble_builder.build_ensemble_spec(
              name=ensemble_name,
              candidate=ensemble_candidate,
              ensembler=ensembler,
              subnetwork_specs=subnetwork_specs,
              summary=summary,
              features=features,
              mode=builder_mode,
              iteration_step=iteration_step_tensor,
              iteration_number=iteration_number,
              labels=labels,
              previous_ensemble_spec=previous_ensemble_spec)
          # TODO: Eliminate need for candidates.
          # TODO: Don't track moving average of loss when rebuilding
          # previous ensemble.
          candidate = self._candidate_builder.build_candidate(
              ensemble_spec=ensemble_spec,
              training=training,
              iteration_step=iteration_step_tensor,
              summary=summary)
          candidates.append(candidate)
          # TODO: Move adanet_loss from subnetwork report to a new
          # ensemble report, since the adanet_loss is associated with an
          # ensemble, and only when using a ComplexityRegularizedEnsemblers.
          # Keep adanet_loss in subnetwork report for backwards compatibility.
          if len(ensemble_candidates) != len(subnetwork_builders):
            continue
          if len(ensemble_candidate.subnetwork_builders) > 1:
            continue
          if mode == tf.estimator.ModeKeys.PREDICT:
            continue
          builder_name = ensemble_candidate.subnetwork_builders[0].name
          subnetwork_reports[builder_name].metrics[
              "adanet_loss"] = tf.metrics.mean(ensemble_spec.adanet_loss)

      # Dynamically select the outputs of best candidate.
      best_candidate_index = self._best_candidate_index(candidates)
      best_predictions = self._best_predictions(candidates,
                                                best_candidate_index)
      best_loss = self._best_loss(candidates, best_candidate_index, mode)
      best_export_outputs = self._best_export_outputs(
          candidates, best_candidate_index, mode, best_predictions)
      # Hooks on TPU cannot depend on any graph `Tensors`. Instead the value of
      # `is_over` is stored in a `Variable` that can later be retrieved from
      # inside a training hook.
      is_over_var_template = tf.make_template("is_over_var_template",
                                              _is_over_var)
      training_chief_hooks, training_hooks = (), ()
      for subnetwork_spec in subnetwork_specs:
        if not self._placement_strategy.should_train_subnetworks(
            num_subnetworks) and not rebuilding:
          continue
        if not subnetwork_spec.train_op:
          continue
        training_chief_hooks += subnetwork_spec.train_op.chief_hooks or ()
        training_hooks += subnetwork_spec.train_op.hooks or ()
      for candidate in candidates:
        spec = candidate.ensemble_spec
        if not spec.train_op:
          continue
        training_chief_hooks += spec.train_op.chief_hooks or ()
        training_hooks += spec.train_op.hooks or ()
      summary = self._summary_maker(
          namespace=None, scope=None, skip_summary=skip_summaries or rebuilding)
      summaries.append(summary)
      with summary.current_scope():
        summary.scalar("iteration/adanet/iteration", iteration_number)
        summary.scalar("iteration_step/adanet/iteration_step",
                       iteration_step_tensor)
        if best_loss is not None:
          summary.scalar("loss", best_loss)
      train_op = self._create_train_op(subnetwork_specs, candidates, mode,
                                       iteration_step, is_over_var_template,
                                       num_subnetworks)
      iteration_metrics = _IterationMetrics(candidates, subnetwork_specs)
      if self._use_tpu:
        estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=best_predictions,
            loss=best_loss,
            train_op=train_op,
            eval_metrics=iteration_metrics.best_eval_metrics_tuple(
                best_candidate_index, mode),
            export_outputs=best_export_outputs,
            training_hooks=training_hooks)
      else:
        estimator_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=best_predictions,
            loss=best_loss,
            train_op=train_op,
            eval_metric_ops=iteration_metrics.best_eval_metric_ops(
                best_candidate_index, mode),
            export_outputs=best_export_outputs,
            training_chief_hooks=training_chief_hooks,
            training_hooks=training_hooks)

      return _Iteration(
          number=iteration_number,
          candidates=candidates,
          subnetwork_specs=subnetwork_specs,
          estimator_spec=estimator_spec,
          best_candidate_index=best_candidate_index,
          summaries=summaries,
          is_over_fn=is_over_var_template,
          subnetwork_reports=subnetwork_reports,
          step=iteration_step_tensor)
Ejemplo n.º 3
0
class ReportMaterializerTest(parameterized.TestCase, tf.test.TestCase):

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      {
          "testcase_name":
              "one_empty_subnetwork",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo":
                      subnetwork.Report(hparams={}, attributes={}, metrics={}),
              },
          "steps":
              3,
          "included_subnetwork_names": ["foo"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo",
                  hparams={},
                  attributes={},
                  metrics={},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "one_subnetwork",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo":
                      subnetwork.Report(
                          hparams={
                              "learning_rate": 1.e-5,
                              "optimizer": "sgd",
                              "num_layers": 0,
                              "use_side_inputs": True,
                          },
                          attributes={
                              "weight_norms": tf.constant(3.14),
                              "foo": tf.constant("bar"),
                              "parameters": tf.constant(7777),
                              "boo": tf.constant(True),
                          },
                          metrics={},
                      ),
              },
          "steps":
              3,
          "included_subnetwork_names": ["foo"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo",
                  hparams={
                      "learning_rate": 1.e-5,
                      "optimizer": "sgd",
                      "num_layers": 0,
                      "use_side_inputs": True,
                  },
                  attributes={
                      "weight_norms": 3.14,
                      "foo": "bar",
                      "parameters": 7777,
                      "boo": True,
                  },
                  metrics={},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "one_subnetwork_iteration_2",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo":
                      subnetwork.Report(
                          hparams={
                              "learning_rate": 1.e-5,
                              "optimizer": "sgd",
                              "num_layers": 0,
                              "use_side_inputs": True,
                          },
                          attributes={
                              "weight_norms": tf.constant(3.14),
                              "foo": tf.constant("bar"),
                              "parameters": tf.constant(7777),
                              "boo": tf.constant(True),
                          },
                          metrics={},
                      ),
              },
          "steps":
              3,
          "iteration_number":
              2,
          "included_subnetwork_names": ["foo"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=2,
                  name="foo",
                  hparams={
                      "learning_rate": 1.e-5,
                      "optimizer": "sgd",
                      "num_layers": 0,
                      "use_side_inputs": True,
                  },
                  attributes={
                      "weight_norms": 3.14,
                      "foo": "bar",
                      "parameters": 7777,
                      "boo": True,
                  },
                  metrics={},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "two_subnetworks",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo1":
                      subnetwork.Report(
                          hparams={
                              "learning_rate": 1.e-5,
                              "optimizer": "sgd",
                              "num_layers": 0,
                              "use_side_inputs": True,
                          },
                          attributes={
                              "weight_norms": tf.constant(3.14),
                              "foo": tf.constant("bar"),
                              "parameters": tf.constant(7777),
                              "boo": tf.constant(True),
                          },
                          metrics={},
                      ),
                  "foo2":
                      subnetwork.Report(
                          hparams={
                              "learning_rate": 1.e-6,
                              "optimizer": "sgd",
                              "num_layers": 1,
                              "use_side_inputs": True,
                          },
                          attributes={
                              "weight_norms": tf.constant(3.1445),
                              "foo": tf.constant("baz"),
                              "parameters": tf.constant(7788),
                              "boo": tf.constant(True),
                          },
                          metrics={},
                      ),
              },
          "steps":
              3,
          "included_subnetwork_names": ["foo2"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo1",
                  hparams={
                      "learning_rate": 1.e-5,
                      "optimizer": "sgd",
                      "num_layers": 0,
                      "use_side_inputs": True,
                  },
                  attributes={
                      "weight_norms": 3.14,
                      "foo": "bar",
                      "parameters": 7777,
                      "boo": True,
                  },
                  metrics={},
                  included_in_final_ensemble=False,
              ),
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo2",
                  hparams={
                      "learning_rate": 1.e-6,
                      "optimizer": "sgd",
                      "num_layers": 1,
                      "use_side_inputs": True,
                  },
                  attributes={
                      "weight_norms": 3.1445,
                      "foo": "baz",
                      "parameters": 7788,
                      "boo": True,
                  },
                  metrics={},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "two_subnetworks_zero_included",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo1":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={},
                      ),
                  "foo2":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={},
                      ),
              },
          "steps":
              3,
          "included_subnetwork_names": [],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo1",
                  hparams={},
                  attributes={},
                  metrics={},
                  included_in_final_ensemble=False,
              ),
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo2",
                  hparams={},
                  attributes={},
                  metrics={},
                  included_in_final_ensemble=False,
              ),
          ],
      }, {
          "testcase_name":
              "two_subnetworks_both_included",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo1":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={},
                      ),
                  "foo2":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={},
                      ),
              },
          "steps":
              3,
          "included_subnetwork_names": ["foo1", "foo2"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo1",
                  hparams={},
                  attributes={},
                  metrics={},
                  included_in_final_ensemble=True,
              ),
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo2",
                  hparams={},
                  attributes={},
                  metrics={},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "materialize_metrics",
          "input_fn":
              tu.dummy_input_fn([[1., 1.], [1., 1.], [1., 1.]],
                                [[1.], [2.], [3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={"moo": tf_compat.v1.metrics.mean(labels)},
                      ),
              },
          "steps":
              3,
          "included_subnetwork_names": ["foo"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo",
                  hparams={},
                  attributes={},
                  metrics={"moo": 2.},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "materialize_metrics_none_steps",
          "input_fn":
              tu.dataset_input_fn([[1., 1.], [1., 1.], [1., 1.]],
                                  [[1.], [2.], [3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={"moo": tf_compat.v1.metrics.mean(labels)},
                      ),
              },
          "steps":
              None,
          "included_subnetwork_names": ["foo"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo",
                  hparams={},
                  attributes={},
                  metrics={"moo": 2.},
                  included_in_final_ensemble=True,
              ),
          ],
      }, {
          "testcase_name":
              "materialize_metrics_non_tensor_op",
          "input_fn":
              tu.dummy_input_fn([[1., 2]], [[3.]]),
          "subnetwork_reports_fn":
              lambda features, labels: {
                  "foo":
                      subnetwork.Report(
                          hparams={},
                          attributes={},
                          metrics={"moo": (tf.constant(42), tf.no_op())},
                      ),
              },
          "steps":
              3,
          "included_subnetwork_names": ["foo"],
          "want_materialized_reports": [
              subnetwork.MaterializedReport(
                  iteration_number=0,
                  name="foo",
                  hparams={},
                  attributes={},
                  metrics={"moo": 42},
                  included_in_final_ensemble=True,
              ),
          ],
      })
  @test_util.run_in_graph_and_eager_modes
  def test_materialize_subnetwork_reports(self,
                                          input_fn,
                                          subnetwork_reports_fn,
                                          steps,
                                          iteration_number=0,
                                          included_subnetwork_names=None,
                                          want_materialized_reports=None):
    with context.graph_mode():
      tf.constant(0.)  # dummy op so that the session graph is never empty.
      features, labels = input_fn()
      subnetwork_reports = subnetwork_reports_fn(features, labels)
      with self.test_session() as sess:
        sess.run(tf_compat.v1.initializers.local_variables())
        report_materializer = ReportMaterializer(input_fn=input_fn, steps=steps)
        materialized_reports = (
            report_materializer.materialize_subnetwork_reports(
                sess, iteration_number, subnetwork_reports,
                included_subnetwork_names))
        self.assertEqual(
            len(want_materialized_reports), len(materialized_reports))
        materialized_reports_dict = {
            blrm.name: blrm for blrm in materialized_reports
        }
        for want_materialized_report in want_materialized_reports:
          materialized_report = (
              materialized_reports_dict[want_materialized_report.name])
          self.assertEqual(iteration_number,
                           materialized_report.iteration_number)
          self.assertEqual(
              set(want_materialized_report.hparams.keys()),
              set(materialized_report.hparams.keys()))
          for hparam_key, want_hparam in (
              want_materialized_report.hparams.items()):
            if isinstance(want_hparam, float):
              self.assertAllClose(want_hparam,
                                  materialized_report.hparams[hparam_key])
            else:
              self.assertEqual(want_hparam,
                               materialized_report.hparams[hparam_key])

          self.assertSetEqual(
              set(want_materialized_report.attributes.keys()),
              set(materialized_report.attributes.keys()))
          for attribute_key, want_attribute in (
              want_materialized_report.attributes.items()):
            if isinstance(want_attribute, float):
              self.assertAllClose(
                  want_attribute,
                  decode(materialized_report.attributes[attribute_key]))
            else:
              self.assertEqual(
                  want_attribute,
                  decode(materialized_report.attributes[attribute_key]))

          self.assertSetEqual(
              set(want_materialized_report.metrics.keys()),
              set(materialized_report.metrics.keys()))
          for metric_key, want_metric in (
              want_materialized_report.metrics.items()):
            if isinstance(want_metric, float):
              self.assertAllClose(
                  want_metric, decode(materialized_report.metrics[metric_key]))
            else:
              self.assertEqual(want_metric,
                               decode(materialized_report.metrics[metric_key]))