def test_materialize_subnetwork_reports(self, input_fn, subnetwork_reports_fn, steps, iteration_number=0, included_subnetwork_names=None, want_materialized_reports=None): with context.graph_mode(): tf.constant(0.) # dummy op so that the session graph is never empty. features, labels = input_fn() subnetwork_reports = subnetwork_reports_fn(features, labels) with self.test_session() as sess: sess.run(tf_compat.v1.initializers.local_variables()) report_materializer = ReportMaterializer(input_fn=input_fn, steps=steps) materialized_reports = ( report_materializer.materialize_subnetwork_reports( sess, iteration_number, subnetwork_reports, included_subnetwork_names)) self.assertEqual( len(want_materialized_reports), len(materialized_reports)) materialized_reports_dict = { blrm.name: blrm for blrm in materialized_reports } for want_materialized_report in want_materialized_reports: materialized_report = ( materialized_reports_dict[want_materialized_report.name]) self.assertEqual(iteration_number, materialized_report.iteration_number) self.assertEqual( set(want_materialized_report.hparams.keys()), set(materialized_report.hparams.keys())) for hparam_key, want_hparam in ( want_materialized_report.hparams.items()): if isinstance(want_hparam, float): self.assertAllClose(want_hparam, materialized_report.hparams[hparam_key]) else: self.assertEqual(want_hparam, materialized_report.hparams[hparam_key]) self.assertSetEqual( set(want_materialized_report.attributes.keys()), set(materialized_report.attributes.keys())) for attribute_key, want_attribute in ( want_materialized_report.attributes.items()): if isinstance(want_attribute, float): self.assertAllClose( want_attribute, decode(materialized_report.attributes[attribute_key])) else: self.assertEqual( want_attribute, decode(materialized_report.attributes[attribute_key])) self.assertSetEqual( set(want_materialized_report.metrics.keys()), set(materialized_report.metrics.keys())) for metric_key, want_metric in ( want_materialized_report.metrics.items()): if isinstance(want_metric, float): self.assertAllClose( want_metric, decode(materialized_report.metrics[metric_key])) else: self.assertEqual(want_metric, decode(materialized_report.metrics[metric_key]))
def test_summaries(self): """Tests that summaries are written to candidate directory.""" run_config = tf.estimator.RunConfig(tf_random_seed=42, log_step_count_steps=2, save_summary_steps=2, model_dir=self.test_subdirectory) subnetwork_generator = SimpleGenerator([_SimpleBuilder("dnn")]) report_materializer = ReportMaterializer(input_fn=tu.dummy_input_fn( [[1., 1.]], [[0.]]), steps=1) estimator = Estimator(head=regression_head.RegressionHead( loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE), subnetwork_generator=subnetwork_generator, report_materializer=report_materializer, max_iteration_steps=10, config=run_config) train_input_fn = tu.dummy_input_fn([[1., 0.]], [[1.]]) estimator.train(input_fn=train_input_fn, max_steps=3) ensemble_loss = 1.52950 self.assertAlmostEqual(ensemble_loss, tu.check_eventfile_for_keyword( "loss", self.test_subdirectory), places=3) self.assertIsNotNone( tu.check_eventfile_for_keyword("global_step/sec", self.test_subdirectory)) self.assertEqual( 0., tu.check_eventfile_for_keyword("iteration/adanet/iteration", self.test_subdirectory)) subnetwork_subdir = os.path.join(self.test_subdirectory, "subnetwork/t0_dnn") self.assertAlmostEqual(3., tu.check_eventfile_for_keyword( "scalar", subnetwork_subdir), places=3) self.assertEqual( (3, 3, 1), tu.check_eventfile_for_keyword("image", subnetwork_subdir)) self.assertAlmostEqual(5., tu.check_eventfile_for_keyword( "nested/scalar", subnetwork_subdir), places=3) ensemble_subdir = os.path.join( self.test_subdirectory, "ensemble/t0_dnn_grow_complexity_regularized") self.assertAlmostEqual( ensemble_loss, tu.check_eventfile_for_keyword( "adanet_loss/adanet/adanet_weighted_ensemble", ensemble_subdir), places=1) self.assertAlmostEqual( 0., tu.check_eventfile_for_keyword( "complexity_regularization/adanet/adanet_weighted_ensemble", ensemble_subdir), places=3) self.assertAlmostEqual(1., tu.check_eventfile_for_keyword( "mixture_weight_norms/adanet/" "adanet_weighted_ensemble/subnetwork_0", ensemble_subdir), places=3)