コード例 #1
0
 def _test_get_best_trials():
     hps = kt.HyperParameters()
     hps.Int("a", 0, 100, default=5)
     hps.Int("b", 0, 100, default=6)
     oracle = randomsearch.RandomSearchOracle(
         objective=kt.Objective("score", direction="max"),
         max_trials=10,
         hyperparameters=hps,
     )
     oracle._set_project_dir(tmp_dir, "untitled")
     tuner_id = os.environ["KERASTUNER_TUNER_ID"]
     if "chief" in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial_scores = {}
         for score in range(10):
             trial = client.create_trial(tuner_id)
             assert trial.status == "RUNNING"
             assert "a" in trial.hyperparameters.values
             assert "b" in trial.hyperparameters.values
             trial_id = trial.trial_id
             client.update_trial(trial_id, {"score": score})
             client.end_trial(trial_id)
             trial_scores[trial_id] = score
         return
         best_trials = client.get_best_trials(3)
         best_scores = [t.score for t in best_trials]
         assert best_scores == [9, 8, 7]
         # Check that trial_ids are correctly mapped to scores.
         for t in best_trials:
             assert trial_scores[t.trial_id] == t.score
コード例 #2
0
    def __init__(
        self,
        oracle,
        hypermodel=None,
        directory=None,
        project_name=None,
        logger=None,
        overwrite=False,
    ):
        # Ops and metadata
        self.directory = directory or "."
        self.project_name = project_name or "untitled_project"
        if overwrite and tf.io.gfile.exists(self.project_dir):
            tf.io.gfile.rmtree(self.project_dir)

        if not isinstance(oracle, oracle_module.Oracle):
            raise ValueError(
                "Expected `oracle` argument to be an instance of `Oracle`. "
                f"Received: oracle={oracle} (of type ({type(oracle)}).")
        self.oracle = oracle
        self.oracle._set_project_dir(self.directory,
                                     self.project_name,
                                     overwrite=overwrite)

        # Run in distributed mode.
        if dist_utils.is_chief_oracle():
            # Blocks forever.
            oracle_chief.start_server(self.oracle)
        elif dist_utils.has_chief_oracle():
            # Proxies requests to the chief oracle.
            self.oracle = oracle_client.OracleClient(self.oracle)

        # To support tuning distribution.
        self.tuner_id = os.environ.get("KERASTUNER_TUNER_ID", "tuner0")

        self.hypermodel = hm_module.get_hypermodel(hypermodel)

        # Logs etc
        self.logger = logger
        self._display = tuner_utils.Display(oracle=self.oracle)

        self._populate_initial_space()

        if not overwrite and tf.io.gfile.exists(self._get_tuner_fname()):
            tf.get_logger().info("Reloading Tuner from {}".format(
                self._get_tuner_fname()))
            self.reload()
コード例 #3
0
 def _test_get_space():
     hps = kt.HyperParameters()
     hps.Int("a", 0, 10, default=3)
     oracle = randomsearch.RandomSearchOracle(
         objective=kt.Objective("score", "max"),
         max_trials=10,
         hyperparameters=hps,
     )
     oracle._set_project_dir(tmp_dir, "untitled")
     tuner_id = os.environ["KERASTUNER_TUNER_ID"]
     if "chief" in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         retrieved_hps = client.get_space()
         assert retrieved_hps.values == {"a": 3}
         assert len(retrieved_hps.space) == 1
コード例 #4
0
    def _test_update_space():
        oracle = randomsearch.RandomSearchOracle(objective=kt.Objective(
            "score", "max"),
                                                 max_trials=10)
        oracle._set_project_dir(tmp_dir, "untitled")
        tuner_id = os.environ["KERASTUNER_TUNER_ID"]
        if "chief" in tuner_id:
            oracle_chief.start_server(oracle)
        else:
            client = oracle_client.OracleClient(oracle)

            hps = kt.HyperParameters()
            hps.Int("a", 0, 10, default=5)
            hps.Choice("b", [1, 2, 3])
            client.update_space(hps)

            retrieved_hps = client.get_space()
            assert len(retrieved_hps.space) == 2
            assert retrieved_hps.values["a"] == 5
            assert retrieved_hps.values["b"] == 1
コード例 #5
0
 def _test_end_trial():
     hps = kt.HyperParameters()
     hps.Int("a", 0, 10, default=5)
     oracle = randomsearch.RandomSearchOracle(
         objective=kt.Objective("score", "max"),
         max_trials=10,
         hyperparameters=hps,
     )
     oracle._set_project_dir(tmp_dir, "untitled")
     tuner_id = os.environ["KERASTUNER_TUNER_ID"]
     if "chief" in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial = client.create_trial(tuner_id)
         trial_id = trial.trial_id
         client.update_trial(trial_id, {"score": 1}, step=2)
         client.end_trial(trial_id, "INVALID")
         updated_trial = client.get_trial(trial_id)
         assert updated_trial.status == "INVALID"
コード例 #6
0
 def _test_create_trial():
     hps = kt.HyperParameters()
     hps.Int("a", 0, 10, default=5)
     hps.Choice("b", [1, 2, 3])
     oracle = randomsearch.RandomSearchOracle(
         objective=kt.Objective("score", "max"),
         max_trials=10,
         hyperparameters=hps,
     )
     oracle._set_project_dir(tmp_dir, "untitled")
     tuner_id = os.environ["KERASTUNER_TUNER_ID"]
     if "chief" in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial = client.create_trial(tuner_id)
         assert trial.status == "RUNNING"
         a = trial.hyperparameters.get("a")
         assert a >= 0 and a <= 10
         b = trial.hyperparameters.get("b")
         assert b in {1, 2, 3}
コード例 #7
0
 def _test_update_trial():
     hps = kt.HyperParameters()
     hps.Int("a", 0, 10, default=5)
     oracle = randomsearch.RandomSearchOracle(
         objective=kt.Objective("score", "max"),
         max_trials=10,
         hyperparameters=hps,
     )
     oracle._set_project_dir(tmp_dir, "untitled")
     tuner_id = os.environ["KERASTUNER_TUNER_ID"]
     if "chief" in tuner_id:
         oracle_chief.start_server(oracle)
     else:
         client = oracle_client.OracleClient(oracle)
         trial = client.create_trial(tuner_id)
         assert "score" not in trial.metrics.metrics
         trial_id = trial.trial_id
         client.update_trial(trial_id, {"score": 1}, step=2)
         updated_trial = client.get_trial(trial_id)
         assert updated_trial.metrics.get_history("score") == [
             metrics_tracking.MetricObservation([1], step=2)
         ]