Exemplo n.º 1
0
 def test_intermixed_prior_graph(self):
     # Force graph mode
     with tf.compat.v1.Graph().as_default():
         spec = phoenix_spec_pb2.PhoenixSpec(
             problem_type=phoenix_spec_pb2.PhoenixSpec.DNN)
         spec.ensemble_spec.ensemble_search_type = (
             ensembling_spec_pb2.EnsemblingSpec.
             INTERMIXED_NONADAPTIVE_ENSEMBLE_SEARCH)
         spec.ensemble_spec.intermixed_search.width = 2
         spec.ensemble_spec.intermixed_search.try_ensembling_every = 4
         spec.ensemble_spec.intermixed_search.num_trials_to_consider = 3
         spec.is_input_shared = True
         generator = prior_generator.PriorGenerator(
             phoenix_spec=spec,
             metadata=ml_metadata_db.MLMetaData(phoenix_spec=spec,
                                                study_name='',
                                                study_owner=''))
         fake_config = collections.namedtuple('RunConfig', ['model_dir'])
         # Should be multplication of 4.
         run_config = fake_config(model_dir=flags.FLAGS.test_tmpdir +
                                  '/10000')
         tf.io.gfile.makedirs(run_config.model_dir)
         # Best three trials checkpoint are generated. If the generator chooses
         # the suboptimal (wrong) trials, the test will fail.
         self._create_checkpoint(['search_generator'], 2)
         self._create_checkpoint(['search_generator'], 3)
         self._create_checkpoint(['search_generator'], 5)
         logits, _ = generator.first_time_chief_generate(
             features={},
             input_layer_fn=lambda: None,
             trial_mode=trial_utils.TrialMode.ENSEMBLE_SEARCH,
             shared_input_tensor=tf.zeros([100, 32, 32, 3]),
             shared_lengths=None,
             logits_dimension=10,
             hparams={},
             run_config=run_config,
             is_training=True,
             trials=trial_utils.create_test_trials_intermixed(
                 flags.FLAGS.test_tmpdir))
         self.assertLen(logits, 2)
         all_nodes = [
             node.name for node in
             tf.compat.v1.get_default_graph().as_graph_def().node
         ]
         self.assertAllInSet(_NONADAPTIVE_GRAPH_NODES, all_nodes)
Exemplo n.º 2
0
 def test_nonadaptive_prior(self, width, consider):
     # Force graph mode
     with tf.compat.v1.Graph().as_default():
         spec = phoenix_spec_pb2.PhoenixSpec(
             problem_type=phoenix_spec_pb2.PhoenixSpec.DNN)
         spec.ensemble_spec.ensemble_search_type = (
             ensembling_spec_pb2.EnsemblingSpec.NONADAPTIVE_ENSEMBLE_SEARCH)
         spec.ensemble_spec.nonadaptive_search.width = width
         spec.ensemble_spec.nonadaptive_search.num_trials_to_consider = consider
         spec.is_input_shared = True
         generator = prior_generator.PriorGenerator(
             phoenix_spec=spec,
             metadata=ml_metadata_db.MLMetaData(phoenix_spec=spec,
                                                study_name='',
                                                study_owner=''))
         fake_config = collections.namedtuple('RunConfig', ['model_dir'])
         run_config = fake_config(model_dir=flags.FLAGS.test_tmpdir +
                                  '/10000')
         tf.io.gfile.makedirs(run_config.model_dir)
         # Best three trials checkpoint are generated. If the generator chooses
         # the suboptimal (wrong) trials, the test will fail.
         self._create_checkpoint(['search_generator'], 3)
         self._create_checkpoint(['search_generator'], 4)
         self._create_checkpoint(['search_generator'], 5)
         logits, _ = generator.first_time_chief_generate(
             features={},
             input_layer_fn=lambda: None,
             trial_mode=trial_utils.TrialMode.ENSEMBLE_SEARCH,
             shared_input_tensor=tf.zeros([100, 32, 32, 3]),
             shared_lengths=None,
             logits_dimension=10,
             hparams={},
             run_config=run_config,
             is_training=True,
             trials=_create_trials(flags.FLAGS.test_tmpdir))
         self.assertLen(logits, min(width, consider))