def _make_metrics(self, metric_fn, mode=tf.estimator.ModeKeys.EVAL, multi_head=False, sess=None): with context.graph_mode(): if multi_head: head = multi_head_lib.MultiHead(heads=[ binary_class_head.BinaryClassHead( name="head1", loss_reduction=tf_compat.SUM), binary_class_head.BinaryClassHead( name="head2", loss_reduction=tf_compat.SUM) ]) labels = {"head1": tf.constant([0, 1]), "head2": tf.constant([0, 1])} else: head = binary_class_head.BinaryClassHead(loss_reduction=tf_compat.SUM) labels = tf.constant([0, 1]) features = {"x": tf.constant([[1.], [2.]])} builder = _EnsembleBuilder(head, metric_fn=metric_fn) subnetwork_manager = _SubnetworkManager(head, metric_fn=metric_fn) subnetwork_builder = _Builder( lambda unused0, unused1: tf.no_op(), lambda unused0, unused1: tf.no_op(), use_logits_last_layer=True) subnetwork_spec = subnetwork_manager.build_subnetwork_spec( name="test", subnetwork_builder=subnetwork_builder, summary=_FakeSummary(), features=features, mode=mode, labels=labels) ensemble_spec = builder.build_ensemble_spec( name="test", candidate=EnsembleCandidate("foo", [subnetwork_builder], None), ensembler=ComplexityRegularizedEnsembler( mixture_weight_type=MixtureWeightType.SCALAR), subnetwork_specs=[subnetwork_spec], summary=_FakeSummary(), features=features, iteration_number=0, labels=labels, mode=mode) subnetwork_metric_ops = call_eval_metrics(subnetwork_spec.eval_metrics) ensemble_metric_ops = call_eval_metrics(ensemble_spec.eval_metrics) evaluate = self.evaluate if sess is not None: evaluate = sess.run evaluate((tf_compat.v1.global_variables_initializer(), tf_compat.v1.local_variables_initializer())) evaluate((subnetwork_metric_ops, ensemble_metric_ops)) # Return the idempotent tensor part of the (tensor, op) metrics tuple. return { k: evaluate(subnetwork_metric_ops[k][0]) for k in subnetwork_metric_ops }, {k: evaluate(ensemble_metric_ops[k][0]) for k in ensemble_metric_ops}
def test_build_ensemble_spec( self, want_logits, want_loss=None, want_adanet_loss=None, want_ensemble_trainable_vars=None, adanet_lambda=0., adanet_beta=0., ensemble_spec_fn=lambda: None, use_bias=False, use_logits_last_layer=False, mixture_weight_type=MixtureWeightType.MATRIX, mixture_weight_initializer=tf_compat.v1.zeros_initializer(), warm_start_mixture_weights=True, subnetwork_builder_class=_Builder, mode=tf.estimator.ModeKeys.TRAIN, multi_head=False, want_subnetwork_trainable_vars=2): seed = 64 if multi_head: head = multi_head_lib.MultiHead(heads=[ binary_class_head.BinaryClassHead( name="head1", loss_reduction=tf_compat.SUM), binary_class_head.BinaryClassHead(name="head2", loss_reduction=tf_compat.SUM) ]) else: head = binary_class_head.BinaryClassHead( loss_reduction=tf_compat.SUM) builder = _EnsembleBuilder(head=head) def _subnetwork_train_op_fn(loss, var_list): self.assertLen(var_list, want_subnetwork_trainable_vars) self.assertEqual( var_list, tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("subnetwork_test/iteration_step", tf_compat.v1.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) optimizer = tf_compat.v1.train.GradientDescentOptimizer( learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) def _mixture_weights_train_op_fn(loss, var_list): self.assertLen(var_list, want_ensemble_trainable_vars) self.assertEqual( var_list, tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("ensemble_test/iteration_step", tf_compat.v1.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) optimizer = tf_compat.v1.train.GradientDescentOptimizer( learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) previous_ensemble = None previous_ensemble_spec = ensemble_spec_fn() if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble subnetwork_manager = _SubnetworkManager(head) subnetwork_builder = subnetwork_builder_class( _subnetwork_train_op_fn, _mixture_weights_train_op_fn, use_logits_last_layer, seed, multi_head=multi_head) with tf.Graph().as_default() as g: # A trainable variable to later verify that creating models does not # affect the global variables collection. _ = tf_compat.v1.get_variable("some_var", 0., trainable=True) features = {"x": tf.constant([[1.], [2.]])} if multi_head: labels = { "head1": tf.constant([0, 1]), "head2": tf.constant([0, 1]) } else: labels = tf.constant([0, 1]) subnetwork_spec = subnetwork_manager.build_subnetwork_spec( name="test", subnetwork_builder=subnetwork_builder, iteration_step=tf_compat.v1.train.get_or_create_global_step(), summary=_FakeSummary(), features=features, mode=mode, labels=labels, previous_ensemble=previous_ensemble) ensemble_spec = builder.build_ensemble_spec( # Note: when ensemble_spec is not None and warm_start_mixture_weights # is True, we need to make sure that the bias and mixture weights are # already saved to the checkpoint_dir. name="test", previous_ensemble_spec=previous_ensemble_spec, candidate=EnsembleCandidate("foo", [subnetwork_builder], None), ensembler=ComplexityRegularizedEnsembler( mixture_weight_type=mixture_weight_type, mixture_weight_initializer=mixture_weight_initializer, warm_start_mixture_weights=warm_start_mixture_weights, model_dir=self.test_subdirectory, adanet_lambda=adanet_lambda, adanet_beta=adanet_beta, use_bias=use_bias), subnetwork_specs=[subnetwork_spec], summary=_FakeSummary(), features=features, iteration_number=1, iteration_step=tf_compat.v1.train.get_or_create_global_step(), labels=labels, mode=mode) with tf_compat.v1.Session(graph=g).as_default() as sess: sess.run(tf_compat.v1.global_variables_initializer()) # Equals the number of subnetwork and ensemble trainable variables, # plus the one 'some_var' created earlier. self.assertLen( tf_compat.v1.trainable_variables(), want_subnetwork_trainable_vars + want_ensemble_trainable_vars + 1) # Get the real global step outside a subnetwork's context. self.assertEqual("global_step", tf_compat.v1.train.get_global_step().op.name) self.assertEqual("global_step", train.get_global_step().op.name) self.assertEqual("global_step", tf_v1.train.get_global_step().op.name) self.assertEqual("global_step", training_util.get_global_step().op.name) self.assertEqual( "global_step", tf_compat.v1.train.get_or_create_global_step().op.name) self.assertEqual("global_step", train.get_or_create_global_step().op.name) self.assertEqual( "global_step", tf_v1.train.get_or_create_global_step().op.name) self.assertEqual( "global_step", training_util.get_or_create_global_step().op.name) # Get global tf.summary outside a subnetwork's context. self.assertNotEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertNotEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertNotEqual( "fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertNotEqual( "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) if mode == tf.estimator.ModeKeys.PREDICT: self.assertAllClose(want_logits, sess.run( ensemble_spec.ensemble.logits), atol=1e-3) self.assertIsNone(ensemble_spec.loss) self.assertIsNone(ensemble_spec.adanet_loss) self.assertIsNone(ensemble_spec.train_op) self.assertIsNotNone(ensemble_spec.export_outputs) return # Verify that train_op works, previous loss should be greater than loss # after a train op. loss = sess.run(ensemble_spec.loss) train_op = tf.group(subnetwork_spec.train_op.train_op, ensemble_spec.train_op.train_op) for _ in range(3): sess.run(train_op) self.assertGreater(loss, sess.run(ensemble_spec.loss)) self.assertAllClose(want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) # Bias should learn a non-zero value when used. bias = sess.run(ensemble_spec.ensemble.bias) if isinstance(bias, dict): bias = sum(abs(b) for b in bias.values()) if use_bias: self.assertNotEqual(0., bias) else: self.assertAlmostEqual(0., bias) self.assertAlmostEqual(want_loss, sess.run(ensemble_spec.loss), places=3) self.assertAlmostEqual(want_adanet_loss, sess.run(ensemble_spec.adanet_loss), places=3)