示例#1
0
  def test_materialize_subnetwork_reports(self,
                                          input_fn,
                                          subnetwork_reports_fn,
                                          steps,
                                          iteration_number=0,
                                          included_subnetwork_names=None,
                                          want_materialized_reports=None):
    with context.graph_mode():
      tf.constant(0.)  # dummy op so that the session graph is never empty.
      features, labels = input_fn()
      subnetwork_reports = subnetwork_reports_fn(features, labels)
      with self.test_session() as sess:
        sess.run(tf_compat.v1.initializers.local_variables())
        report_materializer = ReportMaterializer(input_fn=input_fn, steps=steps)
        materialized_reports = (
            report_materializer.materialize_subnetwork_reports(
                sess, iteration_number, subnetwork_reports,
                included_subnetwork_names))
        self.assertEqual(
            len(want_materialized_reports), len(materialized_reports))
        materialized_reports_dict = {
            blrm.name: blrm for blrm in materialized_reports
        }
        for want_materialized_report in want_materialized_reports:
          materialized_report = (
              materialized_reports_dict[want_materialized_report.name])
          self.assertEqual(iteration_number,
                           materialized_report.iteration_number)
          self.assertEqual(
              set(want_materialized_report.hparams.keys()),
              set(materialized_report.hparams.keys()))
          for hparam_key, want_hparam in (
              want_materialized_report.hparams.items()):
            if isinstance(want_hparam, float):
              self.assertAllClose(want_hparam,
                                  materialized_report.hparams[hparam_key])
            else:
              self.assertEqual(want_hparam,
                               materialized_report.hparams[hparam_key])

          self.assertSetEqual(
              set(want_materialized_report.attributes.keys()),
              set(materialized_report.attributes.keys()))
          for attribute_key, want_attribute in (
              want_materialized_report.attributes.items()):
            if isinstance(want_attribute, float):
              self.assertAllClose(
                  want_attribute,
                  decode(materialized_report.attributes[attribute_key]))
            else:
              self.assertEqual(
                  want_attribute,
                  decode(materialized_report.attributes[attribute_key]))

          self.assertSetEqual(
              set(want_materialized_report.metrics.keys()),
              set(materialized_report.metrics.keys()))
          for metric_key, want_metric in (
              want_materialized_report.metrics.items()):
            if isinstance(want_metric, float):
              self.assertAllClose(
                  want_metric, decode(materialized_report.metrics[metric_key]))
            else:
              self.assertEqual(want_metric,
                               decode(materialized_report.metrics[metric_key]))
示例#2
0
    def test_summaries(self):
        """Tests that summaries are written to candidate directory."""

        run_config = tf.estimator.RunConfig(tf_random_seed=42,
                                            log_step_count_steps=2,
                                            save_summary_steps=2,
                                            model_dir=self.test_subdirectory)
        subnetwork_generator = SimpleGenerator([_SimpleBuilder("dnn")])
        report_materializer = ReportMaterializer(input_fn=tu.dummy_input_fn(
            [[1., 1.]], [[0.]]),
                                                 steps=1)
        estimator = Estimator(head=regression_head.RegressionHead(
            loss_reduction=tf_compat.SUM_OVER_BATCH_SIZE),
                              subnetwork_generator=subnetwork_generator,
                              report_materializer=report_materializer,
                              max_iteration_steps=10,
                              config=run_config)
        train_input_fn = tu.dummy_input_fn([[1., 0.]], [[1.]])
        estimator.train(input_fn=train_input_fn, max_steps=3)

        ensemble_loss = 1.52950
        self.assertAlmostEqual(ensemble_loss,
                               tu.check_eventfile_for_keyword(
                                   "loss", self.test_subdirectory),
                               places=3)
        self.assertIsNotNone(
            tu.check_eventfile_for_keyword("global_step/sec",
                                           self.test_subdirectory))
        self.assertEqual(
            0.,
            tu.check_eventfile_for_keyword("iteration/adanet/iteration",
                                           self.test_subdirectory))

        subnetwork_subdir = os.path.join(self.test_subdirectory,
                                         "subnetwork/t0_dnn")
        self.assertAlmostEqual(3.,
                               tu.check_eventfile_for_keyword(
                                   "scalar", subnetwork_subdir),
                               places=3)
        self.assertEqual(
            (3, 3, 1),
            tu.check_eventfile_for_keyword("image", subnetwork_subdir))
        self.assertAlmostEqual(5.,
                               tu.check_eventfile_for_keyword(
                                   "nested/scalar", subnetwork_subdir),
                               places=3)

        ensemble_subdir = os.path.join(
            self.test_subdirectory,
            "ensemble/t0_dnn_grow_complexity_regularized")
        self.assertAlmostEqual(
            ensemble_loss,
            tu.check_eventfile_for_keyword(
                "adanet_loss/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=1)
        self.assertAlmostEqual(
            0.,
            tu.check_eventfile_for_keyword(
                "complexity_regularization/adanet/adanet_weighted_ensemble",
                ensemble_subdir),
            places=3)
        self.assertAlmostEqual(1.,
                               tu.check_eventfile_for_keyword(
                                   "mixture_weight_norms/adanet/"
                                   "adanet_weighted_ensemble/subnetwork_0",
                                   ensemble_subdir),
                               places=3)