示例#1
0
    def test_loss_exceptions(self):
        """Check that ValueError is raised when from_logits=False for loss."""
        keras_model = model.Model(subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn")]),
                                  max_iteration_steps=1)

        loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

        with self.assertRaises(ValueError):
            keras_model.compile(loss=loss)
示例#2
0
  def test_tpu_estimator_simple_lifecycle(self, use_tpu, want_loss):
    config = tf.contrib.tpu.RunConfig(master="", tf_random_seed=42)
    estimator = TPUEstimator(
        head=tu.head(),
        subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn", use_tpu=use_tpu)]),
        max_iteration_steps=100,
        model_dir=self.test_subdirectory,
        config=config,
        use_tpu=use_tpu,
        train_batch_size=64 if use_tpu else 0)
    max_steps = 300

    xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
    xor_labels = [[1.], [0.], [1.], [0.]]
    train_input_fn = tu.dummy_input_fn(xor_features, xor_labels)

    # Train.
    estimator.train(
        input_fn=train_input_fn, steps=None, max_steps=max_steps, hooks=None)

    # Evaluate.
    eval_results = estimator.evaluate(
        input_fn=train_input_fn, steps=1, hooks=None)

    # Predict.
    # TODO: skip predictions on TF versions 1.11 and 1.12 since
    # some TPU hooks seem to be failing on predict.
    predictions = []
    tf_version = LooseVersion(tf.VERSION)
    if (tf_version != LooseVersion("1.11.0") and
        tf_version != LooseVersion("1.12.0")):
      predictions = estimator.predict(
          input_fn=tu.dataset_input_fn(features=[0., 0.], labels=None))

    # Export SavedModel.
    def serving_input_fn():
      """Input fn for serving export, starting from serialized example."""
      serialized_example = tf.placeholder(
          dtype=tf.string, shape=(None), name="serialized_example")
      return tf.estimator.export.ServingInputReceiver(
          features={"x": tf.constant([[0., 0.]], name="serving_x")},
          receiver_tensors=serialized_example)

    export_saved_model_fn = getattr(estimator, "export_saved_model", None)
    if not callable(export_saved_model_fn):
      export_saved_model_fn = estimator.export_savedmodel
    export_saved_model_fn(
        export_dir_base=estimator.model_dir,
        serving_input_receiver_fn=serving_input_fn)

    self.assertAlmostEqual(want_loss, eval_results["loss"], places=2)
    self.assertEqual(max_steps, eval_results["global_step"])
    for prediction in predictions:
      self.assertIsNotNone(prediction["predictions"])
示例#3
0
    def test_compile_exceptions(self):
        keras_model = model.Model(subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn")]),
                                  max_iteration_steps=1)
        train_data = tf.data.Dataset.from_tensors(([[1., 1.]], [[1.]]))
        predict_data = tf.data.Dataset.from_tensors(([[1., 1.]]))

        with self.assertRaises(RuntimeError):
            keras_model.fit(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.evaluate(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.predict(predict_data)
示例#4
0
def train_and_evaluate_estimator():
  """Runs Estimator distributed training."""

  # The tf.estimator.RunConfig automatically parses the TF_CONFIG environment
  # variables during construction.
  # For more information on how tf.estimator.RunConfig uses TF_CONFIG, see
  # https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig.
  config = tf.estimator.RunConfig(
      tf_random_seed=42,
      save_checkpoints_steps=10,
      save_checkpoints_secs=None,
      # Keep all checkpoints to avoid checkpoint GC causing failures during
      # evaluation.
      # TODO: Prevent checkpoints that are currently being
      # evaluated by another process from being garbage collected.
      keep_checkpoint_max=None,
      model_dir=FLAGS.model_dir,
      session_config=tf_compat.v1.ConfigProto(
          log_device_placement=False,
          # Ignore other workers; only talk to parameter servers.
          # Otherwise, when a chief/worker terminates, the others will hang.
          device_filters=["/job:ps"]))

  def input_fn():
    input_features = {"x": tf.constant(features, name="x")}
    input_labels = tf.constant(labels, name="y")
    return tf.data.Dataset.from_tensors((input_features, input_labels)).repeat()

  kwargs = {
      "max_iteration_steps": 100,
      "force_grow": True,
      "delay_secs_per_worker": .2,
      "max_worker_delay_secs": 1,
      "worker_wait_secs": 1,
      # Set low timeout to reduce wait time for failures.
      "worker_wait_timeout_secs": 180,
      "evaluator": Evaluator(input_fn, steps=10),
      "config": config
  }

  head = head_lib._regression_head(  # pylint: disable=protected-access
      loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)
  features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
  labels = [[1.], [0.], [1.], [0.]]

  estimator_type = FLAGS.estimator_type
  if FLAGS.placement_strategy == "round_robin":
    kwargs["experimental_placement_strategy"] = RoundRobinStrategy()
  if estimator_type == "autoensemble":
    feature_columns = [tf.feature_column.numeric_column("x", shape=[2])]
    # pylint: disable=g-long-lambda
    # TODO: Switch optimizers to tf.keras.optimizers.Adam once the
    # distribution bug is fixed.
    candidate_pool = {
        "linear":
            tf.estimator.LinearEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=lambda: tf_compat.v1.train.AdamOptimizer(
                    learning_rate=.001)),
        "dnn":
            tf.estimator.DNNEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=lambda: tf_compat.v1.train.AdamOptimizer(
                    learning_rate=.001),
                hidden_units=[3]),
        "dnn2":
            tf.estimator.DNNEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=lambda: tf_compat.v1.train.AdamOptimizer(
                    learning_rate=.001),
                hidden_units=[10, 10]),
    }
    # pylint: enable=g-long-lambda

    estimator = AutoEnsembleEstimator(
        head=head, candidate_pool=candidate_pool, **kwargs)
  elif estimator_type == "estimator":
    subnetwork_generator = SimpleGenerator([
        _DNNBuilder("dnn1", config, layer_size=3),
        _DNNBuilder("dnn2", config, layer_size=4),
        _DNNBuilder("dnn3", config, layer_size=5),
    ])

    estimator = Estimator(
        head=head, subnetwork_generator=subnetwork_generator, **kwargs)
  elif FLAGS.estimator_type == "autoensemble_trees_multiclass":
    if not bt_losses:
      logging.warning(
          "Skipped autoensemble_trees_multiclass test since contrib is missing."
      )
      return
    n_classes = 3
    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
        n_classes=n_classes,
        loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)

    def tree_loss_fn(labels, logits):
      result = bt_losses.per_example_maxent_loss(
          labels=labels, logits=logits, num_classes=n_classes, weights=None)
      return result[0]

    tree_head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
        loss_fn=tree_loss_fn,
        n_classes=n_classes,
        loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)
    labels = [[1], [0], [1], [2]]
    feature_columns = [tf.feature_column.numeric_column("x", shape=[2])]
    # TODO: Switch optimizers to tf.keras.optimizers.Adam once the
    # distribution bug is fixed.
    candidate_pool = lambda config: {  # pylint: disable=g-long-lambda
        "linear":
            tf.estimator.LinearEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=tf_compat.v1.train.AdamOptimizer(
                    learning_rate=.001),
                config=config),
        "gbdt":
            tf.estimator.BoostedTreesEstimator(
                head=tree_head,
                feature_columns=feature_columns,
                n_trees=10,
                n_batches_per_layer=1,
                center_bias=False,
                config=config),
    }

    estimator = AutoEnsembleEstimator(
        head=head, candidate_pool=candidate_pool, **kwargs)

  elif estimator_type == "estimator_with_experimental_multiworker_strategy":

    def _model_fn(features, labels, mode):
      """Test model_fn."""
      layer = tf.keras.layers.Dense(1)
      logits = layer(features["x"])

      if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {"logits": logits}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

      loss = tf.losses.mean_squared_error(
          labels=labels,
          predictions=logits,
          reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)

      if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss)

      if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(0.2)
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

    if json.loads(os.environ["TF_CONFIG"])["task"]["type"] == "evaluator":
      # The evaluator job would crash if MultiWorkerMirroredStrategy is called.
      distribution = None
    else:
      distribution = tf.distribute.experimental.MultiWorkerMirroredStrategy()

    multiworker_config = tf.estimator.RunConfig(
        tf_random_seed=42,
        model_dir=FLAGS.model_dir,
        train_distribute=distribution,
        session_config=tf_compat.v1.ConfigProto(log_device_placement=False))
    # TODO: Replace with adanet.Estimator. Currently this just verifies
    # that the distributed testing framework supports distribute strategies.
    estimator = tf.estimator.Estimator(
        model_fn=_model_fn, config=multiworker_config)

  train_hooks = [
      tf.estimator.ProfilerHook(save_steps=50, output_dir=FLAGS.model_dir)
  ]
  # Train for three iterations.
  train_spec = tf.estimator.TrainSpec(
      input_fn=input_fn, max_steps=300, hooks=train_hooks)
  eval_spec = tf.estimator.EvalSpec(
      input_fn=input_fn, steps=1, start_delay_secs=.5, throttle_secs=.05)

  # Calling train_and_evaluate is the official way to perform distributed
  # training with an Estimator. Calling Estimator#train directly results
  # in an error when the TF_CONFIG is setup for a cluster.
  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#5
0
    def test_tpu_estimator_summaries(self, use_tpu, want_loss,
                                     want_adanet_loss, want_eval_summary_loss,
                                     want_predictions):
        config = tf.contrib.tpu.RunConfig(tf_random_seed=42,
                                          save_summary_steps=100,
                                          log_step_count_steps=100)
        assert config.log_step_count_steps

        def metric_fn(predictions):
            return {
                "predictions":
                tf_compat.v1.metrics.mean(predictions["predictions"])
            }

        max_steps = 100
        estimator = TPUEstimator(head=tu.head(),
                                 subnetwork_generator=SimpleGenerator(
                                     [_DNNBuilder("dnn", use_tpu=use_tpu)]),
                                 max_iteration_steps=max_steps,
                                 model_dir=self.test_subdirectory,
                                 metric_fn=metric_fn,
                                 config=config,
                                 use_tpu=use_tpu,
                                 train_batch_size=64 if use_tpu else 0)
        xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
        xor_labels = [[1.], [0.], [1.], [0.]]
        train_input_fn = tu.dummy_input_fn(xor_features, xor_labels)

        estimator.train(input_fn=train_input_fn, max_steps=max_steps)
        eval_results = estimator.evaluate(input_fn=train_input_fn, steps=1)
        self.assertAlmostEqual(want_loss, eval_results["loss"], places=2)
        self.assertEqual(max_steps, eval_results["global_step"])
        self.assertEqual(0, eval_results["iteration"])

        subnetwork_subdir = os.path.join(self.test_subdirectory,
                                         "subnetwork/t0_dnn")

        ensemble_subdir = os.path.join(
            self.test_subdirectory,
            "ensemble/t0_dnn_grow_complexity_regularized")

        # TODO: Why is the adanet_loss written to 'loss'?
        self.assertAlmostEqual(want_adanet_loss,
                               _get_summary_value("loss",
                                                  self.test_subdirectory),
                               places=1)
        self.assertEqual(
            0.,
            _get_summary_value("iteration/adanet/iteration",
                               self.test_subdirectory))
        self.assertAlmostEqual(3.,
                               _get_summary_value("scalar", subnetwork_subdir),
                               places=3)
        self.assertEqual((3, 3, 1),
                         _get_summary_value("image/image/0",
                                            subnetwork_subdir))
        self.assertAlmostEqual(5.,
                               _get_summary_value("nested/scalar",
                                                  subnetwork_subdir),
                               places=3)
        self.assertAlmostEqual(
            want_adanet_loss,
            _get_summary_value("adanet_loss/adanet/adanet_weighted_ensemble",
                               ensemble_subdir),
            places=1)
        self.assertAlmostEqual(
            0.,
            _get_summary_value(
                "complexity_regularization/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=1)
        self.assertAlmostEqual(1.,
                               _get_summary_value(
                                   "mixture_weight_norms/adanet/"
                                   "adanet_weighted_ensemble/subnetwork_0",
                                   ensemble_subdir),
                               places=1)

        # Eval metric summaries are always written out during eval.
        subnetwork_eval_subdir = os.path.join(subnetwork_subdir, "eval")
        self.assertAlmostEqual(want_eval_summary_loss,
                               _get_summary_value("loss",
                                                  subnetwork_eval_subdir),
                               places=1)
        self.assertAlmostEqual(want_eval_summary_loss,
                               _get_summary_value("average_loss",
                                                  subnetwork_eval_subdir),
                               places=1)
        self.assertAlmostEqual(want_predictions,
                               _get_summary_value("predictions",
                                                  subnetwork_eval_subdir),
                               places=3)

        eval_subdir = os.path.join(self.test_subdirectory, "eval")
        ensemble_eval_subdir = os.path.join(ensemble_subdir, "eval")
        for subdir in [ensemble_eval_subdir, eval_subdir]:
            self.assertEqual([b"| dnn |"],
                             _get_summary_value(
                                 "architecture/adanet/ensembles/0", subdir))
            if subdir == eval_subdir:
                self.assertAlmostEqual(want_loss,
                                       _get_summary_value("loss", subdir),
                                       places=1)
            self.assertAlmostEqual(want_eval_summary_loss,
                                   _get_summary_value("average_loss", subdir),
                                   places=1)
示例#6
0
class TPUEstimatorTest(tu.AdanetTestCase):
    def setUp(self):
        super(TPUEstimatorTest, self).setUp()

        if not tf_compat.version_greater_or_equal("1.14.0"):
            self.skipTest(
                "TPUEmbedding not supported in version 1.13.0 and below.")

        # TPUConfig initializes model_dir from TF_CONFIG and checks that the user
        # provided model_dir matches the TF_CONFIG one.
        tf_config = {"model_dir": self.test_subdirectory}
        os.environ["TF_CONFIG"] = json.dumps(tf_config)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "not_use_tpu",
            "use_tpu":
            False,
            "subnetwork_generator":
            SimpleGenerator([_DNNBuilder("dnn", use_tpu=False)]),
            "want_loss":
            .27357,
        }, )
    def test_tpu_estimator_simple_lifecycle(self, use_tpu,
                                            subnetwork_generator, want_loss):
        config = tf.contrib.tpu.RunConfig(master="", tf_random_seed=42)
        estimator = TPUEstimator(head=tu.head(),
                                 subnetwork_generator=subnetwork_generator,
                                 max_iteration_steps=100,
                                 model_dir=self.test_subdirectory,
                                 config=config,
                                 use_tpu=use_tpu,
                                 train_batch_size=64 if use_tpu else 0)
        max_steps = 300

        xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
        xor_labels = [[1.], [0.], [1.], [0.]]
        train_input_fn = tu.dummy_input_fn(xor_features, xor_labels)

        # Train.
        estimator.train(input_fn=train_input_fn,
                        steps=None,
                        max_steps=max_steps,
                        hooks=None)

        # Evaluate.
        eval_results = estimator.evaluate(input_fn=train_input_fn,
                                          steps=1,
                                          hooks=None)

        # Predict.
        predictions = estimator.predict(
            input_fn=tu.dataset_input_fn(features=[0., 0.], labels=None))

        # Export SavedModel.
        def serving_input_fn():
            """Input fn for serving export, starting from serialized example."""
            serialized_example = tf.placeholder(dtype=tf.string,
                                                shape=(None),
                                                name="serialized_example")
            return tf.estimator.export.ServingInputReceiver(
                features={"x": tf.constant([[0., 0.]], name="serving_x")},
                receiver_tensors=serialized_example)

        export_saved_model_fn = getattr(estimator, "export_saved_model", None)
        if not callable(export_saved_model_fn):
            export_saved_model_fn = estimator.export_savedmodel
        export_saved_model_fn(export_dir_base=estimator.model_dir,
                              serving_input_receiver_fn=serving_input_fn)

        self.assertAlmostEqual(want_loss, eval_results["loss"], places=2)
        self.assertEqual(max_steps, eval_results["global_step"])
        self.assertEqual(2, eval_results["iteration"])
        for prediction in predictions:
            self.assertIsNotNone(prediction["predictions"])

    @parameterized.named_parameters(
        {
            "testcase_name": "not_use_tpu",
            "use_tpu": False,
            "want_loss": .31239,
            "want_adanet_loss": .64416,
            "want_eval_summary_loss": .31239,
            "want_predictions": .45473,
        }, )
    def test_tpu_estimator_summaries(self, use_tpu, want_loss,
                                     want_adanet_loss, want_eval_summary_loss,
                                     want_predictions):
        config = tf.contrib.tpu.RunConfig(tf_random_seed=42,
                                          save_summary_steps=100,
                                          log_step_count_steps=100)
        assert config.log_step_count_steps

        def metric_fn(predictions):
            return {
                "predictions":
                tf_compat.v1.metrics.mean(predictions["predictions"])
            }

        max_steps = 100
        estimator = TPUEstimator(head=tu.head(),
                                 subnetwork_generator=SimpleGenerator(
                                     [_DNNBuilder("dnn", use_tpu=use_tpu)]),
                                 max_iteration_steps=max_steps,
                                 model_dir=self.test_subdirectory,
                                 metric_fn=metric_fn,
                                 config=config,
                                 use_tpu=use_tpu,
                                 train_batch_size=64 if use_tpu else 0)
        xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
        xor_labels = [[1.], [0.], [1.], [0.]]
        train_input_fn = tu.dummy_input_fn(xor_features, xor_labels)

        estimator.train(input_fn=train_input_fn, max_steps=max_steps)
        eval_results = estimator.evaluate(input_fn=train_input_fn, steps=1)
        self.assertAlmostEqual(want_loss, eval_results["loss"], places=2)
        self.assertEqual(max_steps, eval_results["global_step"])
        self.assertEqual(0, eval_results["iteration"])

        subnetwork_subdir = os.path.join(self.test_subdirectory,
                                         "subnetwork/t0_dnn")

        ensemble_subdir = os.path.join(
            self.test_subdirectory,
            "ensemble/t0_dnn_grow_complexity_regularized")

        # TODO: Why is the adanet_loss written to 'loss'?
        self.assertAlmostEqual(want_adanet_loss,
                               _get_summary_value("loss",
                                                  self.test_subdirectory),
                               places=1)
        self.assertEqual(
            0.,
            _get_summary_value("iteration/adanet/iteration",
                               self.test_subdirectory))
        self.assertAlmostEqual(3.,
                               _get_summary_value("scalar", subnetwork_subdir),
                               places=3)
        self.assertEqual((3, 3, 1),
                         _get_summary_value("image/image/0",
                                            subnetwork_subdir))
        self.assertAlmostEqual(5.,
                               _get_summary_value("nested/scalar",
                                                  subnetwork_subdir),
                               places=3)
        self.assertAlmostEqual(
            want_adanet_loss,
            _get_summary_value("adanet_loss/adanet/adanet_weighted_ensemble",
                               ensemble_subdir),
            places=1)
        self.assertAlmostEqual(
            0.,
            _get_summary_value(
                "complexity_regularization/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=1)
        self.assertAlmostEqual(1.,
                               _get_summary_value(
                                   "mixture_weight_norms/adanet/"
                                   "adanet_weighted_ensemble/subnetwork_0",
                                   ensemble_subdir),
                               places=1)

        # Eval metric summaries are always written out during eval.
        subnetwork_eval_subdir = os.path.join(subnetwork_subdir, "eval")
        self.assertAlmostEqual(want_eval_summary_loss,
                               _get_summary_value("loss",
                                                  subnetwork_eval_subdir),
                               places=1)
        self.assertAlmostEqual(want_eval_summary_loss,
                               _get_summary_value("average_loss",
                                                  subnetwork_eval_subdir),
                               places=1)
        self.assertAlmostEqual(want_predictions,
                               _get_summary_value("predictions",
                                                  subnetwork_eval_subdir),
                               places=3)

        eval_subdir = os.path.join(self.test_subdirectory, "eval")
        ensemble_eval_subdir = os.path.join(ensemble_subdir, "eval")
        for subdir in [ensemble_eval_subdir, eval_subdir]:
            self.assertEqual([b"| dnn |"],
                             _get_summary_value(
                                 "architecture/adanet/ensembles/0", subdir))
            if subdir == eval_subdir:
                self.assertAlmostEqual(want_loss,
                                       _get_summary_value("loss", subdir),
                                       places=1)
            self.assertAlmostEqual(want_eval_summary_loss,
                                   _get_summary_value("average_loss", subdir),
                                   places=1)
def train_and_evaluate_estimator():
    """Runs Estimator distributed training."""

    # The tf.estimator.RunConfig automatically parses the TF_CONFIG environment
    # variables during construction.
    # For more information on how tf.estimator.RunConfig uses TF_CONFIG, see
    # https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig.
    config = tf.estimator.RunConfig(
        tf_random_seed=42,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(
            log_device_placement=False,
            # Ignore other workers; only talk to parameter servers.
            # Otherwise, when a chief/worker terminates, the others will hang.
            device_filters=["/job:ps"]))
    head = tf.contrib.estimator.regression_head(
        loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)

    kwargs = {
        "max_iteration_steps": 100,
        "force_grow": True,
        "delay_secs_per_worker": .2,
        "max_worker_delay_secs": 1,
        "worker_wait_secs": .5,
        # Set low timeout to reduce wait time for failures.
        "worker_wait_timeout_secs": 60,
        "config": config
    }
    if FLAGS.placement_strategy == "round_robin":
        kwargs["experimental_placement_strategy"] = RoundRobinStrategy()
    if FLAGS.estimator_type == "autoensemble":
        feature_columns = [tf.feature_column.numeric_column("x", shape=[2])]
        if hasattr(tf.estimator, "LinearEstimator"):
            linear_estimator_fn = tf.estimator.LinearEstimator
        else:
            linear_estimator_fn = tf.contrib.estimator.LinearEstimator
        if hasattr(tf.estimator, "DNNEstimator"):
            dnn_estimator_fn = tf.estimator.DNNEstimator
        else:
            dnn_estimator_fn = tf.contrib.estimator.DNNEstimator
        candidate_pool = {
            "linear":
            linear_estimator_fn(
                head=head,
                feature_columns=feature_columns,
                optimizer=tf.train.AdamOptimizer(learning_rate=.001)),
            "dnn":
            dnn_estimator_fn(
                head=head,
                feature_columns=feature_columns,
                optimizer=tf.train.AdamOptimizer(learning_rate=.001),
                hidden_units=[3]),
            "dnn2":
            dnn_estimator_fn(
                head=head,
                feature_columns=feature_columns,
                optimizer=tf.train.AdamOptimizer(learning_rate=.001),
                hidden_units=[5])
        }

        estimator = AutoEnsembleEstimator(head=head,
                                          candidate_pool=candidate_pool,
                                          **kwargs)

    elif FLAGS.estimator_type == "estimator":
        subnetwork_generator = SimpleGenerator([
            _DNNBuilder("dnn1", config, layer_size=3),
            _DNNBuilder("dnn2", config, layer_size=4),
            _DNNBuilder("dnn3", config, layer_size=5),
        ])

        estimator = Estimator(head=head,
                              subnetwork_generator=subnetwork_generator,
                              **kwargs)

    def input_fn():
        xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
        xor_labels = [[1.], [0.], [1.], [0.]]
        input_features = {"x": tf.constant(xor_features, name="x")}
        input_labels = tf.constant(xor_labels, name="y")
        return input_features, input_labels

    train_hooks = []
    # ProfilerHook raises the following error in older TensorFlow versions:
    # ValueError: The provided tag was already used for this event type.
    if LooseVersion(tf.VERSION) >= LooseVersion("1.13.0"):
        train_hooks = [
            tf.train.ProfilerHook(save_steps=50, output_dir=FLAGS.model_dir)
        ]
    # Train for three iterations.
    train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
                                        max_steps=300,
                                        hooks=train_hooks)
    eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
                                      steps=1,
                                      start_delay_secs=.5,
                                      throttle_secs=.5)

    # Calling train_and_evaluate is the official way to perform distributed
    # training with an Estimator. Calling Estimator#train directly results
    # in an error when the TF_CONFIG is setup for a cluster.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#8
0
class ModelTest(tu.AdanetTestCase):
    @parameterized.named_parameters(
        {
            "testcase_name": "one_step_binary_crossentropy_loss",
            "loss": "binary_crossentropy",
            "subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
            "max_iteration_steps": 1,
            "epochs": 1,
            "steps_per_epoch": 3,
            "want_loss": 0.7690,
        },
        {
            "testcase_name": "one_step_mse_loss",
            "loss": "mse",
            "subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
            "max_iteration_steps": 1,
            "epochs": 1,
            "steps_per_epoch": 3,
            "want_loss": 0.6354,
        },
        {
            "testcase_name":
            "one_step_sparse_categorical_crossentropy_loss",
            "loss":
            "sparse_categorical_crossentropy",
            "subnetwork_generator":
            SimpleGenerator([_DNNBuilder("dnn")]),
            "max_iteration_steps":
            1,
            "epochs":
            1,
            "steps_per_epoch":
            3,
            "want_loss":
            1.2521,
            "logits_dimension":
            3,
            "dataset":
            lambda: tf.data.Dataset.from_tensors((
                {
                    "x": XOR_FEATURES
                },  # pylint: disable=g-long-lambda
                XOR_CLASS_LABELS))
        })
    @test_util.run_in_graph_and_eager_modes
    def test_lifecycle(self,
                       loss,
                       subnetwork_generator,
                       max_iteration_steps,
                       want_loss,
                       logits_dimension=1,
                       ensemblers=None,
                       ensemble_strategies=None,
                       evaluator=None,
                       adanet_loss_decay=0.9,
                       dataset=None,
                       epochs=None,
                       steps_per_epoch=None):

        keras_model = model.Model(subnetwork_generator=subnetwork_generator,
                                  max_iteration_steps=max_iteration_steps,
                                  logits_dimension=logits_dimension,
                                  ensemblers=ensemblers,
                                  ensemble_strategies=ensemble_strategies,
                                  evaluator=evaluator,
                                  adanet_loss_decay=adanet_loss_decay,
                                  filepath=self.test_subdirectory)

        keras_model.compile(loss=loss)
        # Make sure we have access to metrics_names immediately after compilation.
        self.assertEqual(["loss"], keras_model.metrics_names)

        if dataset is None:
            dataset = lambda: tf.data.Dataset.from_tensors(  # pylint: disable=g-long-lambda
                ({
                    "x": XOR_FEATURES
                }, XOR_LABELS)).repeat()

        keras_model.fit(dataset,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch)

        eval_results = keras_model.evaluate(dataset, steps=3)
        self.assertAlmostEqual(want_loss, eval_results["loss"], places=3)

        # TODO: Predict not currently working for BinaryClassHead and
        #                   MultiClassHead.
        if loss == "mse":
            prediction_data = lambda: tf.data.Dataset.from_tensors(
                ({  # pylint: disable=g-long-lambda
                    "x": XOR_FEATURES
                }))

            predictions = keras_model.predict(prediction_data)
            self.assertLen(predictions, 4)

    @test_util.run_in_graph_and_eager_modes
    def test_compile_exceptions(self):
        keras_model = model.Model(subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn")]),
                                  max_iteration_steps=1)
        train_data = tf.data.Dataset.from_tensors(([[1., 1.]], [[1.]]))
        predict_data = tf.data.Dataset.from_tensors(([[1., 1.]]))

        with self.assertRaises(RuntimeError):
            keras_model.fit(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.evaluate(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.predict(predict_data)
示例#9
0
class ModelTest(tu.AdanetTestCase):
    @parameterized.named_parameters({
        "testcase_name":
        "one_step",
        "subnetwork_generator":
        SimpleGenerator([_DNNBuilder("dnn")]),
        "max_iteration_steps":
        1,
        "epochs":
        1,
        "steps_per_epoch":
        1,
        "want_loss":
        1.2208,
    })
    def test_lifecycle(self,
                       subnetwork_generator,
                       max_iteration_steps,
                       want_loss,
                       ensemblers=None,
                       ensemble_strategies=None,
                       evaluator=None,
                       adanet_loss_decay=0.9,
                       dataset=None,
                       epochs=None,
                       steps_per_epoch=None):

        keras_model = model.Model(subnetwork_generator=subnetwork_generator,
                                  max_iteration_steps=max_iteration_steps,
                                  ensemblers=ensemblers,
                                  ensemble_strategies=ensemble_strategies,
                                  evaluator=evaluator,
                                  adanet_loss_decay=adanet_loss_decay,
                                  filepath=self.test_subdirectory)

        keras_model.compile(loss="mse")
        # Make sure we have access to metrics_names immediately after compilation.
        self.assertEqual(["loss"], keras_model.metrics_names)

        if dataset is None:
            dataset = tf.data.Dataset.from_tensors(
                (XOR_FEATURES, XOR_LABELS)).repeat()

        keras_model.fit(dataset,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch)

        eval_results = keras_model.evaluate(dataset, steps=3)
        self.assertAlmostEqual(want_loss, eval_results["loss"], places=3)

        prediction_data = tf.data.Dataset.from_tensors((XOR_FEATURES))

        predictions = keras_model.predict(prediction_data)
        self.assertLen(predictions, 4)

    def test_compile_exceptions(self):
        keras_model = model.Model(subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn")]),
                                  max_iteration_steps=1)
        train_data = tf.data.Dataset.from_tensors(([[1., 1.]], [[1.]]))
        predict_data = tf.data.Dataset.from_tensors(([[1., 1.]]))

        with self.assertRaises(RuntimeError):
            keras_model.fit(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.evaluate(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.predict(predict_data)
def train_and_evaluate_estimator():
    """Runs Estimator distributed training."""

    # The tf.estimator.RunConfig automatically parses the TF_CONFIG environment
    # variables during construction.
    # For more information on how tf.estimator.RunConfig uses TF_CONFIG, see
    # https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig.
    config = tf.estimator.RunConfig(
        tf_random_seed=42,
        model_dir=FLAGS.model_dir,
        session_config=tf_compat.v1.ConfigProto(
            log_device_placement=False,
            # Ignore other workers; only talk to parameter servers.
            # Otherwise, when a chief/worker terminates, the others will hang.
            device_filters=["/job:ps"]))

    kwargs = {
        "max_iteration_steps": 100,
        "force_grow": True,
        "delay_secs_per_worker": .2,
        "max_worker_delay_secs": 1,
        "worker_wait_secs": .5,
        # Set low timeout to reduce wait time for failures.
        "worker_wait_timeout_secs": 60,
        "config": config
    }
    head = regression_head.RegressionHead(
        loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE)
    features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
    labels = [[1.], [0.], [1.], [0.]]
    if FLAGS.placement_strategy == "round_robin":
        kwargs["experimental_placement_strategy"] = RoundRobinStrategy()
    if FLAGS.estimator_type == "autoensemble":
        feature_columns = [tf.feature_column.numeric_column("x", shape=[2])]
        candidate_pool = {
            "linear":
            tf.estimator.LinearEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=lambda: tf.keras.optimizers.Adam(lr=.001)),
            "dnn":
            tf.estimator.DNNEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=lambda: tf.keras.optimizers.Adam(lr=.001),
                hidden_units=[3]),
            "dnn2":
            tf.estimator.DNNEstimator(
                head=head,
                feature_columns=feature_columns,
                optimizer=lambda: tf.keras.optimizers.Adam(lr=.001),
                hidden_units=[5]),
        }

        estimator = AutoEnsembleEstimator(head=head,
                                          candidate_pool=candidate_pool,
                                          **kwargs)
    elif FLAGS.estimator_type == "estimator":
        subnetwork_generator = SimpleGenerator([
            _DNNBuilder("dnn1", config, layer_size=3),
            _DNNBuilder("dnn2", config, layer_size=4),
            _DNNBuilder("dnn3", config, layer_size=5),
        ])

        estimator = Estimator(head=head,
                              subnetwork_generator=subnetwork_generator,
                              **kwargs)
    elif FLAGS.estimator_type == "autoensemble_trees_multiclass":
        n_classes = 3
        head = multi_class_head.MultiClassHead(
            n_classes=n_classes, loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE)

        def tree_loss_fn(labels, logits):
            result = bt_losses.per_example_maxent_loss(labels=labels,
                                                       logits=logits,
                                                       num_classes=n_classes,
                                                       weights=None)
            return result[0]

        tree_head = multi_class_head.MultiClassHead(
            loss_fn=tree_loss_fn,
            n_classes=n_classes,
            loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE)
        labels = [[1], [0], [1], [2]]
        feature_columns = [tf.feature_column.numeric_column("x", shape=[2])]
        candidate_pool = lambda config: {  # pylint: disable=g-long-lambda
            "linear":
                tf.estimator.LinearEstimator(
                    head=head,
                    feature_columns=feature_columns,
                    optimizer=tf.keras.optimizers.Adam(lr=.001),
                    config=config),
            "gbdt":
                CoreGradientBoostedDecisionTreeEstimator(
                    head=tree_head,
                    learner_config=learner_pb2.LearnerConfig(num_classes=n_classes),
                    examples_per_layer=8,
                    num_trees=None,
                    center_bias=False,  # Required for multi-class.
                    feature_columns=feature_columns,
                    config=config),
        }

        estimator = AutoEnsembleEstimator(head=head,
                                          candidate_pool=candidate_pool,
                                          **kwargs)

    def input_fn():
        input_features = {"x": tf.constant(features, name="x")}
        input_labels = tf.constant(labels, name="y")
        return input_features, input_labels

    train_hooks = [
        tf.estimator.ProfilerHook(save_steps=50, output_dir=FLAGS.model_dir)
    ]
    # Train for three iterations.
    train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
                                        max_steps=300,
                                        hooks=train_hooks)
    eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
                                      steps=1,
                                      start_delay_secs=.5,
                                      throttle_secs=.5)

    # Calling train_and_evaluate is the official way to perform distributed
    # training with an Estimator. Calling Estimator#train directly results
    # in an error when the TF_CONFIG is setup for a cluster.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#11
0
    def test_summaries(self):
        """Tests that summaries are written to candidate directory."""

        run_config = tf.estimator.RunConfig(tf_random_seed=42,
                                            log_step_count_steps=2,
                                            save_summary_steps=2,
                                            model_dir=self.test_subdirectory)
        subnetwork_generator = SimpleGenerator([_SimpleBuilder("dnn")])
        report_materializer = ReportMaterializer(input_fn=tu.dummy_input_fn(
            [[1., 1.]], [[0.]]),
                                                 steps=1)
        estimator = Estimator(head=regression_head.RegressionHead(
            loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE),
                              subnetwork_generator=subnetwork_generator,
                              report_materializer=report_materializer,
                              max_iteration_steps=10,
                              config=run_config)
        train_input_fn = tu.dummy_input_fn([[1., 0.]], [[1.]])
        estimator.train(input_fn=train_input_fn, max_steps=3)

        ensemble_loss = 1.52950
        self.assertAlmostEqual(ensemble_loss,
                               tu.check_eventfile_for_keyword(
                                   "loss", self.test_subdirectory),
                               places=3)
        self.assertIsNotNone(
            tu.check_eventfile_for_keyword("global_step/sec",
                                           self.test_subdirectory))
        self.assertEqual(
            0.,
            tu.check_eventfile_for_keyword("iteration/adanet/iteration",
                                           self.test_subdirectory))

        subnetwork_subdir = os.path.join(self.test_subdirectory,
                                         "subnetwork/t0_dnn")
        self.assertAlmostEqual(3.,
                               tu.check_eventfile_for_keyword(
                                   "scalar", subnetwork_subdir),
                               places=3)
        self.assertEqual(
            (3, 3, 1),
            tu.check_eventfile_for_keyword("image", subnetwork_subdir))
        self.assertAlmostEqual(5.,
                               tu.check_eventfile_for_keyword(
                                   "nested/scalar", subnetwork_subdir),
                               places=3)

        ensemble_subdir = os.path.join(
            self.test_subdirectory,
            "ensemble/t0_dnn_grow_complexity_regularized")
        self.assertAlmostEqual(
            ensemble_loss,
            tu.check_eventfile_for_keyword(
                "adanet_loss/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=1)
        self.assertAlmostEqual(
            0.,
            tu.check_eventfile_for_keyword(
                "complexity_regularization/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=3)
        self.assertAlmostEqual(1.,
                               tu.check_eventfile_for_keyword(
                                   "mixture_weight_norms/adanet/"
                                   "adanet_weighted_ensemble/subnetwork_0",
                                   ensemble_subdir),
                               places=3)
示例#12
0
class ModelTest(tu.AdanetTestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        {
            "testcase_name":
            "one_step_binary_crossentropy_loss",
            "loss":
            tf.keras.losses.BinaryCrossentropy(from_logits=True),
            "metrics": [
                lambda: tf.keras.metrics.BinaryCrossentropy(name="bin_acc",
                                                            from_logits=True)
            ],
            "subnetwork_generator":
            SimpleGenerator([_DNNBuilder("dnn")]),
            "max_iteration_steps":
            1,
            "epochs":
            1,
            "steps_per_epoch":
            3,
            "want_metrics_names": ["loss", "bin_acc"],
            "want_loss":
            0.7690,
            "want_metrics": [0.7690]
        },
        {
            "testcase_name": "one_step_mse_loss",
            "loss": tf.keras.losses.MeanSquaredError(),
            "metrics":
            [lambda: tf.keras.metrics.MeanAbsoluteError(name="mae")],
            "subnetwork_generator": SimpleGenerator([_DNNBuilder("dnn")]),
            "max_iteration_steps": 1,
            "epochs": 1,
            "steps_per_epoch": 3,
            "want_metrics_names": ["loss", "mae"],
            "want_loss": 0.6354,
            "want_metrics": [0.6191]
        },
        {
            "testcase_name":
            "one_step_sparse_categorical_crossentropy_loss",
            "loss":
            tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            "subnetwork_generator":
            SimpleGenerator([_DNNBuilder("dnn")]),
            "max_iteration_steps":
            1,
            "epochs":
            1,
            "steps_per_epoch":
            3,
            "want_loss":
            1.2521,
            "logits_dimension":
            3,
            "dataset":
            lambda: tf.data.Dataset.from_tensors((
                {
                    "x": XOR_FEATURES
                },  # pylint: disable=g-long-lambda
                XOR_CLASS_LABELS))
        })
    # pylint: enable=g-long-lambda
    @test_util.run_in_graph_and_eager_modes
    def test_lifecycle(self,
                       loss,
                       subnetwork_generator,
                       max_iteration_steps,
                       want_loss,
                       want_metrics=None,
                       want_metrics_names=None,
                       metrics=None,
                       logits_dimension=1,
                       ensemblers=None,
                       ensemble_strategies=None,
                       evaluator=None,
                       adanet_loss_decay=0.9,
                       dataset=None,
                       epochs=None,
                       steps_per_epoch=None):

        keras_model = model.Model(subnetwork_generator=subnetwork_generator,
                                  max_iteration_steps=max_iteration_steps,
                                  logits_dimension=logits_dimension,
                                  ensemblers=ensemblers,
                                  ensemble_strategies=ensemble_strategies,
                                  evaluator=evaluator,
                                  adanet_loss_decay=adanet_loss_decay,
                                  filepath=self.test_subdirectory)

        keras_model.compile(loss=loss, metrics=metrics)
        if want_metrics_names is None:
            want_metrics_names = ["loss"]
        # Make sure we have access to metrics_names immediately after compilation.
        self.assertEqual(want_metrics_names, keras_model.metrics_names)

        if dataset is None:
            dataset = lambda: tf.data.Dataset.from_tensors(  # pylint: disable=g-long-lambda
                ({
                    "x": XOR_FEATURES
                }, XOR_LABELS)).repeat()

        keras_model.fit(dataset,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch)

        eval_results = keras_model.evaluate(dataset, steps=3)
        self.assertAlmostEqual(want_loss, eval_results[0], places=3)
        if metrics:
            self.assertAllClose(want_metrics, eval_results[1:], 1e-3, 1e-3)

        prediction_data = lambda: tf.data.Dataset.from_tensors(
            ({  # pylint: disable=g-long-lambda
                "x": XOR_FEATURES
            }))

        # TODO: Change the assertion to actually check the values rather
        # than the length of the returned predictions array.
        predictions = keras_model.predict(prediction_data)
        self.assertLen(predictions, 4)

    @test_util.run_in_graph_and_eager_modes
    def test_compile_exceptions(self):
        keras_model = model.Model(subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn")]),
                                  max_iteration_steps=1)
        train_data = tf.data.Dataset.from_tensors(([[1., 1.]], [[1.]]))
        predict_data = tf.data.Dataset.from_tensors(([[1., 1.]]))

        with self.assertRaises(RuntimeError):
            keras_model.fit(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.evaluate(train_data)

        with self.assertRaises(RuntimeError):
            keras_model.predict(predict_data)

    @test_util.run_in_graph_and_eager_modes
    def test_loss_exceptions(self):
        """Check that ValueError is raised when from_logits=False for loss."""
        keras_model = model.Model(subnetwork_generator=SimpleGenerator(
            [_DNNBuilder("dnn")]),
                                  max_iteration_steps=1)

        loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

        with self.assertRaises(ValueError):
            keras_model.compile(loss=loss)
class TPUEstimatorTest(tu.AdanetTestCase):
    def setUp(self):
        super(TPUEstimatorTest, self).setUp()

        if not tf_compat.version_greater_or_equal("1.14.0"):
            self.skipTest(
                "TPUEmbedding not supported in version 1.13.0 and below.")

        # TPUConfig initializes model_dir from TF_CONFIG and checks that the user
        # provided model_dir matches the TF_CONFIG one.
        tf_config = {"model_dir": self.test_subdirectory}
        os.environ["TF_CONFIG"] = json.dumps(tf_config)

    def tearDown(self):
        super(TPUEstimatorTest, self).tearDown()
        del os.environ["TF_CONFIG"]

    @parameterized.named_parameters(
        {
            "testcase_name":
            "not_use_tpu",
            "use_tpu":
            False,
            "subnetwork_generator":
            SimpleGenerator([_DNNBuilder("dnn", use_tpu=False)]),
            "want_loss":
            0.41315794,
        }, )
    def test_tpu_estimator_simple_lifecycle(self, use_tpu,
                                            subnetwork_generator, want_loss):
        config = tf.compat.v1.estimator.tpu.RunConfig(master="",
                                                      tf_random_seed=42)
        estimator = TPUEstimator(
            # TODO: Add test with estimator Head v2.
            head=make_regression_head(use_tpu),
            subnetwork_generator=subnetwork_generator,
            max_iteration_steps=10,
            model_dir=self.test_subdirectory,
            config=config,
            use_tpu=use_tpu,
            train_batch_size=64 if use_tpu else 0)
        max_steps = 30

        xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
        xor_labels = [[1.], [0.], [1.], [0.]]
        train_input_fn = tu.dummy_input_fn(xor_features, xor_labels)

        # Train.
        estimator.train(input_fn=train_input_fn,
                        steps=None,
                        max_steps=max_steps,
                        hooks=None)

        # Evaluate.
        eval_results = estimator.evaluate(input_fn=train_input_fn,
                                          steps=1,
                                          hooks=None)

        # Predict.
        predictions = estimator.predict(input_fn=tu.dataset_input_fn(
            features=[0., 0.], return_dataset=True))
        # We need to iterate over all the predictions before moving on, otherwise
        # the TPU will not be shut down.
        for prediction in predictions:
            self.assertIsNotNone(prediction["predictions"])

        # Export SavedModel.
        def serving_input_fn():
            """Input fn for serving export, starting from serialized example."""
            serialized_example = tf.compat.v1.placeholder(
                dtype=tf.string, shape=(None), name="serialized_example")
            return tf.estimator.export.ServingInputReceiver(
                features={"x": tf.constant([[0., 0.]], name="serving_x")},
                receiver_tensors=serialized_example)

        estimator.export_saved_model(
            export_dir_base=estimator.model_dir,
            serving_input_receiver_fn=serving_input_fn)

        self.assertAlmostEqual(want_loss, eval_results["loss"], places=2)
        self.assertEqual(max_steps, eval_results["global_step"])
        self.assertEqual(2, eval_results["iteration"])

    @parameterized.named_parameters(
        {
            "testcase_name": "not_use_tpu",
            "use_tpu": False,
            "want_loss": 0.55584925,
            "want_adanet_loss": .64416,
            "want_eval_summary_loss": 0.555849,
            "want_predictions": 0.46818,
        }, )
    def test_tpu_estimator_summaries(self, use_tpu, want_loss,
                                     want_adanet_loss, want_eval_summary_loss,
                                     want_predictions):
        max_steps = 10
        config = tf.compat.v1.estimator.tpu.RunConfig(
            tf_random_seed=42,
            save_summary_steps=2,
            log_step_count_steps=max_steps)
        assert config.log_step_count_steps

        def metric_fn(predictions):
            return {
                "predictions":
                tf_compat.v1.metrics.mean(predictions["predictions"])
            }

        estimator = TPUEstimator(head=make_regression_head(use_tpu),
                                 subnetwork_generator=SimpleGenerator(
                                     [_DNNBuilder("dnn", use_tpu=use_tpu)]),
                                 max_iteration_steps=max_steps,
                                 model_dir=self.test_subdirectory,
                                 metric_fn=metric_fn,
                                 config=config,
                                 use_tpu=use_tpu,
                                 train_batch_size=64 if use_tpu else 0)
        xor_features = [[1., 0.], [0., 0], [0., 1.], [1., 1.]]
        xor_labels = [[1.], [0.], [1.], [0.]]
        train_input_fn = tu.dummy_input_fn(xor_features, xor_labels)

        estimator.train(input_fn=train_input_fn, max_steps=max_steps)
        eval_results = estimator.evaluate(input_fn=train_input_fn, steps=1)
        self.assertAlmostEqual(want_loss, eval_results["loss"], places=2)
        self.assertEqual(max_steps, eval_results["global_step"])
        self.assertEqual(0, eval_results["iteration"])

        subnetwork_subdir = os.path.join(self.test_subdirectory,
                                         "subnetwork/t0_dnn")

        ensemble_subdir = os.path.join(
            self.test_subdirectory,
            "ensemble/t0_dnn_grow_complexity_regularized")

        # TODO: Why is the adanet_loss written to 'loss'?
        self.assertAlmostEqual(want_adanet_loss,
                               tu.check_eventfile_for_keyword(
                                   "loss", self.test_subdirectory),
                               places=1)
        self.assertEqual(
            0.,
            tu.check_eventfile_for_keyword("iteration/adanet/iteration",
                                           self.test_subdirectory))
        self.assertAlmostEqual(3.,
                               tu.check_eventfile_for_keyword(
                                   "scalar", subnetwork_subdir),
                               places=3)
        self.assertEqual(
            (3, 3, 1),
            tu.check_eventfile_for_keyword(
                # When TF 2 behavior is enabled AdaNet uses V2 summaries.
                "image"
                if tf_compat.is_v2_behavior_enabled() else "image/image/0",
                subnetwork_subdir))
        self.assertAlmostEqual(5.,
                               tu.check_eventfile_for_keyword(
                                   "nested/scalar", subnetwork_subdir),
                               places=3)
        self.assertAlmostEqual(
            want_adanet_loss,
            tu.check_eventfile_for_keyword(
                "adanet_loss/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=1)
        self.assertAlmostEqual(
            0.,
            tu.check_eventfile_for_keyword(
                "complexity_regularization/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=1)
        self.assertAlmostEqual(1.,
                               tu.check_eventfile_for_keyword(
                                   "mixture_weight_norms/adanet/"
                                   "adanet_weighted_ensemble/subnetwork_0",
                                   ensemble_subdir),
                               places=1)

        # Eval metric summaries are always written out during eval.
        subnetwork_eval_subdir = os.path.join(subnetwork_subdir, "eval")
        self.assertAlmostEqual(want_eval_summary_loss,
                               tu.check_eventfile_for_keyword(
                                   "loss", subnetwork_eval_subdir),
                               places=1)
        # TODO: Check why some eval metrics are zero on TPU.
        self.assertAlmostEqual(0.0 if use_tpu else want_eval_summary_loss,
                               tu.check_eventfile_for_keyword(
                                   "average_loss", subnetwork_eval_subdir),
                               places=1)
        self.assertAlmostEqual(want_predictions,
                               tu.check_eventfile_for_keyword(
                                   "predictions", subnetwork_eval_subdir),
                               places=3)

        eval_subdir = os.path.join(self.test_subdirectory, "eval")
        ensemble_eval_subdir = os.path.join(ensemble_subdir, "eval")
        for subdir in [ensemble_eval_subdir, eval_subdir]:
            self.assertEqual([b"| dnn |"],
                             tu.check_eventfile_for_keyword(
                                 "architecture/adanet/ensembles/0", subdir))
            if subdir == eval_subdir:
                self.assertAlmostEqual(want_loss,
                                       tu.check_eventfile_for_keyword(
                                           "loss", subdir),
                                       places=1)
            # TODO: Check why some eval metrics are zero on TPU.
            self.assertAlmostEqual(0.0 if use_tpu else want_eval_summary_loss,
                                   tu.check_eventfile_for_keyword(
                                       "average_loss", subdir),
                                   places=1)