示例#1
0
  def _make_metrics(self,
                    metric_fn,
                    mode=tf.estimator.ModeKeys.EVAL,
                    multi_head=False,
                    sess=None):

    with context.graph_mode():
      if multi_head:
        head = multi_head_lib.MultiHead(heads=[
            binary_class_head.BinaryClassHead(
                name="head1", loss_reduction=tf_compat.SUM),
            binary_class_head.BinaryClassHead(
                name="head2", loss_reduction=tf_compat.SUM)
        ])
        labels = {"head1": tf.constant([0, 1]), "head2": tf.constant([0, 1])}
      else:
        head = binary_class_head.BinaryClassHead(loss_reduction=tf_compat.SUM)
        labels = tf.constant([0, 1])
      features = {"x": tf.constant([[1.], [2.]])}
      builder = _EnsembleBuilder(head, metric_fn=metric_fn)
      subnetwork_manager = _SubnetworkManager(head, metric_fn=metric_fn)
      subnetwork_builder = _Builder(
          lambda unused0, unused1: tf.no_op(),
          lambda unused0, unused1: tf.no_op(),
          use_logits_last_layer=True)

      subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
          name="test",
          subnetwork_builder=subnetwork_builder,
          summary=_FakeSummary(),
          features=features,
          mode=mode,
          labels=labels)
      ensemble_spec = builder.build_ensemble_spec(
          name="test",
          candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
          ensembler=ComplexityRegularizedEnsembler(
              mixture_weight_type=MixtureWeightType.SCALAR),
          subnetwork_specs=[subnetwork_spec],
          summary=_FakeSummary(),
          features=features,
          iteration_number=0,
          labels=labels,
          mode=mode)
      subnetwork_metric_ops = call_eval_metrics(subnetwork_spec.eval_metrics)
      ensemble_metric_ops = call_eval_metrics(ensemble_spec.eval_metrics)
      evaluate = self.evaluate
      if sess is not None:
        evaluate = sess.run
      evaluate((tf_compat.v1.global_variables_initializer(),
                tf_compat.v1.local_variables_initializer()))
      evaluate((subnetwork_metric_ops, ensemble_metric_ops))
      # Return the idempotent tensor part of the (tensor, op) metrics tuple.
      return {
          k: evaluate(subnetwork_metric_ops[k][0])
          for k in subnetwork_metric_ops
      }, {k: evaluate(ensemble_metric_ops[k][0]) for k in ensemble_metric_ops}
示例#2
0
def _make_metrics(sess,
                  metric_fn,
                  mode=tf.estimator.ModeKeys.EVAL,
                  multi_head=False):

    if multi_head:
        head = tf.contrib.estimator.multi_head(heads=[
            tf.contrib.estimator.binary_classification_head(
                name="head1", loss_reduction=tf.losses.Reduction.SUM),
            tf.contrib.estimator.binary_classification_head(
                name="head2", loss_reduction=tf.losses.Reduction.SUM)
        ])
        labels = {"head1": tf.constant([0, 1]), "head2": tf.constant([0, 1])}
    else:
        head = tf.contrib.estimator.binary_classification_head(
            loss_reduction=tf.losses.Reduction.SUM)
        labels = tf.constant([0, 1])
    features = {"x": tf.constant([[1.], [2.]])}
    builder = _EnsembleBuilder(head, metric_fn=metric_fn)
    subnetwork_manager = _SubnetworkManager(head, metric_fn)
    subnetwork_builder = _Builder(lambda unused0, unused1: tf.no_op(),
                                  lambda unused0, unused1: tf.no_op(),
                                  use_logits_last_layer=True)

    subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
        name="test",
        subnetwork_builder=subnetwork_builder,
        iteration_step=1,
        summary=_FakeSummary(),
        features=features,
        mode=mode,
        labels=labels)
    ensemble_spec = builder.build_ensemble_spec(
        name="test",
        candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
        ensembler=ComplexityRegularizedEnsembler(
            mixture_weight_type=MixtureWeightType.SCALAR),
        subnetwork_specs=[subnetwork_spec],
        summary=_FakeSummary(),
        features=features,
        iteration_number=0,
        iteration_step=1,
        labels=labels,
        mode=mode)
    fn, kwargs = subnetwork_spec.eval_metrics
    subnetwork_metric_ops = fn(**kwargs)
    fn, kwargs = ensemble_spec.eval_metrics
    ensemble_metric_ops = fn(**kwargs)
    sess.run(
        (tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run((subnetwork_metric_ops, ensemble_metric_ops))
    # Return the idempotent tensor part of the (tensor, op) metrics tuple.
    return {
        k: sess.run(subnetwork_metric_ops[k][0])
        for k in subnetwork_metric_ops
    }, {k: sess.run(ensemble_metric_ops[k][0])
        for k in ensemble_metric_ops}
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(head=head)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", 0., trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ComplexityRegularizedEnsembler(
                    mixture_weight_type=mixture_weight_type,
                    mixture_weight_initializer=mixture_weight_initializer,
                    warm_start_mixture_weights=warm_start_mixture_weights,
                    model_dir=self.test_subdirectory,
                    adanet_lambda=adanet_lambda,
                    adanet_beta=adanet_beta,
                    use_bias=use_bias),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                labels=labels,
                mode=mode)

            with tf_compat.v1.Session(graph=g).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                # Bias should learn a non-zero value when used.
                bias = sess.run(ensemble_spec.ensemble.bias)
                if isinstance(bias, dict):
                    bias = sum(abs(b) for b in bias.values())
                if use_bias:
                    self.assertNotEqual(0., bias)
                else:
                    self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)
示例#4
0
 def test_init_errors(self, max_steps=None):
   head = binary_class_head.BinaryClassHead(loss_reduction=tf_compat.SUM)
   with self.test_session():
     with self.assertRaises(ValueError):
       _SubnetworkManager(head, max_steps=max_steps)
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2,
            ensembler_class=ComplexityRegularizedEnsembler,
            my_ensemble_index=None,
            want_replay_indices=None,
            want_predictions=None,
            export_subnetworks=False,
            previous_ensemble_spec=None,
            previous_iteration_checkpoint=None):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(
            head=head,
            export_subnetwork_logits=export_subnetworks,
            export_subnetwork_last_layer=export_subnetworks)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            if not var_list:
                return tf.no_op()
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            tf_compat.v1.train.get_or_create_global_step()
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", shape=0, trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            session_config = tf.compat.v1.ConfigProto(
                gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensembler_kwargs = {}
            if ensembler_class is ComplexityRegularizedEnsembler:
                ensembler_kwargs.update({
                    "mixture_weight_type": mixture_weight_type,
                    "mixture_weight_initializer": mixture_weight_initializer,
                    "warm_start_mixture_weights": warm_start_mixture_weights,
                    "model_dir": self.test_subdirectory,
                    "adanet_lambda": adanet_lambda,
                    "adanet_beta": adanet_beta,
                    "use_bias": use_bias
                })
            if ensembler_class is MeanEnsembler:
                ensembler_kwargs.update(
                    {"add_mean_last_layer_predictions": True})
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ensembler_class(**ensembler_kwargs),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                labels=labels,
                my_ensemble_index=my_ensemble_index,
                mode=mode,
                previous_iteration_checkpoint=previous_iteration_checkpoint)

            if want_replay_indices:
                self.assertAllEqual(want_replay_indices,
                                    ensemble_spec.architecture.replay_indices)

            with tf_compat.v1.Session(
                    graph=g, config=session_config).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    if not export_subnetworks:
                        return
                    if not multi_head:
                        subnetwork_logits = sess.run(
                            ensemble_spec.export_outputs[
                                _EnsembleBuilder.
                                _SUBNETWORK_LOGITS_EXPORT_SIGNATURE].outputs)
                        self.assertAllClose(
                            subnetwork_logits["test"],
                            sess.run(subnetwork_spec.subnetwork.logits))
                        subnetwork_last_layer = sess.run(
                            ensemble_spec.export_outputs[
                                _EnsembleBuilder.
                                _SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE].
                            outputs)
                        self.assertAllClose(
                            subnetwork_last_layer["test"],
                            sess.run(subnetwork_spec.subnetwork.last_layer))
                    else:
                        self.assertIn("subnetwork_logits_head2",
                                      ensemble_spec.export_outputs)
                        subnetwork_logits_head1 = sess.run(
                            ensemble_spec.
                            export_outputs["subnetwork_logits_head1"].outputs)
                        self.assertAllClose(
                            subnetwork_logits_head1["test"],
                            sess.run(
                                subnetwork_spec.subnetwork.logits["head1"]))
                        self.assertIn("subnetwork_logits_head2",
                                      ensemble_spec.export_outputs)
                        subnetwork_last_layer_head1 = sess.run(
                            ensemble_spec.export_outputs[
                                "subnetwork_last_layer_head1"].outputs)
                        self.assertAllClose(
                            subnetwork_last_layer_head1["test"],
                            sess.run(subnetwork_spec.subnetwork.
                                     last_layer["head1"]))
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                if ensembler_class is ComplexityRegularizedEnsembler:
                    # Bias should learn a non-zero value when used.
                    bias = sess.run(ensemble_spec.ensemble.bias)
                    if isinstance(bias, dict):
                        bias = sum(abs(b) for b in bias.values())
                    if use_bias:
                        self.assertNotEqual(0., bias)
                    else:
                        self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)

                if want_predictions:
                    self.assertAllClose(
                        want_predictions,
                        sess.run(ensemble_spec.ensemble.predictions),
                        atol=1e-3)