예제 #1
0
class CandidateTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.named_parameters({
        "testcase_name":
        "valid",
        "ensemble_spec":
        tu.dummy_ensemble_spec("foo"),
        "adanet_loss": [.1],
    })
    @test_util.run_in_graph_and_eager_modes
    def test_new(self, ensemble_spec, adanet_loss):
        with self.test_session():
            got = _Candidate(ensemble_spec, adanet_loss)
            self.assertEqual(got.ensemble_spec, ensemble_spec)
            self.assertEqual(got.adanet_loss, adanet_loss)

    @parameterized.named_parameters(
        {
            "testcase_name": "none_ensemble_spec",
            "ensemble_spec": None,
            "adanet_loss": [.1],
        }, {
            "testcase_name": "none_adanet_loss",
            "ensemble_spec": tu.dummy_ensemble_spec("foo"),
            "adanet_loss": None,
        })
    @test_util.run_in_graph_and_eager_modes
    def test_new_errors(self, ensemble_spec, adanet_loss):
        with self.test_session():
            with self.assertRaises(ValueError):
                _Candidate(ensemble_spec, adanet_loss)
예제 #2
0
    def build_ensemble_spec(self,
                            name,
                            candidate,
                            ensembler,
                            subnetwork_specs,
                            summary,
                            features,
                            mode,
                            iteration_number,
                            labels=None,
                            previous_ensemble_spec=None,
                            params=None):
        del ensembler
        del subnetwork_specs
        del summary
        del features
        del mode
        del labels
        del iteration_number
        del params

        num_subnetworks = 0
        if previous_ensemble_spec:
            num_subnetworks += 1

        return tu.dummy_ensemble_spec(
            name=name,
            num_subnetworks=num_subnetworks,
            random_seed=candidate.subnetwork_builders[0].seed,
            subnetwork_builders=candidate.subnetwork_builders,
            dict_predictions=self._dict_predictions,
            eval_metrics=self._eval_metrics,
            export_output_key=self._export_output_key)
예제 #3
0
    def test_build_candidate(self,
                             training,
                             max_steps,
                             want_adanet_losses,
                             want_is_training,
                             is_previous_best=False):
        # A fake adanet_loss that halves at each train step: 1.0, 0.5, 0.25, ...
        fake_adanet_loss = tf.Variable(1.)
        fake_train_op = fake_adanet_loss.assign(fake_adanet_loss / 2)
        fake_ensemble_spec = tu.dummy_ensemble_spec(
            "new", adanet_loss=fake_adanet_loss, train_op=fake_train_op)

        iteration_step = tf.Variable(0)
        builder = _CandidateBuilder(max_steps=max_steps)
        candidate = builder.build_candidate(ensemble_spec=fake_ensemble_spec,
                                            training=training,
                                            iteration_step=iteration_step,
                                            summary=_FakeSummary(),
                                            is_previous_best=is_previous_best)
        with self.test_session() as sess:
            sess.run(tf_compat.v1.global_variables_initializer())
            adanet_losses = []
            is_training = True
            for _ in range(len(want_adanet_losses)):
                is_training, adanet_loss = sess.run(
                    (candidate.is_training, candidate.adanet_loss))
                adanet_losses.append(adanet_loss)
                sess.run((fake_train_op, iteration_step.assign_add(1)))

        # Verify that adanet_loss moving average works.
        self.assertAllClose(want_adanet_losses, adanet_losses, atol=1e-3)
        self.assertEqual(want_is_training, is_training)
def _dummy_candidate():
  """Returns a dummy `_Candidate` instance."""

  return _Candidate(
      ensemble_spec=tu.dummy_ensemble_spec("foo"),
      adanet_loss=1.,
      variables=[tf.Variable(1.)])
예제 #5
0
def _dummy_candidate():
  """Returns a dummy `_Candidate` instance."""

  return _Candidate(
      ensemble_spec=tu.dummy_ensemble_spec("foo"),
      adanet_loss=1.,
      is_training=True)
예제 #6
0
class CandidateTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.named_parameters({
        "testcase_name":
        "valid",
        "ensemble_spec":
        tu.dummy_ensemble_spec("foo"),
        "adanet_loss": [.1],
        "is_training":
        True,
    })
    def test_new(self, ensemble_spec, adanet_loss, is_training):
        with self.test_session():
            got = _Candidate(ensemble_spec, adanet_loss, is_training)
            self.assertEqual(got.ensemble_spec, ensemble_spec)
            self.assertEqual(got.adanet_loss, adanet_loss)
            self.assertEqual(got.is_training, is_training)

    @parameterized.named_parameters(
        {
            "testcase_name": "none_ensemble_spec",
            "ensemble_spec": None,
            "adanet_loss": [.1],
            "is_training": True,
        }, {
            "testcase_name": "none_adanet_loss",
            "ensemble_spec": tu.dummy_ensemble_spec("foo"),
            "adanet_loss": None,
            "is_training": True,
        }, {
            "testcase_name": "none_is_training",
            "ensemble_spec": tu.dummy_ensemble_spec("foo"),
            "adanet_loss": [.1],
            "is_training": None,
        })
    def test_new_errors(self, ensemble_spec, adanet_loss, is_training):
        with self.test_session():
            with self.assertRaises(ValueError):
                _Candidate(ensemble_spec, adanet_loss, is_training)
예제 #7
0
    def append_new_subnetwork(self, ensemble_spec, subnetwork_builder,
                              iteration_step, summary, features, mode, labels):
        del summary
        del mode
        del features
        del labels
        del iteration_step

        num_subnetworks = 0
        if ensemble_spec:
            num_subnetworks += 1

        return tu.dummy_ensemble_spec(
            name=subnetwork_builder.name,
            num_subnetworks=num_subnetworks,
            random_seed=subnetwork_builder.seed,
            dict_predictions=self._dict_predictions,
            eval_metric_ops=self._eval_metric_ops_fn(),
            export_output_key=self._export_output_key)
예제 #8
0
    def test_build_candidate(self, training, want_adanet_losses):
        # A fake adanet_loss that halves at each train step: 1.0, 0.5, 0.25, ...
        fake_adanet_loss = tf.Variable(1.)
        fake_train_op = fake_adanet_loss.assign(fake_adanet_loss / 2)
        fake_ensemble_spec = tu.dummy_ensemble_spec(
            "new", adanet_loss=fake_adanet_loss, train_op=fake_train_op)

        builder = _CandidateBuilder()
        candidate = builder.build_candidate(ensemble_spec=fake_ensemble_spec,
                                            training=training,
                                            summary=_FakeSummary())
        with self.test_session() as sess:
            sess.run(tf_compat.v1.global_variables_initializer())
            adanet_losses = []
            for _ in range(len(want_adanet_losses)):
                adanet_loss = sess.run(candidate.adanet_loss)
                adanet_losses.append(adanet_loss)
                sess.run(fake_train_op)

        # Verify that adanet_loss moving average works.
        self.assertAllClose(want_adanet_losses, adanet_losses, atol=1e-3)
  def build_ensemble_spec(self,
                          name,
                          candidate,
                          ensembler,
                          subnetwork_specs,
                          summary,
                          features,
                          mode,
                          iteration_number,
                          labels=None,
                          previous_ensemble_spec=None,
                          my_ensemble_index=None,
                          params=None,
                          previous_iteration_checkpoint=None):
    del ensembler
    del subnetwork_specs
    del summary
    del features
    del mode
    del labels
    del iteration_number
    del params
    del my_ensemble_index
    del previous_iteration_checkpoint

    num_subnetworks = 0
    if previous_ensemble_spec:
      num_subnetworks += 1

    return tu.dummy_ensemble_spec(
        name=name,
        num_subnetworks=num_subnetworks,
        random_seed=candidate.subnetwork_builders[0].seed,
        subnetwork_builders=candidate.subnetwork_builders,
        dict_predictions=self._dict_predictions,
        eval_metrics=tu.create_ensemble_metrics(
            metric_fn=self._eval_metric_ops_fn),
        export_output_key=self._export_output_key,
        variables=[tf.Variable(1.)])
예제 #10
0
    def test_build_candidate(self, training, want_adanet_losses):
        # `Cadidate#build_candidate` will only ever be called in graph mode.
        with context.graph_mode():
            # A fake adanet_loss that halves at each train step: 1.0, 0.5, 0.25, ...
            fake_adanet_loss = tf.Variable(1.)
            fake_train_op = fake_adanet_loss.assign(fake_adanet_loss / 2)
            fake_ensemble_spec = tu.dummy_ensemble_spec(
                "new", adanet_loss=fake_adanet_loss, train_op=fake_train_op)

            builder = _CandidateBuilder()
            candidate = builder.build_candidate(
                ensemble_spec=fake_ensemble_spec,
                training=training,
                summary=_FakeSummary())
            self.evaluate(tf_compat.v1.global_variables_initializer())
            adanet_losses = []
            for _ in range(len(want_adanet_losses)):
                adanet_loss = self.evaluate(candidate.adanet_loss)
                adanet_losses.append(adanet_loss)
                self.evaluate(fake_train_op)

            # Verify that adanet_loss moving average works.
            self.assertAllClose(want_adanet_losses, adanet_losses, atol=1e-3)
예제 #11
0
class IterationBuilderTest(parameterized.TestCase, tf.test.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        {
            "testcase_name": "single_subnetwork_fn",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "want_loss": 1.403943,
            "want_predictions": 2.129,
            "want_best_candidate_index": 0,
        }, {
            "testcase_name":
            "single_subnetwork_with_eval_metrics",
            "ensemble_builder":
            _FakeEnsembleBuilder(eval_metric_ops_fn=lambda:
                                 {"a": (tf.constant(1), tf.constant(2))}),
            "subnetwork_builders": [
                _FakeBuilder("training", ),
            ],
            "mode":
            tf.estimator.ModeKeys.EVAL,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_eval_metric_ops": ["a"],
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "single_subnetwork_with_non_tensor_eval_metric_op",
            "ensemble_builder":
            _FakeEnsembleBuilder(eval_metric_ops_fn=lambda:
                                 {"a": (tf.constant(1), tf.no_op())}),
            "subnetwork_builders": [
                _FakeBuilder("training", ),
            ],
            "mode":
            tf.estimator.ModeKeys.EVAL,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_eval_metric_ops": ["a"],
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name": "single_subnetwork_done_training_fn",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("done")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "want_loss": 1.403943,
            "want_predictions": 2.129,
            "want_best_candidate_index": 0,
            "want_is_over": True,
        }, {
            "testcase_name": "single_dict_predictions_subnetwork_fn",
            "ensemble_builder": _FakeEnsembleBuilder(dict_predictions=True),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "want_loss": 1.403943,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index": 0,
        }, {
            "testcase_name": "previous_ensemble",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "previous_ensemble_spec": lambda: tu.dummy_ensemble_spec("old"),
            "want_loss": 1.403943,
            "want_predictions": 2.129,
            "want_best_candidate_index": 1,
        }, {
            "testcase_name":
            "previous_ensemble_is_best",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "previous_ensemble_spec":
            lambda: tu.dummy_ensemble_spec("old", random_seed=12),
            "want_loss":
            -.437,
            "want_predictions":
            .688,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "previous_ensemble_spec_and_eval_metrics",
            "ensemble_builder":
            _FakeEnsembleBuilder(eval_metric_ops_fn=lambda:
                                 {"a": (tf.constant(1), tf.constant(2))}),
            "subnetwork_builders": [_FakeBuilder("training")],
            "mode":
            tf.estimator.ModeKeys.EVAL,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "previous_ensemble_spec":
            lambda: tu.dummy_ensemble_spec("old",
                                           eval_metrics=(lambda: {
                                               "a":
                                               (tf.constant(1), tf.constant(2))
                                           }, {})),
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_eval_metric_ops": ["a"],
            "want_best_candidate_index":
            1,
        }, {
            "testcase_name":
            "two_subnetwork_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.40394,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_subnetwork_fns_other_best",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=12)
            ],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            -.437,
            "want_predictions":
            .688,
            "want_best_candidate_index":
            1,
        }, {
            "testcase_name":
            "two_subnetwork_one_training_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("training"),
             _FakeBuilder("done", random_seed=7)],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_subnetwork_done_training_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("done"),
             _FakeBuilder("done1", random_seed=7)],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
            "want_is_over":
            True,
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(dict_predictions=True),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.404,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_classes",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.CLASSIFICATION_CLASSES),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.404,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.CLASSIFICATION_CLASSES: [2.129],
                "serving_default": [2.129],
            },
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_scores",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.CLASSIFICATION_SCORES),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.404,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.CLASSIFICATION_SCORES: [2.129],
                "serving_default": [2.129],
            },
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_regression",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.REGRESSION),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.REGRESSION: 2.129,
                "serving_default": 2.129,
            },
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_prediction",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.PREDICTION),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.PREDICTION: {
                    "classes": 2,
                    "logits": 2.129
                },
                "serving_default": {
                    "classes": 2,
                    "logits": 2.129
                },
            },
        })
    def test_build_iteration(self,
                             ensemble_builder,
                             subnetwork_builders,
                             features,
                             labels,
                             want_predictions,
                             want_best_candidate_index,
                             want_eval_metric_ops=(),
                             want_is_over=False,
                             previous_ensemble_spec=lambda: None,
                             want_loss=None,
                             want_export_outputs=None,
                             mode=tf.estimator.ModeKeys.TRAIN):
        global_step = tf.train.create_global_step()
        builder = _IterationBuilder("/tmp", _FakeCandidateBuilder(),
                                    ensemble_builder)
        iteration = builder.build_iteration(
            iteration_number=0,
            subnetwork_builders=subnetwork_builders,
            features=features(),
            labels=labels(),
            mode=mode,
            previous_ensemble_spec=previous_ensemble_spec())
        with self.test_session() as sess:
            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            sess.run(init)
            estimator_spec = iteration.estimator_spec
            self.assertAllClose(want_predictions,
                                sess.run(estimator_spec.predictions),
                                atol=1e-3)
            self.assertEqual(set(want_eval_metric_ops),
                             set(estimator_spec.eval_metric_ops.keys()))
            self.assertEqual(want_best_candidate_index,
                             sess.run(iteration.best_candidate_index))

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.assertIsNotNone(estimator_spec.export_outputs)
                self.assertAllClose(want_export_outputs,
                                    sess.run(
                                        _export_output_tensors(
                                            estimator_spec.export_outputs)),
                                    atol=1e-3)
                self.assertEqual(iteration.estimator_spec.train_op.type,
                                 tf.no_op().type)
                self.assertIsNone(iteration.estimator_spec.loss)
                self.assertIsNotNone(want_export_outputs)
                return

            self.assertAlmostEqual(want_loss,
                                   sess.run(iteration.estimator_spec.loss),
                                   places=3)
            self.assertIsNone(iteration.estimator_spec.export_outputs)
            if mode == tf.estimator.ModeKeys.TRAIN:
                sess.run(iteration.estimator_spec.train_op)
                self.assertEqual(want_is_over,
                                 sess.run(iteration.is_over_fn()))
                self.assertEqual(1, sess.run(global_step))
                self.assertEqual(1, sess.run(iteration.step))

    @parameterized.named_parameters(
        {
            "testcase_name": "empty_subnetwork_builders",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [],
            "want_raises": ValueError,
        }, {
            "testcase_name":
            "same_subnetwork_builder_names",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("same_name"),
             _FakeBuilder("same_name")],
            "want_raises":
            ValueError,
        }, {
            "testcase_name":
            "same_name_as_previous_ensemble_spec",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "previous_ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("same_name"),
            "subnetwork_builders": [
                _FakeBuilder("same_name"),
            ],
            "want_raises":
            ValueError,
        }, {
            "testcase_name":
            "predict_invalid",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.INVALID),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "want_raises":
            TypeError,
        })
    def test_build_iteration_error(self,
                                   ensemble_builder,
                                   subnetwork_builders,
                                   want_raises,
                                   previous_ensemble_spec_fn=lambda: None,
                                   mode=tf.estimator.ModeKeys.TRAIN):
        builder = _IterationBuilder("/tmp", _FakeCandidateBuilder(),
                                    ensemble_builder)
        features = [[1., -1., 0.]]
        labels = [1]
        with self.test_session():
            with self.assertRaises(want_raises):
                builder.build_iteration(
                    iteration_number=0,
                    subnetwork_builders=subnetwork_builders,
                    features=features,
                    labels=labels,
                    mode=mode,
                    previous_ensemble_spec=previous_ensemble_spec_fn())
예제 #12
0
class IterationBuilderTest(tu.AdanetTestCase):
    @parameterized.named_parameters(
        {
            "testcase_name": "negative_max_steps",
            "max_steps": -1,
        }, {
            "testcase_name": "zero_max_steps",
            "max_steps": 0,
        })
    @test_util.run_in_graph_and_eager_modes
    def test_init_errors(self, max_steps):
        with self.assertRaises(ValueError):
            _IterationBuilder(_FakeCandidateBuilder(),
                              _FakeSubnetworkManager(),
                              _FakeEnsembleBuilder(),
                              summary_maker=_ScopedSummary,
                              ensemblers=[_FakeEnsembler()],
                              max_steps=max_steps)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        {
            "testcase_name": "single_subnetwork_fn",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "want_loss": 1.403943,
            "want_predictions": 2.129,
            "want_best_candidate_index": 0,
        }, {
            "testcase_name":
            "single_subnetwork_fn_mock_summary",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "summary_maker":
            functools.partial(_TPUScopedSummary, logdir="/tmp/fakedir"),
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "single_subnetwork_with_eval_metrics",
            "ensemble_builder":
            _FakeEnsembleBuilder(eval_metric_ops_fn=lambda:
                                 {"a": (tf.constant(1), tf.constant(2))}),
            "subnetwork_builders": [
                _FakeBuilder("training", ),
            ],
            "mode":
            tf.estimator.ModeKeys.EVAL,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_eval_metric_ops": ["a", "iteration"],
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "single_subnetwork_with_non_tensor_eval_metric_op",
            "ensemble_builder":
            _FakeEnsembleBuilder(eval_metric_ops_fn=lambda:
                                 {"a": (tf.constant(1), tf.no_op())}),
            "subnetwork_builders": [
                _FakeBuilder("training", ),
            ],
            "mode":
            tf.estimator.ModeKeys.EVAL,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_eval_metric_ops": ["a", "iteration"],
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name": "single_subnetwork_done_training_fn",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("done")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "want_loss": 1.403943,
            "want_predictions": 2.129,
            "want_best_candidate_index": 0,
        }, {
            "testcase_name": "single_dict_predictions_subnetwork_fn",
            "ensemble_builder": _FakeEnsembleBuilder(dict_predictions=True),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "want_loss": 1.403943,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index": 0,
        }, {
            "testcase_name": "previous_ensemble",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features": lambda: [[1., -1., 0.]],
            "labels": lambda: [1],
            "previous_ensemble_spec": lambda: tu.dummy_ensemble_spec("old"),
            "want_loss": 1.403943,
            "want_predictions": 2.129,
            "want_best_candidate_index": 1,
        }, {
            "testcase_name":
            "previous_ensemble_is_best",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [_FakeBuilder("training")],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "previous_ensemble_spec":
            lambda: tu.dummy_ensemble_spec("old", random_seed=12),
            "want_loss":
            -.437,
            "want_predictions":
            .688,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "previous_ensemble_spec_and_eval_metrics",
            "ensemble_builder":
            _FakeEnsembleBuilder(eval_metric_ops_fn=lambda:
                                 {"a": (tf.constant(1), tf.constant(2))}),
            "subnetwork_builders": [_FakeBuilder("training")],
            "mode":
            tf.estimator.ModeKeys.EVAL,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "previous_ensemble_spec":
            lambda: tu.dummy_ensemble_spec(
                "old",
                eval_metrics=tu.create_ensemble_metrics(
                    metric_fn=lambda: {"a":
                                       (tf.constant(1), tf.constant(2))})),
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_eval_metric_ops": ["a", "iteration"],
            "want_best_candidate_index":
            1,
        }, {
            "testcase_name":
            "two_subnetwork_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.40394,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_subnetwork_fns_other_best",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=12)
            ],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            -.437,
            "want_predictions":
            .688,
            "want_best_candidate_index":
            1,
        }, {
            "testcase_name":
            "two_subnetwork_one_training_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("training"),
             _FakeBuilder("done", random_seed=7)],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_subnetwork_done_training_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("done"),
             _FakeBuilder("done1", random_seed=7)],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns",
            "ensemble_builder":
            _FakeEnsembleBuilder(dict_predictions=True),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.404,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_classes",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.CLASSIFICATION_CLASSES),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.404,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.CLASSIFICATION_CLASSES: [2.129],
                "serving_default": [2.129],
            },
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_scores",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.CLASSIFICATION_SCORES),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.404,
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.CLASSIFICATION_SCORES: [2.129],
                "serving_default": [2.129],
            },
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_regression",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.REGRESSION),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.REGRESSION: 2.129,
                "serving_default": 2.129,
            },
        }, {
            "testcase_name":
            "two_dict_predictions_subnetwork_fns_predict_prediction",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.PREDICTION),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_predictions": {
                "classes": 2,
                "logits": 2.129
            },
            "want_best_candidate_index":
            0,
            "want_export_outputs": {
                tu.ExportOutputKeys.PREDICTION: {
                    "classes": 2,
                    "logits": 2.129
                },
                "serving_default": {
                    "classes": 2,
                    "logits": 2.129
                },
            },
        }, {
            "testcase_name":
            "chief_session_run_hook",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("training", chief_hook=tu.ModifierSessionRunHook())],
            "features":
            lambda: [[1., -1., 0.]],
            "labels":
            lambda: [1],
            "want_loss":
            1.403943,
            "want_predictions":
            2.129,
            "want_best_candidate_index":
            0,
            "want_chief_hooks":
            True,
        })
    @test_util.run_in_graph_and_eager_modes
    def test_build_iteration(self,
                             ensemble_builder,
                             subnetwork_builders,
                             features,
                             labels,
                             want_predictions,
                             want_best_candidate_index,
                             want_eval_metric_ops=(),
                             previous_ensemble_spec=lambda: None,
                             want_loss=None,
                             want_export_outputs=None,
                             mode=tf.estimator.ModeKeys.TRAIN,
                             summary_maker=_ScopedSummary,
                             want_chief_hooks=False):
        with context.graph_mode():
            tf_compat.v1.train.create_global_step()
            builder = _IterationBuilder(_FakeCandidateBuilder(),
                                        _FakeSubnetworkManager(),
                                        ensemble_builder,
                                        summary_maker=summary_maker,
                                        ensemblers=[_FakeEnsembler()],
                                        max_steps=1)
            iteration = builder.build_iteration(
                base_global_step=0,
                iteration_number=0,
                ensemble_candidates=[
                    EnsembleCandidate(b.name, [b], None)
                    for b in subnetwork_builders
                ],
                subnetwork_builders=subnetwork_builders,
                features=features(),
                labels=labels(),
                mode=mode,
                config=tf.estimator.RunConfig(
                    model_dir=self.test_subdirectory),
                previous_ensemble_spec=previous_ensemble_spec())
            init = tf.group(tf_compat.v1.global_variables_initializer(),
                            tf_compat.v1.local_variables_initializer())
            self.evaluate(init)
            estimator_spec = iteration.estimator_spec
            if want_chief_hooks:
                self.assertNotEmpty(
                    iteration.estimator_spec.training_chief_hooks)
            self.assertAllClose(want_predictions,
                                self.evaluate(estimator_spec.predictions),
                                atol=1e-3)

            # A default architecture metric is always included, even if we don't
            # specify one.
            eval_metric_ops = estimator_spec.eval_metric_ops
            if "architecture/adanet/ensembles" in eval_metric_ops:
                del eval_metric_ops["architecture/adanet/ensembles"]
            self.assertEqual(set(want_eval_metric_ops),
                             set(eval_metric_ops.keys()))

            self.assertEqual(want_best_candidate_index,
                             self.evaluate(iteration.best_candidate_index))

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.assertIsNotNone(estimator_spec.export_outputs)
                self.assertAllClose(want_export_outputs,
                                    self.evaluate(
                                        _export_output_tensors(
                                            estimator_spec.export_outputs)),
                                    atol=1e-3)
                self.assertIsNone(iteration.estimator_spec.train_op)
                self.assertIsNone(iteration.estimator_spec.loss)
                self.assertIsNotNone(want_export_outputs)
                return

            self.assertAlmostEqual(want_loss,
                                   self.evaluate(
                                       iteration.estimator_spec.loss),
                                   places=3)
            self.assertIsNone(iteration.estimator_spec.export_outputs)
            if mode == tf.estimator.ModeKeys.TRAIN:
                self.evaluate(iteration.estimator_spec.train_op)

    @parameterized.named_parameters(
        {
            "testcase_name": "empty_subnetwork_builders",
            "ensemble_builder": _FakeEnsembleBuilder(),
            "subnetwork_builders": [],
            "want_raises": ValueError,
        }, {
            "testcase_name":
            "same_subnetwork_builder_names",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "subnetwork_builders":
            [_FakeBuilder("same_name"),
             _FakeBuilder("same_name")],
            "want_raises":
            ValueError,
        }, {
            "testcase_name":
            "same_name_as_previous_ensemble_spec",
            "ensemble_builder":
            _FakeEnsembleBuilder(),
            "previous_ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("same_name"),
            "subnetwork_builders": [
                _FakeBuilder("same_name"),
            ],
            "want_raises":
            ValueError,
        }, {
            "testcase_name":
            "predict_invalid",
            "ensemble_builder":
            _FakeEnsembleBuilder(
                dict_predictions=True,
                export_output_key=tu.ExportOutputKeys.INVALID),
            "subnetwork_builders": [
                _FakeBuilder("training"),
                _FakeBuilder("training2", random_seed=7)
            ],
            "mode":
            tf.estimator.ModeKeys.PREDICT,
            "want_raises":
            TypeError,
        })
    @test_util.run_in_graph_and_eager_modes
    def test_build_iteration_error(self,
                                   ensemble_builder,
                                   subnetwork_builders,
                                   want_raises,
                                   previous_ensemble_spec_fn=lambda: None,
                                   mode=tf.estimator.ModeKeys.TRAIN,
                                   summary_maker=_ScopedSummary):
        with context.graph_mode():
            builder = _IterationBuilder(_FakeCandidateBuilder(),
                                        _FakeSubnetworkManager(),
                                        ensemble_builder,
                                        summary_maker=summary_maker,
                                        ensemblers=[_FakeEnsembler()],
                                        max_steps=100)
            features = [[1., -1., 0.]]
            labels = [1]
            with self.assertRaises(want_raises):
                builder.build_iteration(
                    base_global_step=0,
                    iteration_number=0,
                    ensemble_candidates=[
                        EnsembleCandidate("test", subnetwork_builders, None)
                    ],
                    subnetwork_builders=subnetwork_builders,
                    features=features,
                    labels=labels,
                    mode=mode,
                    config=tf.estimator.RunConfig(
                        model_dir=self.test_subdirectory),
                    previous_ensemble_spec=previous_ensemble_spec_fn())
예제 #13
0
class EnsembleBuilderTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.named_parameters(
        {
            "testcase_name": "no_previous_ensemble",
            "want_logits": [[.016], [.117]],
            "want_loss": 1.338,
            "want_adanet_loss": 1.338,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name": "no_previous_ensemble_prune_all",
            "want_logits": [[.016], [.117]],
            "want_loss": 1.338,
            "want_adanet_loss": 1.338,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
            "subnetwork_builder_class": _BuilderPrunerAll
        }, {
            "testcase_name": "no_previous_ensemble_prune_leave_one",
            "want_logits": [[.016], [.117]],
            "want_loss": 1.338,
            "want_adanet_loss": 1.338,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
            "subnetwork_builder_class": _BuilderPrunerLeaveOne
        }, {
            "testcase_name": "default_mixture_weight_initializer_scalar",
            "mixture_weight_initializer": None,
            "mixture_weight_type": MixtureWeightType.SCALAR,
            "use_logits_last_layer": True,
            "want_logits": [[.580], [.914]],
            "want_loss": 1.362,
            "want_adanet_loss": 1.362,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name": "default_mixture_weight_initializer_vector",
            "mixture_weight_initializer": None,
            "mixture_weight_type": MixtureWeightType.VECTOR,
            "use_logits_last_layer": True,
            "want_logits": [[.580], [.914]],
            "want_loss": 1.362,
            "want_adanet_loss": 1.362,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name": "default_mixture_weight_initializer_matrix",
            "mixture_weight_initializer": None,
            "mixture_weight_type": MixtureWeightType.MATRIX,
            "want_logits": [[.016], [.117]],
            "want_loss": 1.338,
            "want_adanet_loss": 1.338,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name":
            "default_mixture_weight_initializer_matrix_on_logits",
            "mixture_weight_initializer": None,
            "mixture_weight_type": MixtureWeightType.MATRIX,
            "use_logits_last_layer": True,
            "want_logits": [[.030], [.047]],
            "want_loss": 1.378,
            "want_adanet_loss": 1.378,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name": "no_previous_ensemble_use_bias",
            "use_bias": True,
            "want_logits": [[0.013], [0.113]],
            "want_loss": 1.338,
            "want_adanet_loss": 1.338,
            "want_complexity_regularization": 0.,
            "want_mixture_weight_vars": 2,
        }, {
            "testcase_name": "no_previous_ensemble_predict_mode",
            "mode": tf.estimator.ModeKeys.PREDICT,
            "want_logits": [[0.], [0.]],
            "want_complexity_regularization": 0.,
        }, {
            "testcase_name": "no_previous_ensemble_lambda",
            "adanet_lambda": .01,
            "want_logits": [[.014], [.110]],
            "want_loss": 1.340,
            "want_adanet_loss": 1.343,
            "want_complexity_regularization": .003,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name": "no_previous_ensemble_beta",
            "adanet_beta": .1,
            "want_logits": [[.006], [.082]],
            "want_loss": 1.349,
            "want_adanet_loss": 1.360,
            "want_complexity_regularization": .012,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name": "no_previous_ensemble_lambda_and_beta",
            "adanet_lambda": .01,
            "adanet_beta": .1,
            "want_logits": [[.004], [.076]],
            "want_loss": 1.351,
            "want_adanet_loss": 1.364,
            "want_complexity_regularization": .013,
            "want_mixture_weight_vars": 1,
        }, {
            "testcase_name":
            "previous_ensemble",
            "ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("test", random_seed=1),
            "adanet_lambda":
            .01,
            "adanet_beta":
            .1,
            "want_logits": [[.089], [.159]],
            "want_loss":
            1.355,
            "want_adanet_loss":
            1.398,
            "want_complexity_regularization":
            .043,
            "want_mixture_weight_vars":
            2,
        }, {
            "testcase_name":
            "previous_ensemble_prune_all",
            "ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("test", random_seed=1),
            "adanet_lambda":
            .01,
            "adanet_beta":
            .1,
            "want_logits": [[0.003569], [0.07557]],
            "want_loss":
            1.3510095,
            "want_adanet_loss":
            1.3644928,
            "want_complexity_regularization":
            0.013483323,
            "want_mixture_weight_vars":
            1,
            "subnetwork_builder_class":
            _BuilderPrunerAll
        }, {
            "testcase_name":
            "previous_ensemble_leave_one",
            "ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("test", random_seed=1),
            "adanet_lambda":
            .01,
            "adanet_beta":
            .1,
            "want_logits": [[.089], [.159]],
            "want_loss":
            1.355,
            "want_adanet_loss":
            1.398,
            "want_complexity_regularization":
            .043,
            "want_mixture_weight_vars":
            2,
            "subnetwork_builder_class":
            _BuilderPrunerLeaveOne
        }, {
            "testcase_name":
            "previous_ensemble_use_bias",
            "use_bias":
            True,
            "ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("test", random_seed=1),
            "adanet_lambda":
            .01,
            "adanet_beta":
            .1,
            "want_logits": [[.075], [.146]],
            "want_loss":
            1.354,
            "want_adanet_loss":
            1.397,
            "want_complexity_regularization":
            .043,
            "want_mixture_weight_vars":
            3,
        }, {
            "testcase_name":
            "previous_ensemble_no_warm_start",
            "ensemble_spec_fn":
            lambda: tu.dummy_ensemble_spec("test", random_seed=1),
            "warm_start_mixture_weights":
            False,
            "adanet_lambda":
            .01,
            "adanet_beta":
            .1,
            "want_logits": [[.007], [.079]],
            "want_loss":
            1.351,
            "want_adanet_loss":
            1.367,
            "want_complexity_regularization":
            .016,
            "want_mixture_weight_vars":
            2,
        })
    def test_append_new_subnetwork(
            self,
            want_logits,
            want_complexity_regularization,
            want_loss=None,
            want_adanet_loss=None,
            want_mixture_weight_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN):
        seed = 64
        builder = _EnsembleBuilder(
            head=tf.contrib.estimator.binary_classification_head(
                loss_reduction=tf.losses.Reduction.SUM),
            mixture_weight_type=mixture_weight_type,
            mixture_weight_initializer=mixture_weight_initializer,
            warm_start_mixture_weights=warm_start_mixture_weights,
            adanet_lambda=adanet_lambda,
            adanet_beta=adanet_beta,
            use_bias=use_bias)

        features = {"x": tf.constant([[1.], [2.]])}
        labels = tf.constant([0, 1])

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertEqual(2, len(var_list))
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertEqual(want_mixture_weight_vars, len(var_list))
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        ensemble_spec = builder.append_new_subnetwork(
            ensemble_spec=ensemble_spec_fn(),
            subnetwork_builder=subnetwork_builder_class(
                _subnetwork_train_op_fn, _mixture_weights_train_op_fn,
                use_logits_last_layer, seed),
            summary=tf.summary,
            features=features,
            iteration_step=tf.train.get_or_create_global_step(),
            labels=labels,
            mode=mode)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)
                self.assertIsNone(ensemble_spec.loss)
                self.assertIsNone(ensemble_spec.adanet_loss)
                self.assertIsNone(ensemble_spec.train_op)
                self.assertIsNotNone(ensemble_spec.export_outputs)
                return

            # Verify that train_op works, previous loss should be greater than loss
            # after a train op.
            loss = sess.run(ensemble_spec.loss)
            for _ in range(3):
                sess.run(ensemble_spec.train_op)
            self.assertGreater(loss, sess.run(ensemble_spec.loss))

            self.assertAllClose(want_logits,
                                sess.run(ensemble_spec.ensemble.logits),
                                atol=1e-3)

            # Bias should learn a non-zero value when used.
            if use_bias:
                self.assertNotEqual(0., sess.run(ensemble_spec.ensemble.bias))
            else:
                self.assertAlmostEqual(0.,
                                       sess.run(ensemble_spec.ensemble.bias))

            self.assertAlmostEqual(
                want_complexity_regularization,
                sess.run(ensemble_spec.complexity_regularization),
                places=3)
            self.assertAlmostEqual(want_loss,
                                   sess.run(ensemble_spec.loss),
                                   places=3)
            self.assertAlmostEqual(want_adanet_loss,
                                   sess.run(ensemble_spec.adanet_loss),
                                   places=3)