Example #1
0
  def _make_metrics(self,
                    metric_fn,
                    mode=tf.estimator.ModeKeys.EVAL,
                    multi_head=False,
                    sess=None):

    with context.graph_mode():
      if multi_head:
        head = multi_head_lib.MultiHead(heads=[
            binary_class_head.BinaryClassHead(
                name="head1", loss_reduction=tf_compat.SUM),
            binary_class_head.BinaryClassHead(
                name="head2", loss_reduction=tf_compat.SUM)
        ])
        labels = {"head1": tf.constant([0, 1]), "head2": tf.constant([0, 1])}
      else:
        head = binary_class_head.BinaryClassHead(loss_reduction=tf_compat.SUM)
        labels = tf.constant([0, 1])
      features = {"x": tf.constant([[1.], [2.]])}
      builder = _EnsembleBuilder(head, metric_fn=metric_fn)
      subnetwork_manager = _SubnetworkManager(head, metric_fn=metric_fn)
      subnetwork_builder = _Builder(
          lambda unused0, unused1: tf.no_op(),
          lambda unused0, unused1: tf.no_op(),
          use_logits_last_layer=True)

      subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
          name="test",
          subnetwork_builder=subnetwork_builder,
          summary=_FakeSummary(),
          features=features,
          mode=mode,
          labels=labels)
      ensemble_spec = builder.build_ensemble_spec(
          name="test",
          candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
          ensembler=ComplexityRegularizedEnsembler(
              mixture_weight_type=MixtureWeightType.SCALAR),
          subnetwork_specs=[subnetwork_spec],
          summary=_FakeSummary(),
          features=features,
          iteration_number=0,
          labels=labels,
          mode=mode)
      subnetwork_metric_ops = call_eval_metrics(subnetwork_spec.eval_metrics)
      ensemble_metric_ops = call_eval_metrics(ensemble_spec.eval_metrics)
      evaluate = self.evaluate
      if sess is not None:
        evaluate = sess.run
      evaluate((tf_compat.v1.global_variables_initializer(),
                tf_compat.v1.local_variables_initializer()))
      evaluate((subnetwork_metric_ops, ensemble_metric_ops))
      # Return the idempotent tensor part of the (tensor, op) metrics tuple.
      return {
          k: evaluate(subnetwork_metric_ops[k][0])
          for k in subnetwork_metric_ops
      }, {k: evaluate(ensemble_metric_ops[k][0]) for k in ensemble_metric_ops}
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(head=head)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", 0., trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ComplexityRegularizedEnsembler(
                    mixture_weight_type=mixture_weight_type,
                    mixture_weight_initializer=mixture_weight_initializer,
                    warm_start_mixture_weights=warm_start_mixture_weights,
                    model_dir=self.test_subdirectory,
                    adanet_lambda=adanet_lambda,
                    adanet_beta=adanet_beta,
                    use_bias=use_bias),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                labels=labels,
                mode=mode)

            with tf_compat.v1.Session(graph=g).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                # Bias should learn a non-zero value when used.
                bias = sess.run(ensemble_spec.ensemble.bias)
                if isinstance(bias, dict):
                    bias = sum(abs(b) for b in bias.values())
                if use_bias:
                    self.assertNotEqual(0., bias)
                else:
                    self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)