Exemple #1
0
  def test_train_create_loss_logits_tensor_multi_dim(self):
    """Tests create_loss with multi-dimensional logits of shape [2, 2, 5]."""
    head1 = regression_head.RegressionHead(label_dimension=2, name='head1')
    head2 = regression_head.RegressionHead(label_dimension=3, name='head2')
    multi_head = multi_head_lib.MultiHead([head1, head2])

    logits = np.array(
        [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
         [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]],
        dtype=np.float32)
    labels = {
        'head1': np.array([[[1., 0.], [1., 0.]],
                           [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32),
        'head2': np.array([[[0., 1., 0.], [0., 1., 0.]],
                           [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32),
    }
    # Loss for the first head:
    # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +
    #          (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8
    #       = 3.5
    # Loss for the second head:
    # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +
    #          (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12
    #       = 6.167
    expected_training_loss = 3.5 + 6.167

    training_loss = multi_head.loss(
        logits=logits,
        labels=labels,
        features={},
        mode=model_fn.ModeKeys.TRAIN)
    tol = 1e-3
    self.assertAllClose(
        expected_training_loss, self.evaluate(training_loss),
        rtol=tol, atol=tol)
    def test_train_one_head_with_optimizer(self):
        head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
        multi_head = multi_head_lib.MultiHead([head1])

        logits = {
            'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
        }
        labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
        features = {'x': np.array(((42, ), ), dtype=np.int32)}
        # For large logits, sigmoid cross entropy loss is approximated as:
        # loss = labels * (logits < 0) * (-logits) +
        #        (1 - labels) * (logits > 0) * logits =>
        # expected_unweighted_loss = [[10., 10.], [15., 0.]]
        # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
        expected_loss = 8.75
        tol = 1e-3
        loss = multi_head.loss(logits=logits,
                               labels=labels,
                               features=features,
                               mode=ModeKeys.TRAIN)
        self.assertAllClose(expected_loss,
                            self.evaluate(loss),
                            rtol=tol,
                            atol=tol)
        if tf.executing_eagerly():
            return

        expected_train_result = 'my_train_op'

        class _Optimizer(optimizer_v2.OptimizerV2):
            def get_updates(self, loss, params):
                del params
                return [
                    tf.strings.join([
                        tf.constant(expected_train_result),
                        tf.strings.as_string(loss, precision=3)
                    ])
                ]

            def get_config(self):
                config = super(_Optimizer, self).get_config()
                return config

        spec = multi_head.create_estimator_spec(
            features=features,
            mode=ModeKeys.TRAIN,
            logits=logits,
            labels=labels,
            optimizer=_Optimizer('my_optimizer'),
            trainable_variables=[
                tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)
            ])

        with self.cached_session() as sess:
            test_lib._initialize_variables(self, spec.scaffold)
            loss, train_result = sess.run((spec.loss, spec.train_op))
            self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
            self.assertEqual(
                six.b('{0:s}{1:.3f}'.format(expected_train_result,
                                            expected_loss)), train_result)
 def test_loss_reduction_must_be_same(self):
     """Tests the loss reduction must be the same for different heads."""
     head1 = multi_label_head.MultiLabelHead(
         n_classes=2,
         name='head1',
         loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)
     head2 = multi_label_head.MultiLabelHead(
         n_classes=3,
         name='head2',
         loss_reduction=losses_utils.ReductionV2.AUTO)
     multi_head = multi_head_lib.MultiHead([head1, head2])
     logits = {
         'head1':
         np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
         'head2':
         np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),
     }
     labels = {
         'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
         'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
     }
     with self.assertRaisesRegexp(ValueError, 'must be the same'):
         multi_head.create_estimator_spec(
             features={'x': np.array(((42, ), ), dtype=np.int32)},
             mode=ModeKeys.TRAIN,
             logits=logits,
             labels=labels)
    def test_train_loss_logits_tensor_wrong_shape(self):
        """Tests loss with a logits Tensor of the wrong shape."""
        weights1 = np.array([[1.], [2.]], dtype=np.float32)
        weights2 = np.array([[2.], [3.]])
        head1 = multi_label_head.MultiLabelHead(n_classes=2,
                                                name='head1',
                                                weight_column='weights1')
        head2 = multi_label_head.MultiLabelHead(n_classes=3,
                                                name='head2',
                                                weight_column='weights2')
        multi_head = multi_head_lib.MultiHead([head1, head2],
                                              head_weights=[1., 2.])

        # logits tensor is 2x6 instead of 2x5
        logits = np.array([[-10., 10., 20., -20., 20., 70.],
                           [-15., 10., -30., 20., -20., 80.]],
                          dtype=np.float32)
        labels = {
            'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
            'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
        }
        with self.assertRaisesRegexp(ValueError, r'Could not split logits'):
            multi_head.loss(features={
                'x': np.array(((42, ), ), dtype=np.int32),
                'weights1': weights1,
                'weights2': weights2
            },
                            mode=ModeKeys.TRAIN,
                            logits=logits,
                            labels=labels)
Exemple #5
0
  def test_predict_two_heads_logits_tensor_multi_dim(self):
    """Tests predict with multi-dimensional logits of shape [2, 2, 5]."""
    head1 = regression_head.RegressionHead(label_dimension=2, name='head1')
    head2 = regression_head.RegressionHead(label_dimension=3, name='head2')
    multi_head = multi_head_lib.MultiHead([head1, head2])

    logits = np.array(
        [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
         [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]],
        dtype=np.float32)
    expected_logits1 = np.array(
        [[[-1., 1.], [-1., 1.]],
         [[-1.5, 1.], [-1.5, 1.]]],
        dtype=np.float32)
    expected_logits2 = np.array(
        [[[2., -2., 2.], [2., -2., 2.]],
         [[-3., 2., -2.], [-3., 2., -2.]]],
        dtype=np.float32)
    pred_keys = prediction_keys.PredictionKeys

    predictions = multi_head.predictions(logits)
    self.assertAllClose(
        expected_logits1,
        self.evaluate(predictions[('head1', pred_keys.PREDICTIONS)]))
    self.assertAllClose(
        expected_logits2,
        self.evaluate(predictions[('head2', pred_keys.PREDICTIONS)]))
    if context.executing_eagerly():
      return

    spec = multi_head.create_estimator_spec(
        features={'x': np.array(((42,),), dtype=np.int32)},
        mode=model_fn.ModeKeys.PREDICT,
        logits=logits)
    self.assertItemsEqual(
        (test_lib._DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression',
         'head1/predict', 'head2', 'head2/regression', 'head2/predict'),
        spec.export_outputs.keys())
    # Assert predictions and export_outputs.
    with self.cached_session() as sess:
      test_lib._initialize_variables(self, spec.scaffold)
      self.assertIsNone(spec.scaffold.summary_op)
      predictions = sess.run(spec.predictions)
      self.assertAllClose(
          expected_logits1,
          predictions[('head1', pred_keys.PREDICTIONS)])
      self.assertAllClose(
          expected_logits2,
          predictions[('head2', pred_keys.PREDICTIONS)])

      self.assertAllClose(
          expected_logits1,
          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].value))
      self.assertAllClose(
          expected_logits1,
          sess.run(spec.export_outputs['head1'].value))
      self.assertAllClose(
          expected_logits2,
          sess.run(spec.export_outputs['head2'].value))
 def test_multi_head_provided(self):
     """Tests error raised when a multi-head is provided."""
     with self.assertRaisesRegexp(
             ValueError,
             '`MultiHead` is not supported with `SequentialHeadWrapper`.'):
         _ = seq_head_lib.SequentialHeadWrapper(
             multi_head.MultiHead(
                 [binary_head_lib.BinaryClassHead(name='test-head')]))
Exemple #7
0
 def test_head_weights_wrong_size(self):
   head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
   head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')
   with self.assertRaisesRegexp(
       ValueError,
       r'heads and head_weights must have the same size\. '
       r'Given len\(heads\): 2. Given len\(head_weights\): 1\.'):
     multi_head_lib.MultiHead([head1, head2], head_weights=[1.])
Exemple #8
0
  def _make_metrics(self,
                    metric_fn,
                    mode=tf.estimator.ModeKeys.EVAL,
                    multi_head=False,
                    sess=None):

    with context.graph_mode():
      if multi_head:
        head = multi_head_lib.MultiHead(heads=[
            binary_class_head.BinaryClassHead(
                name="head1", loss_reduction=tf_compat.SUM),
            binary_class_head.BinaryClassHead(
                name="head2", loss_reduction=tf_compat.SUM)
        ])
        labels = {"head1": tf.constant([0, 1]), "head2": tf.constant([0, 1])}
      else:
        head = binary_class_head.BinaryClassHead(loss_reduction=tf_compat.SUM)
        labels = tf.constant([0, 1])
      features = {"x": tf.constant([[1.], [2.]])}
      builder = _EnsembleBuilder(head, metric_fn=metric_fn)
      subnetwork_manager = _SubnetworkManager(head, metric_fn=metric_fn)
      subnetwork_builder = _Builder(
          lambda unused0, unused1: tf.no_op(),
          lambda unused0, unused1: tf.no_op(),
          use_logits_last_layer=True)

      subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
          name="test",
          subnetwork_builder=subnetwork_builder,
          summary=_FakeSummary(),
          features=features,
          mode=mode,
          labels=labels)
      ensemble_spec = builder.build_ensemble_spec(
          name="test",
          candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
          ensembler=ComplexityRegularizedEnsembler(
              mixture_weight_type=MixtureWeightType.SCALAR),
          subnetwork_specs=[subnetwork_spec],
          summary=_FakeSummary(),
          features=features,
          iteration_number=0,
          labels=labels,
          mode=mode)
      subnetwork_metric_ops = call_eval_metrics(subnetwork_spec.eval_metrics)
      ensemble_metric_ops = call_eval_metrics(ensemble_spec.eval_metrics)
      evaluate = self.evaluate
      if sess is not None:
        evaluate = sess.run
      evaluate((tf_compat.v1.global_variables_initializer(),
                tf_compat.v1.local_variables_initializer()))
      evaluate((subnetwork_metric_ops, ensemble_metric_ops))
      # Return the idempotent tensor part of the (tensor, op) metrics tuple.
      return {
          k: evaluate(subnetwork_metric_ops[k][0])
          for k in subnetwork_metric_ops
      }, {k: evaluate(ensemble_metric_ops[k][0]) for k in ensemble_metric_ops}
Exemple #9
0
  def test_train_one_head(self):
    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
    multi_head = multi_head_lib.MultiHead([head1])

    logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}
    labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
    features = {'x': np.array(((42,),), dtype=np.int32)}
    # For large logits, sigmoid cross entropy loss is approximated as:
    # loss = labels * (logits < 0) * (-logits) +
    #        (1 - labels) * (logits > 0) * logits =>
    # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
    expected_loss = 8.75
    tol = 1e-3
    loss = multi_head.loss(
        logits=logits,
        labels=labels,
        features=features,
        mode=model_fn.ModeKeys.TRAIN)
    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)
    if context.executing_eagerly():
      return

    expected_train_result = 'my_train_op'
    def _train_op_fn(loss):
      return string_ops.string_join(
          [constant_op.constant(expected_train_result),
           string_ops.as_string(loss, precision=3)])
    spec = multi_head.create_estimator_spec(
        features=features,
        mode=model_fn.ModeKeys.TRAIN,
        logits=logits,
        labels=labels,
        train_op_fn=_train_op_fn)
    self.assertIsNotNone(spec.loss)
    self.assertEqual({}, spec.eval_metric_ops)
    self.assertIsNotNone(spec.train_op)
    self.assertIsNone(spec.export_outputs)
    test_lib._assert_no_hooks(self, spec)
    # Assert predictions, loss, train_op, and summaries.
    with self.cached_session() as sess:
      test_lib._initialize_variables(self, spec.scaffold)
      self.assertIsNotNone(spec.scaffold.summary_op)
      loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
                                                  spec.scaffold.summary_op))
      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
      self.assertEqual(
          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
          train_result)
      test_lib._assert_simple_summaries(self, {
          metric_keys.MetricKeys.LOSS: expected_loss,
          metric_keys.MetricKeys.LOSS + '/head1': expected_loss,
      }, summary_str, tol)
Exemple #10
0
    def test_train_one_head_with_optimizer(self):
        head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
        multi_head = multi_head_lib.MultiHead([head1])

        logits = {
            'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
        }
        labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
        features = {'x': np.array(((42, ), ), dtype=np.int32)}
        # For large logits, sigmoid cross entropy loss is approximated as:
        # loss = labels * (logits < 0) * (-logits) +
        #        (1 - labels) * (logits > 0) * logits =>
        # expected_unweighted_loss = [[10., 10.], [15., 0.]]
        # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
        expected_loss = 8.75
        tol = 1e-3
        loss = multi_head.loss(logits=logits,
                               labels=labels,
                               features=features,
                               mode=model_fn.ModeKeys.TRAIN)
        self.assertAllClose(expected_loss,
                            self.evaluate(loss),
                            rtol=tol,
                            atol=tol)
        if context.executing_eagerly():
            return

        expected_train_result = 'my_train_op'

        class _Optimizer(object):
            def minimize(self, loss, global_step):
                del global_step
                return string_ops.string_join([
                    constant_op.constant(expected_train_result),
                    string_ops.as_string(loss, precision=3)
                ])

        spec = multi_head.create_estimator_spec(features=features,
                                                mode=model_fn.ModeKeys.TRAIN,
                                                logits=logits,
                                                labels=labels,
                                                optimizer=_Optimizer())
        with self.cached_session() as sess:
            test_lib._initialize_variables(self, spec.scaffold)
            loss, train_result = sess.run((spec.loss, spec.train_op))
            self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
            self.assertEqual(
                six.b('{0:s}{1:.3f}'.format(expected_train_result,
                                            expected_loss)), train_result)
Exemple #11
0
    def test_train_create_loss_two_heads_with_weights(self):
        # Use different example weighting for each head weighting.
        weights1 = np.array([[1.], [2.]], dtype=np.float32)
        weights2 = np.array([[2.], [3.]])
        head1 = multi_label_head.MultiLabelHead(n_classes=2,
                                                name='head1',
                                                weight_column='weights1')
        head2 = multi_label_head.MultiLabelHead(n_classes=3,
                                                name='head2',
                                                weight_column='weights2')
        multi_head = multi_head_lib.MultiHead([head1, head2],
                                              head_weights=[1., 2.])

        logits = {
            'head1':
            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
            'head2':
            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),
        }
        labels = {
            'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
            'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
        }
        training_loss = multi_head.loss(logits=logits,
                                        labels=labels,
                                        features={
                                            'x':
                                            np.array(((42, ), ),
                                                     dtype=np.int32),
                                            'weights1':
                                            weights1,
                                            'weights2':
                                            weights2
                                        },
                                        mode=model_fn.ModeKeys.TRAIN)
        tol = 1e-3
        # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
        # = [10, 7.5]
        # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
        # head-weighted unreduced_loss = 1 * [10, 7.5]
        # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
        # = [20, 10]
        # training_loss = (2 * 20 + 3 * 10) / 2 = 35
        # head-weighted unreduced_loss = 2 * [20, 10]
        # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5
        self.assertAllClose(82.5,
                            self.evaluate(training_loss),
                            rtol=tol,
                            atol=tol)
Exemple #12
0
  def test_train_create_loss_one_head(self):
    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
    multi_head = multi_head_lib.MultiHead([head1])

    logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}
    labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
    loss = multi_head.loss(
        logits=logits,
        labels=labels,
        features={'x': np.array(((42,),), dtype=np.int32)},
        mode=model_fn.ModeKeys.TRAIN)
    tol = 1e-3
    # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]
    # (averaged over classes, averaged over examples).
    # loss = sum(unreduced_loss) / 2 = sum([10, 7.5]) / 2 = 8.75
    self.assertAllClose(8.75, self.evaluate(loss), rtol=tol, atol=tol)
    def test_train_loss_logits_tensor_multi_dim_wrong_shape(self):
        """Tests loss with a multi-dimensional logits tensor of the wrong shape."""
        head1 = regression_head.RegressionHead(label_dimension=2, name='head1')
        head2 = regression_head.RegressionHead(label_dimension=3, name='head2')
        multi_head = multi_head_lib.MultiHead([head1, head2])

        # logits tensor is 2x2x4 instead of 2x2x5
        logits = np.array([[[-1., 1., 2., -2.], [-1., 1., 2., -2.]],
                           [[-1.5, 1.5, -2., 2.], [-1.5, 1.5, -2., 2.]]],
                          dtype=np.float32)
        labels = {
            'head1':
            np.array([[[1., 0.], [1., 0.]], [[1.5, 1.5], [1.5, 1.5]]],
                     dtype=np.float32),
            'head2':
            np.array(
                [[[0., 1., 0.], [0., 1., 0.]], [[2., 2., 0.], [2., 2., 0.]]],
                dtype=np.float32),
        }
        with self.assertRaisesRegexp(ValueError, r'Could not split logits'):
            multi_head.loss(features={},
                            mode=ModeKeys.TRAIN,
                            logits=logits,
                            labels=labels)
Exemple #14
0
  def test_train_with_regularization_losses(self):
    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')
    multi_head = multi_head_lib.MultiHead(
        [head1, head2], head_weights=[1., 2.])

    logits = {
        'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
        'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
                          dtype=np.float32),
    }
    labels = {
        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
    }
    features = {'x': np.array(((42,),), dtype=np.int32)}
    regularization_losses = [1.5, 0.5]

    # For large logits, sigmoid cross entropy loss is approximated as:
    # loss = labels * (logits < 0) * (-logits) +
    #        (1 - labels) * (logits > 0) * logits =>
    # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
    # loss1 = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
    # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
    # loss2 = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15
    # Average over classes, weighted sum over batch and heads.
    # weights = [1., 2.]
    # merged_training_loss = 1. * loss1 + 2. * loss2
    # training_loss = merged_training_loss + regularization_loss
    #               = 1. * loss1 + 2. * loss2 + sum([1.5, 0.5])
    expected_loss_head1 = 8.75
    expected_loss_head2 = 15.0
    expected_regularization_loss = 2.
    # training loss.
    expected_loss = (1. * expected_loss_head1 + 2. * expected_loss_head2
                     + expected_regularization_loss)
    tol = 1e-3
    loss = multi_head.loss(
        logits=logits,
        labels=labels,
        features=features,
        mode=model_fn.ModeKeys.TRAIN,
        regularization_losses=regularization_losses)
    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)
    if context.executing_eagerly():
      return

    keys = metric_keys.MetricKeys
    expected_train_result = 'my_train_op'
    def _train_op_fn(loss):
      return string_ops.string_join(
          [constant_op.constant(expected_train_result),
           string_ops.as_string(loss, precision=3)])

    spec = multi_head.create_estimator_spec(
        features=features,
        mode=model_fn.ModeKeys.TRAIN,
        logits=logits,
        labels=labels,
        train_op_fn=_train_op_fn,
        regularization_losses=regularization_losses)
    self.assertIsNotNone(spec.loss)
    self.assertEqual({}, spec.eval_metric_ops)
    self.assertIsNotNone(spec.train_op)
    self.assertIsNone(spec.export_outputs)
    test_lib._assert_no_hooks(self, spec)
    # Assert predictions, loss, train_op, and summaries.
    with self.cached_session() as sess:
      test_lib._initialize_variables(self, spec.scaffold)
      self.assertIsNotNone(spec.scaffold.summary_op)
      loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
                                                  spec.scaffold.summary_op))
      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
      self.assertEqual(
          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
          train_result)
      test_lib._assert_simple_summaries(self, {
          keys.LOSS_REGULARIZATION: expected_regularization_loss,
          keys.LOSS: expected_loss,
          keys.LOSS + '/head1': expected_loss_head1,
          keys.LOSS + '/head2': expected_loss_head2,
      }, summary_str, tol)
Exemple #15
0
 def test_name(self):
   head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
   head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')
   multi_head = multi_head_lib.MultiHead([head1, head2])
   self.assertEqual('head1_head2', multi_head.name)
    def test_train_two_heads_with_weights(self):
        head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
        head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')
        multi_head = multi_head_lib.MultiHead([head1, head2],
                                              head_weights=[1., 2.])

        logits = {
            'head1':
            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
            'head2':
            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),
        }
        expected_probabilities = {
            'head1': tf.math.sigmoid(logits['head1']),
            'head2': tf.math.sigmoid(logits['head2']),
        }
        labels = {
            'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
            'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
        }
        features = {'x': np.array(((42, ), ), dtype=np.int32)}
        # For large logits, sigmoid cross entropy loss is approximated as:
        # loss = labels * (logits < 0) * (-logits) +
        #        (1 - labels) * (logits > 0) * logits =>
        # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
        # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
        # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
        # loss = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15
        # Average over classes, weighted sum over batch and heads.
        expected_loss_head1 = 8.75
        expected_loss_head2 = 15.0
        expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
        tol = 1e-3
        loss = multi_head.loss(logits=logits,
                               labels=labels,
                               features=features,
                               mode=ModeKeys.TRAIN)
        self.assertAllClose(expected_loss,
                            self.evaluate(loss),
                            rtol=tol,
                            atol=tol)
        if tf.executing_eagerly():
            return

        expected_train_result = 'my_train_op'

        def _train_op_fn(loss):
            return tf.strings.join([
                tf.constant(expected_train_result),
                tf.strings.as_string(loss, precision=3)
            ])

        spec = multi_head.create_estimator_spec(features=features,
                                                mode=ModeKeys.TRAIN,
                                                logits=logits,
                                                labels=labels,
                                                train_op_fn=_train_op_fn)
        self.assertIsNotNone(spec.loss)
        self.assertEqual({}, spec.eval_metric_ops)
        self.assertIsNotNone(spec.train_op)
        self.assertIsNone(spec.export_outputs)
        test_lib._assert_no_hooks(self, spec)
        # Assert predictions, loss, train_op, and summaries.
        with self.cached_session() as sess:
            test_lib._initialize_variables(self, spec.scaffold)
            self.assertIsNotNone(spec.scaffold.summary_op)
            loss, train_result, summary_str, predictions = sess.run(
                (spec.loss, spec.train_op, spec.scaffold.summary_op,
                 spec.predictions))
            self.assertAllClose(
                logits['head1'],
                predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])
            self.assertAllClose(
                expected_probabilities['head1'],
                predictions[('head1',
                             prediction_keys.PredictionKeys.PROBABILITIES)])
            self.assertAllClose(
                logits['head2'],
                predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])
            self.assertAllClose(
                expected_probabilities['head2'],
                predictions[('head2',
                             prediction_keys.PredictionKeys.PROBABILITIES)])
            self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
            self.assertEqual(
                six.b('{0:s}{1:.3f}'.format(expected_train_result,
                                            expected_loss)), train_result)
            test_lib._assert_simple_summaries(
                self, {
                    metric_keys.MetricKeys.LOSS: expected_loss,
                    metric_keys.MetricKeys.LOSS + '/head1':
                    expected_loss_head1,
                    metric_keys.MetricKeys.LOSS + '/head2':
                    expected_loss_head2,
                }, summary_str, tol)
Exemple #17
0
  def test_predict_two_heads_logits_dict(self):
    """Tests predict with logits as dict."""
    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')
    multi_head = multi_head_lib.MultiHead([head1, head2])

    logits = {
        'head1': np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32),
        'head2': np.array([[2., -2., 2.], [-3., 2., -2.]], dtype=np.float32)
    }
    expected_probabilities = {
        'head1': nn.sigmoid(logits['head1']),
        'head2': nn.sigmoid(logits['head2']),
    }
    pred_keys = prediction_keys.PredictionKeys

    predictions = multi_head.predictions(logits)
    self.assertAllClose(
        logits['head1'],
        self.evaluate(predictions[('head1', pred_keys.LOGITS)]))
    self.assertAllClose(
        logits['head2'],
        self.evaluate(predictions[('head2', pred_keys.LOGITS)]))
    self.assertAllClose(
        expected_probabilities['head1'],
        self.evaluate(predictions[('head1', pred_keys.PROBABILITIES)]))
    self.assertAllClose(
        expected_probabilities['head2'],
        self.evaluate(predictions[('head2', pred_keys.PROBABILITIES)]))
    if context.executing_eagerly():
      return

    spec = multi_head.create_estimator_spec(
        features={'x': np.array(((42,),), dtype=np.int32)},
        mode=model_fn.ModeKeys.PREDICT,
        logits=logits)
    self.assertItemsEqual(
        (test_lib._DEFAULT_SERVING_KEY, 'predict', 'head1',
         'head1/classification', 'head1/predict', 'head2',
         'head2/classification', 'head2/predict'), spec.export_outputs.keys())
    # Assert predictions and export_outputs.
    with self.cached_session() as sess:
      test_lib._initialize_variables(self, spec.scaffold)
      self.assertIsNone(spec.scaffold.summary_op)
      predictions = sess.run(spec.predictions)
      self.assertAllClose(
          logits['head1'],
          predictions[('head1', pred_keys.LOGITS)])
      self.assertAllClose(
          logits['head2'],
          predictions[('head2', pred_keys.LOGITS)])
      self.assertAllClose(
          expected_probabilities['head1'],
          predictions[('head1', pred_keys.PROBABILITIES)])
      self.assertAllClose(
          expected_probabilities['head2'],
          predictions[('head2', pred_keys.PROBABILITIES)])

      self.assertAllClose(
          expected_probabilities['head1'],
          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))
      self.assertAllClose(
          expected_probabilities['head1'],
          sess.run(spec.export_outputs['head1'].scores))
      self.assertAllClose(
          expected_probabilities['head2'],
          sess.run(spec.export_outputs['head2'].scores))
      self.assertAllClose(
          expected_probabilities['head1'],
          sess.run(
              spec.export_outputs['predict'].outputs['head1/probabilities']))
      self.assertAllClose(
          expected_probabilities['head2'],
          sess.run(
              spec.export_outputs['predict'].outputs['head2/probabilities']))
      self.assertAllClose(
          expected_probabilities['head1'],
          sess.run(
              spec.export_outputs['head1/predict'].outputs['probabilities']))
      self.assertAllClose(
          expected_probabilities['head2'],
          sess.run(
              spec.export_outputs['head2/predict'].outputs['probabilities']))
Exemple #18
0
 def test_head_name_missing(self):
   head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
   head2 = multi_label_head.MultiLabelHead(n_classes=3)
   with self.assertRaisesRegexp(
       ValueError, r'All given heads must have name specified\.'):
     multi_head_lib.MultiHead([head1, head2])
Exemple #19
0
 def test_no_heads(self):
   with self.assertRaisesRegexp(
       ValueError, r'Must specify heads\. Given: \[\]'):
     multi_head_lib.MultiHead(heads=[])
def simple_multi_head(export_path, eval_export_path):
    """Trains and exports a simple multi-headed model."""
    def eval_input_receiver_fn():
        """Eval input receiver function."""
        serialized_tf_example = tf.compat.v1.placeholder(
            dtype=tf.string, shape=[None], name='input_example_tensor')

        language = tf.feature_column.categorical_column_with_vocabulary_list(
            'language', ['english', 'chinese', 'other'])
        age = tf.feature_column.numeric_column('age')
        english_label = tf.feature_column.numeric_column('english_label')
        chinese_label = tf.feature_column.numeric_column('chinese_label')
        other_label = tf.feature_column.numeric_column('other_label')
        all_features = [
            age, language, english_label, chinese_label, other_label
        ]
        feature_spec = tf.feature_column.make_parse_example_spec(all_features)
        receiver_tensors = {'examples': serialized_tf_example}
        features = tf.io.parse_example(serialized=serialized_tf_example,
                                       features=feature_spec)

        labels = {
            'english_head': features['english_label'],
            'chinese_head': features['chinese_label'],
            'other_head': features['other_label'],
        }

        return export.EvalInputReceiver(features=features,
                                        receiver_tensors=receiver_tensors,
                                        labels=labels)

    def input_fn():
        """Train input function."""
        labels = {
            'english_head': tf.constant([[1], [1], [0], [0], [0], [0]]),
            'chinese_head': tf.constant([[0], [0], [1], [1], [0], [0]]),
            'other_head': tf.constant([[0], [0], [0], [0], [1], [1]])
        }
        features = {
            'age':
            tf.constant([[1], [2], [3], [4], [5], [6]]),
            'language':
            tf.SparseTensor(values=[
                'english', 'english', 'chinese', 'chinese', 'other', 'other'
            ],
                            indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0],
                                     [5, 0]],
                            dense_shape=[6, 1]),
        }
        return features, labels

    language = tf.feature_column.categorical_column_with_vocabulary_list(
        'language', ['english', 'chinese', 'other'])
    age = tf.feature_column.numeric_column('age')
    all_features = [age, language]
    feature_spec = tf.feature_column.make_parse_example_spec(all_features)

    # TODO(b/130299739): Update with tf.estimator.BinaryClassHead and
    #   tf.estimator.MultiHead
    english_head = binary_class_head.BinaryClassHead(name='english_head')
    chinese_head = binary_class_head.BinaryClassHead(name='chinese_head')
    other_head = binary_class_head.BinaryClassHead(name='other_head')
    combined_head = multi_head.MultiHead(
        [english_head, chinese_head, other_head])

    estimator = tf_compat_v1_estimator.DNNLinearCombinedEstimator(
        head=combined_head,
        dnn_feature_columns=[],
        dnn_optimizer=tf.compat.v1.train.AdagradOptimizer(learning_rate=0.01),
        dnn_hidden_units=[],
        linear_feature_columns=[language, age],
        linear_optimizer=tf.compat.v1.train.FtrlOptimizer(learning_rate=0.05))
    estimator.train(input_fn=input_fn, steps=1000)

    return util.export_model_and_eval_model(
        estimator=estimator,
        serving_input_receiver_fn=(
            tf_estimator.export.build_parsing_serving_input_receiver_fn(
                feature_spec)),
        eval_input_receiver_fn=eval_input_receiver_fn,
        export_path=export_path,
        eval_export_path=eval_export_path)
Exemple #21
0
  def test_eval_two_heads_with_weights(self):
    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')
    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')
    multi_head = multi_head_lib.MultiHead(
        [head1, head2], head_weights=[1., 2.])

    logits = {
        'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
        'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
                          dtype=np.float32),
    }
    labels = {
        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
    }
    features = {'x': np.array(((42,),), dtype=np.int32)}
    # For large logits, sigmoid cross entropy loss is approximated as:
    # loss = labels * (logits < 0) * (-logits) +
    #        (1 - labels) * (logits > 0) * logits =>
    # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
    # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75
    # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
    # loss = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15
    expected_loss_head1 = 8.75
    expected_loss_head2 = 15.
    expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
    tol = 1e-3
    keys = metric_keys.MetricKeys
    expected_metrics = {
        keys.LOSS + '/head1': expected_loss_head1,
        keys.LOSS + '/head2': expected_loss_head2,
        # Average loss over examples.
        keys.LOSS_MEAN + '/head1': expected_loss_head1,
        keys.LOSS_MEAN + '/head2': expected_loss_head2,
        # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but
        # this assert tests that the algorithm remains consistent.
        # TODO(yhliang): update metrics
        # keys.AUC + '/head1': 0.1667,
        # keys.AUC + '/head2': 0.3333,
        # keys.AUC_PR + '/head1': 0.6667,
        # keys.AUC_PR + '/head2': 0.5000,
    }

    if context.executing_eagerly():
      loss = multi_head.loss(
          logits, labels, features=features, mode=model_fn.ModeKeys.EVAL)
      self.assertIsNotNone(loss)
      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)

      eval_metrics = multi_head.metrics()
      updated_metrics = multi_head.update_metrics(
          eval_metrics, features, logits, labels)
      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())
      self.assertAllClose(
          expected_metrics,
          {k: updated_metrics[k].result() for k in updated_metrics},
          rtol=tol,
          atol=tol)
      return

    spec = multi_head.create_estimator_spec(
        features=features,
        mode=model_fn.ModeKeys.EVAL,
        logits=logits,
        labels=labels)
    # Assert spec contains expected tensors.
    self.assertIsNotNone(spec.loss)
    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
    self.assertIsNone(spec.train_op)
    self.assertIsNone(spec.export_outputs)
    test_lib._assert_no_hooks(self, spec)
    # Assert predictions, loss, and metrics.
    with self.cached_session() as sess:
      test_lib._initialize_variables(self, spec.scaffold)
      self.assertIsNone(spec.scaffold.summary_op)
      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
      loss, _ = sess.run((spec.loss, update_ops))
      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
      # Check results of value ops (in `metrics`).
      self.assertAllClose(
          expected_metrics, {k: value_ops[k].eval() for k in value_ops},
          rtol=tol,
          atol=tol)
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(head=head)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", 0., trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ComplexityRegularizedEnsembler(
                    mixture_weight_type=mixture_weight_type,
                    mixture_weight_initializer=mixture_weight_initializer,
                    warm_start_mixture_weights=warm_start_mixture_weights,
                    model_dir=self.test_subdirectory,
                    adanet_lambda=adanet_lambda,
                    adanet_beta=adanet_beta,
                    use_bias=use_bias),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                labels=labels,
                mode=mode)

            with tf_compat.v1.Session(graph=g).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                # Bias should learn a non-zero value when used.
                bias = sess.run(ensemble_spec.ensemble.bias)
                if isinstance(bias, dict):
                    bias = sum(abs(b) for b in bias.values())
                if use_bias:
                    self.assertNotEqual(0., bias)
                else:
                    self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2,
            ensembler_class=ComplexityRegularizedEnsembler,
            my_ensemble_index=None,
            want_replay_indices=None,
            want_predictions=None,
            export_subnetworks=False,
            previous_ensemble_spec=None,
            previous_iteration_checkpoint=None):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(
            head=head,
            export_subnetwork_logits=export_subnetworks,
            export_subnetwork_last_layer=export_subnetworks)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            if not var_list:
                return tf.no_op()
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            tf_compat.v1.train.get_or_create_global_step()
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", shape=0, trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            session_config = tf.compat.v1.ConfigProto(
                gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensembler_kwargs = {}
            if ensembler_class is ComplexityRegularizedEnsembler:
                ensembler_kwargs.update({
                    "mixture_weight_type": mixture_weight_type,
                    "mixture_weight_initializer": mixture_weight_initializer,
                    "warm_start_mixture_weights": warm_start_mixture_weights,
                    "model_dir": self.test_subdirectory,
                    "adanet_lambda": adanet_lambda,
                    "adanet_beta": adanet_beta,
                    "use_bias": use_bias
                })
            if ensembler_class is MeanEnsembler:
                ensembler_kwargs.update(
                    {"add_mean_last_layer_predictions": True})
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ensembler_class(**ensembler_kwargs),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                labels=labels,
                my_ensemble_index=my_ensemble_index,
                mode=mode,
                previous_iteration_checkpoint=previous_iteration_checkpoint)

            if want_replay_indices:
                self.assertAllEqual(want_replay_indices,
                                    ensemble_spec.architecture.replay_indices)

            with tf_compat.v1.Session(
                    graph=g, config=session_config).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    if not export_subnetworks:
                        return
                    if not multi_head:
                        subnetwork_logits = sess.run(
                            ensemble_spec.export_outputs[
                                _EnsembleBuilder.
                                _SUBNETWORK_LOGITS_EXPORT_SIGNATURE].outputs)
                        self.assertAllClose(
                            subnetwork_logits["test"],
                            sess.run(subnetwork_spec.subnetwork.logits))
                        subnetwork_last_layer = sess.run(
                            ensemble_spec.export_outputs[
                                _EnsembleBuilder.
                                _SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE].
                            outputs)
                        self.assertAllClose(
                            subnetwork_last_layer["test"],
                            sess.run(subnetwork_spec.subnetwork.last_layer))
                    else:
                        self.assertIn("subnetwork_logits_head2",
                                      ensemble_spec.export_outputs)
                        subnetwork_logits_head1 = sess.run(
                            ensemble_spec.
                            export_outputs["subnetwork_logits_head1"].outputs)
                        self.assertAllClose(
                            subnetwork_logits_head1["test"],
                            sess.run(
                                subnetwork_spec.subnetwork.logits["head1"]))
                        self.assertIn("subnetwork_logits_head2",
                                      ensemble_spec.export_outputs)
                        subnetwork_last_layer_head1 = sess.run(
                            ensemble_spec.export_outputs[
                                "subnetwork_last_layer_head1"].outputs)
                        self.assertAllClose(
                            subnetwork_last_layer_head1["test"],
                            sess.run(subnetwork_spec.subnetwork.
                                     last_layer["head1"]))
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                if ensembler_class is ComplexityRegularizedEnsembler:
                    # Bias should learn a non-zero value when used.
                    bias = sess.run(ensemble_spec.ensemble.bias)
                    if isinstance(bias, dict):
                        bias = sum(abs(b) for b in bias.values())
                    if use_bias:
                        self.assertNotEqual(0., bias)
                    else:
                        self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)

                if want_predictions:
                    self.assertAllClose(
                        want_predictions,
                        sess.run(ensemble_spec.ensemble.predictions),
                        atol=1e-3)