示例#1
0
def test_bayesian_oracle_maximize(tmp_dir):
    hps = hp_module.HyperParameters()
    hps.Int('a', -100, 100)

    oracle = bo_module.BayesianOptimizationOracle(objective=kt.Objective(
        'score', direction='max'),
                                                  max_trials=20,
                                                  hyperparameters=hps,
                                                  num_initial_points=2)
    oracle._set_project_dir(tmp_dir, 'untitled')

    # Make examples with high 'a' and high score.
    for i in range(5):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values['a'] = 10 * i
        trial.score = i
        trial.status = 'COMPLETED'
        oracle.trials[trial.trial_id] = trial

    # Make examples with low 'a' and low score
    for i in range(5):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values['a'] = -10 * i
        trial.score = -i
        trial.status = 'COMPLETED'
        oracle.trials[trial.trial_id] = trial

    trial = oracle.create_trial('tuner0')
    assert trial.status == 'RUNNING'
    # Assert that the oracle suggests hps it thinks will maximize.
    assert trial.hyperparameters.get('a') > 0
示例#2
0
def test_hyperband_tuner(patch_fit, patch_load, tmp_dir):
    x = np.random.rand(10, 2, 2).astype('float32')
    y = np.random.randint(0, 1, (10, ))
    val_x = np.random.rand(10, 2, 2).astype('float32')
    val_y = np.random.randint(0, 1, (10, ))

    tuner = hyperband_module.Hyperband(build_model,
                                       objective='val_accuracy',
                                       max_trials=15,
                                       factor=2,
                                       min_epochs=1,
                                       max_epochs=2,
                                       directory=tmp_dir)

    hp = hyperparameters.HyperParameters()
    history_trial = trial_module.Trial(hyperparameters=hp.copy())
    history_trial.score = 1
    history_trial.best_step = 0
    hp.values['tuner/epochs'] = 10
    hp.values['tuner/trial_id'] = history_trial.trial_id
    tuner.oracle.trials[history_trial.trial_id] = history_trial

    trial = trial_module.Trial(hyperparameters=hp)
    tuner.oracle.trials[trial.trial_id] = trial
    tuner.run_trial(trial, x=x, y=y, epochs=1, validation_data=(val_x, val_y))
    assert patch_fit.called
    assert patch_load.called
示例#3
0
def test_bayesian_oracle_maximize(tmp_dir):
    hps = hp_module.HyperParameters()
    hps.Int("a", -100, 100)

    oracle = bo_module.BayesianOptimizationOracle(
        objective=kt.Objective("score", direction="max"),
        max_trials=20,
        hyperparameters=hps,
        num_initial_points=2,
    )
    oracle._set_project_dir(tmp_dir, "untitled")

    # Make examples with high 'a' and high score.
    for i in range(5):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values["a"] = 10 * i
        trial.score = i
        trial.status = "COMPLETED"
        oracle.trials[trial.trial_id] = trial

    # Make examples with low 'a' and low score
    for i in range(5):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values["a"] = -10 * i
        trial.score = -i
        trial.status = "COMPLETED"
        oracle.trials[trial.trial_id] = trial

    trial = oracle.create_trial("tuner0")
    assert trial.status == "RUNNING"
    # Assert that the oracle suggests hps it thinks will maximize.
    assert trial.hyperparameters.get("a") > 0
示例#4
0
def test_hyperparameters_added(tmp_dir):
    hps = hp_module.HyperParameters()
    hps.Int("a", -100, 100)

    oracle = bo_module.BayesianOptimizationOracle(
        objective=kt.Objective("score", direction="max"),
        max_trials=20,
        hyperparameters=hps,
        num_initial_points=2,
    )
    oracle._set_project_dir(tmp_dir, "untitled")

    # Populate initial trials.
    for i in range(10):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values["a"] = 10 * i
        trial.score = i
        trial.status = "COMPLETED"
        oracle.trials[trial.trial_id] = trial

    # Update the space.
    new_hps = hp_module.HyperParameters()
    new_hps.Float("b", 3.2, 6.4, step=0.2, default=3.6)
    new_hps.Boolean("c", default=True)
    oracle.update_space(new_hps)

    # Make a new trial, it should have b set.
    trial = oracle.create_trial("tuner0")
    assert trial.status == "RUNNING"
    assert "b" in trial.hyperparameters.values
    assert "c" in trial.hyperparameters.values
示例#5
0
def convert_completed_vizier_trial_to_keras_trial(
    vizier_trial: Dict[Text, Any],
    hyperparameter_space: hp_module.HyperParameters,
) -> trial_module.Trial:
    """Converts completed Vizier Trial into KerasTuner Trial.

    Args:
        vizier_trial: A Vizier Trial Instance.
        hyperparameter_space: Mandatory and must include definitions for all
            hyperparameters used during the search.

    Returns:
        A KerasTuner Trial.
    """
    kerastuner_trial = trial_module.Trial(
        hyperparameters=convert_vizier_trial_to_hps(hyperparameter_space,
                                                    vizier_trial),
        trial_id=get_trial_id(vizier_trial),
        status=trial_module.TrialStatus.COMPLETED,
    )
    # If trial had ended before having intermediate metric reporting,
    # set stepCount = 0.
    final_measurement = vizier_trial.get("finalMeasurement")
    if not final_measurement:
        raise ValueError(
            '"finalMeasurement" not found in this trial {}'.format(
                vizier_trial))

    kerastuner_trial.best_step = int(final_measurement.get("stepCount", 0))
    kerastuner_trial.score = final_measurement["metrics"][0].get("value")
    return kerastuner_trial
示例#6
0
def test_hyperband_tuner(patch_fit, patch_load, tmp_dir):
    x = np.random.rand(10, 2, 2).astype('float32')
    y = np.random.randint(0, 1, (10, ))
    val_x = np.random.rand(10, 2, 2).astype('float32')
    val_y = np.random.randint(0, 1, (10, ))

    tuner = HyperbandStub(build_model,
                          objective='val_accuracy',
                          max_trials=15,
                          factor=2,
                          min_epochs=1,
                          max_epochs=2,
                          executions_per_trial=3,
                          directory=tmp_dir)

    hp = hyperparameters.HyperParameters()
    hp.values['tuner/epochs'] = 10
    trial_id = '1'
    hp.values['tuner/trial_id'] = trial_id

    tuner.run_trial(
        trial_module.Trial(trial_id, hp, 5, base_directory=tmp_dir), hp, [], {
            'x': x,
            'y': y,
            'epochs': 1,
            'validation_data': (val_x, val_y)
        })
    assert patch_fit.called
    assert patch_load.called
示例#7
0
    def test_end_trial_success(self):
        self._tuner_with_hparams()
        self.mock_client.complete_trial.return_value = {
            "name": "1",
            "state": "COMPLETED",
            "parameters": [{"parameter": "learning_rate", "floatValue": 0.01}],
            "finalMeasurement": {
                "stepCount": "3",
                "metrics": [{"metric": "val_acc", "value": 0.7}],
            },
            "trial_infeasible": False,
            "infeasible_reason": None,
        }
        mock_save_trial = mock.Mock()
        self.tuner.oracle._save_trial = mock_save_trial
        self.tuner.oracle.ongoing_trials = {"tuner_0": self._test_trial}
        expected_trial = trial_module.Trial(
            hyperparameters=self._test_hyperparameters,
            trial_id="1",
            status=trial_module.TrialStatus.COMPLETED,
        )
        expected_trial.best_step = 3
        expected_trial.score = 0.7

        self.tuner.oracle.end_trial(trial_id="1")

        self.mock_client.complete_trial.assert_called_once_with(
            "1", False, None)
        self.assertEqual(repr(mock_save_trial.call_args[0][0].get_state()),
                         repr(expected_trial.get_state()))
示例#8
0
def test_trial_proto():
    hps = hp_module.HyperParameters()
    hps.Int('a', 0, 10, default=3)
    trial = trial_module.Trial(hps, trial_id='trial1', status='COMPLETED')
    trial.metrics.register('score', direction='max')
    trial.metrics.update('score', 10, step=1)

    proto = trial.to_proto()
    assert len(proto.hyperparameters.space.int_space) == 1
    assert proto.hyperparameters.values.values['a'].int_value == 3
    assert not proto.HasField('score')

    new_trial = trial_module.Trial.from_proto(proto)
    assert new_trial.status == 'COMPLETED'
    assert new_trial.hyperparameters.get('a') == 3
    assert new_trial.trial_id == 'trial1'
    assert new_trial.score is None
    assert new_trial.best_step is None

    trial.score = -10
    trial.best_step = 3

    proto = trial.to_proto()
    assert proto.HasField('score')
    assert proto.score.value == -10
    assert proto.score.step == 3

    new_trial = trial_module.Trial.from_proto(proto)
    assert new_trial.score == -10
    assert new_trial.best_step == 3
    assert new_trial.metrics.get_history('score') == [
        metrics_tracking.MetricObservation(10, step=1)
    ]
示例#9
0
def test_trial_proto():
    hps = hp_module.HyperParameters()
    hps.Int("a", 0, 10, default=3)
    trial = trial_module.Trial(hps, trial_id="trial1", status="COMPLETED")
    trial.metrics.register("score", direction="max")
    trial.metrics.update("score", 10, step=1)

    proto = trial.to_proto()
    assert len(proto.hyperparameters.space.int_space) == 1
    assert proto.hyperparameters.values.values["a"].int_value == 3
    assert not proto.HasField("score")

    new_trial = trial_module.Trial.from_proto(proto)
    assert new_trial.status == "COMPLETED"
    assert new_trial.hyperparameters.get("a") == 3
    assert new_trial.trial_id == "trial1"
    assert new_trial.score is None
    assert new_trial.best_step is None

    trial.score = -10
    trial.best_step = 3

    proto = trial.to_proto()
    assert proto.HasField("score")
    assert proto.score.value == -10
    assert proto.score.step == 3

    new_trial = trial_module.Trial.from_proto(proto)
    assert new_trial.score == -10
    assert new_trial.best_step == 3
    assert new_trial.metrics.get_history("score") == [
        metrics_tracking.MetricObservation(10, step=1)
    ]
示例#10
0
def test_hyperparameters_added(tmp_dir):
    hps = hp_module.HyperParameters()
    hps.Int('a', -100, 100)

    oracle = bo_module.BayesianOptimizationOracle(objective=kt.Objective(
        'score', direction='max'),
                                                  max_trials=20,
                                                  hyperparameters=hps,
                                                  num_initial_points=2)
    oracle._set_project_dir(tmp_dir, 'untitled')

    # Populate initial trials.
    for i in range(10):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values['a'] = 10 * i
        trial.score = i
        trial.status = 'COMPLETED'
        oracle.trials[trial.trial_id] = trial

    # Update the space.
    new_hps = hp_module.HyperParameters()
    new_hps.Float('b', 3.2, 6.4, step=0.2, default=3.6)
    new_hps.Boolean('c', default=True)
    oracle.update_space(new_hps)

    # Make a new trial, it should have b set.
    trial = oracle.create_trial('tuner0')
    assert trial.status == 'RUNNING'
    assert 'b' in trial.hyperparameters.values
    assert 'c' in trial.hyperparameters.values
示例#11
0
    def setUp(self):
        super(CloudTunerTest, self).setUp()
        self.addCleanup(mock.patch.stopall)

        self._study_id = "study-a"
        self._region = "us-central1"
        self._remote_dir = "gs://remote_dir"
        self._project_id = "project-a"
        # TODO(b/170687807) Switch from using "{}".format() to f-string
        self._trial_parent = "projects/{}/locations/{}/studies/{}".format(
            self._project_id, self._region, self._study_id)
        self._container_uri = "test_container_uri",
        hps = hp_module.HyperParameters()
        hps.Choice("learning_rate", [1e-4, 1e-3, 1e-2])
        self._test_hyperparameters = hps

        self._study_config = {
            "algorithm":
            "ALGORITHM_UNSPECIFIED",
            "metrics": [{
                "metric": "val_acc",
                "goal": "MAXIMIZE"
            }],
            "parameters": [{
                "parameter": "learning_rate",
                "discrete_value_spec": {
                    "values": [0.0001, 0.001, 0.01]
                },
                "type": "DISCRETE",
            }],
            "automatedStoppingConfig": {
                "decayCurveStoppingConfig": {
                    "useElapsedTime": True
                }
            },
        }

        self._test_trial = trial_module.Trial(
            hyperparameters=self._test_hyperparameters,
            trial_id="1",
            status=trial_module.TrialStatus,
        )
        # TODO(b/170687807) Switch from using "{}".format() to f-string
        self._job_id = "{}_{}".format(self._study_id,
                                      self._test_trial.trial_id)
        self.mock_optimizer_client_module = mock.patch.object(
            tuner, "optimizer_client", autospec=True).start()

        self.mock_client = mock.create_autospec(
            optimizer_client._OptimizerClient)
        self.mock_optimizer_client_module.create_or_load_study.return_value = (
            self.mock_client)
示例#12
0
 def __init__(self, hypermodel, objective, max_trials, **kwargs):
     super().__init__(hypermodel, objective, max_trials, **kwargs)
     hp = hyperparameters.HyperParameters()
     trial = trial_module.Trial('1', hp, 5, base_directory=self.directory)
     trial.executions = [
         execution_module.Execution('a',
                                    'b',
                                    1,
                                    3,
                                    base_directory=self.directory)
     ]
     trial.executions[0].best_checkpoint = 'x'
     self.trials = [trial]
示例#13
0
    def get_best_trials(self, num_trials: int = 1) -> List[trial_module.Trial]:
        """Returns the trials with the best objective values found so far.

        Arguments:
            num_trials: positive int, number of trials to return.
        Returns:
            List of KerasTuner Trials.
        """
        if len(self.objective) > 1:
            raise ValueError(
                "Getting the best trials for multi-objective optimization "
                "is not supported. "
            )

        maximizing = (
            utils.format_goal(self.objective[0].direction) == "MAXIMIZE")

        # List all trials associated with the same study
        trial_list = self.service.list_trials()

        optimizer_trials = [t for t in trial_list if t["state"] == "COMPLETED"]

        if not optimizer_trials:
            return []

        sorted_trials = sorted(
            optimizer_trials,
            key=lambda t: t["finalMeasurement"]["metrics"][0]["value"],
            reverse=maximizing,
        )
        best_optimizer_trials = sorted_trials[:num_trials]

        best_trials = []
        # Convert Optimizer trials to KerasTuner Trial instance
        for optimizer_trial in best_optimizer_trials:
            final_measurement = optimizer_trial["finalMeasurement"]
            kerastuner_trial = trial_module.Trial(
                hyperparameters=utils.convert_optimizer_trial_to_hps(
                    self.hyperparameters.copy(), optimizer_trial
                ),
                trial_id=utils.get_trial_id(optimizer_trial),
                status=trial_module.TrialStatus.COMPLETED,
            )
            # If trial had ended before having intermediate metric reporting,
            # set epoch = 1.
            kerastuner_trial.best_step = final_measurement.get("stepCount", 1)
            kerastuner_trial.score = final_measurement["metrics"][0]["value"]
            best_trials.append(kerastuner_trial)
        return best_trials
    def setUp(self):
        super(CloudTunerTest, self).setUp()
        self.addCleanup(mock.patch.stopall)

        self._study_id = 'study-a'
        self._region = 'us-central1'
        self._project_id = 'project-a'
        self._trial_parent = 'projects/{}/locations/{}/studies/{}'.format(
            self._project_id, self._region,
            'CloudTuner_study_{}'.format(self._study_id))

        hps = hp_module.HyperParameters()
        hps.Choice('learning_rate', [1e-4, 1e-3, 1e-2])
        self._test_hyperparameters = hps

        self._study_config = {
            'algorithm':
            'ALGORITHM_UNSPECIFIED',
            'metrics': [{
                'metric': 'val_acc',
                'goal': 'MAXIMIZE'
            }],
            'parameters': [{
                'parameter': 'learning_rate',
                'discrete_value_spec': {
                    'values': [0.0001, 0.001, 0.01]
                },
                'type': 'DISCRETE'
            }],
            'automatedStoppingConfig': {
                'decayCurveStoppingConfig': {
                    'useElapsedTime': True
                }
            }
        }

        self._test_trial = trial_module.Trial(
            hyperparameters=self._test_hyperparameters,
            trial_id='1',
            status=trial_module.TrialStatus)

        self.mock_optimizer_client_module = mock.patch.object(
            cloud_tuner, 'optimizer_client', autospec=True).start()

        self.mock_client = mock.create_autospec(
            optimizer_client._OptimizerClient)
        self.mock_optimizer_client_module.create_or_load_study.return_value = (
            self.mock_client)
示例#15
0
    def get_best_trials(self, num_trials=1):
        """Returns the trials with the best objective values found so far.

    Arguments:
      num_trials: positive int, number of trials to return.

    Returns:
      List of KerasTuner Trials.
    """
        if len(self.objective) > 1:
            raise ValueError(
                'Getting the best trials for multi-objective optimization '
                'is not supported. ')

        maximizing = cloud_tuner_utils.format_goal(
            self.objective[0].direction) == 'MAXIMIZE'

        # List all trials associated with the same study
        trial_list = self.service.list_trials()

        optimizer_trials = [t for t in trial_list if t['state'] == 'COMPLETED']

        if not optimizer_trials:
            return []

        sorted_trials = sorted(
            optimizer_trials,
            key=lambda t: t['finalMeasurement']['metrics'][0]['value'],
            reverse=maximizing)
        best_optimizer_trials = sorted_trials[:num_trials]

        best_trials = []
        # Convert Optimizer trials to KerasTuner Trial instance
        for optimizer_trial in best_optimizer_trials:
            final_measurement = optimizer_trial['finalMeasurement']
            kerastuner_trial = trial_module.Trial(
                hyperparameters=cloud_tuner_utils.
                convert_optimizer_trial_to_hps(self.hyperparameters.copy(),
                                               optimizer_trial),
                trial_id=cloud_tuner_utils.get_trial_id(optimizer_trial),
                status=trial_module.TrialStatus.COMPLETED)
            # If trial had ended before having intermediate metric reporting, set
            # epoch = 1.
            kerastuner_trial.best_step = final_measurement.get('stepCount', 1)
            kerastuner_trial.score = final_measurement['metrics'][0]['value']
            best_trials.append(kerastuner_trial)
        return best_trials
示例#16
0
def test_step_respected(tmp_dir):
    hps = hp_module.HyperParameters()
    hps.Float('c', 0, 10, step=3)
    oracle = bo_module.BayesianOptimizationOracle(objective=kt.Objective(
        'score', direction='max'),
                                                  max_trials=20,
                                                  hyperparameters=hps,
                                                  num_initial_points=2)
    oracle._set_project_dir(tmp_dir, 'untitled')

    # Populate initial trials.
    for i in range(10):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values['c'] = 3.
        trial.score = i
        trial.status = 'COMPLETED'
        oracle.trials[trial.trial_id] = trial

    trial = oracle.create_trial('tuner0')
    # Check that oracle respects the `step` param.
    assert trial.hyperparameters.get('c') in {0, 3, 6, 9}
示例#17
0
def test_step_respected(tmp_dir):
    hps = hp_module.HyperParameters()
    hps.Float("c", 0, 10, step=3)
    oracle = bo_module.BayesianOptimizationOracle(
        objective=kt.Objective("score", direction="max"),
        max_trials=20,
        hyperparameters=hps,
        num_initial_points=2,
    )
    oracle._set_project_dir(tmp_dir, "untitled")

    # Populate initial trials.
    for i in range(10):
        trial = trial_module.Trial(hyperparameters=hps.copy())
        trial.hyperparameters.values["c"] = 3.0
        trial.score = i
        trial.status = "COMPLETED"
        oracle.trials[trial.trial_id] = trial

    trial = oracle.create_trial("tuner0")
    # Check that oracle respects the `step` param.
    assert trial.hyperparameters.get("c") in {0, 3, 6, 9}
示例#18
0
    def create_trial(self, tuner_id: Text) -> trial_module.Trial:
        """Create a new `Trial` to be run by the `Tuner`.

        Args:
            tuner_id: An ID that identifies the `Tuner` requesting a `Trial`.
                `Tuners` that should run the same trial (for instance, when
                running a multi-worker model) should have the same ID. If
                multiple suggestTrialsRequests have the same tuner_id, the
                service will return the identical suggested trial if the trial
                is PENDING, and provide a new trial if the last suggested trial
                was completed.

        Returns:
            A `Trial` object containing a set of hyperparameter values to run
            in a `Tuner`.

        Raises:
            SuggestionInactiveError: Indicates that a suggestion was requested
                from an inactive study.
        """
        # List all trials from the same study and see if any
        # trial.status=STOPPED or if number of trials >= max_limit.
        trial_list = self.service.list_trials()
        # Note that KerasTunerTrialStatus - 'STOPPED' is equivalent to
        # OptimizerTrialState - 'STOPPING'.
        stopping_trials = [t for t in trial_list if t["state"] == "STOPPING"]
        if (self.max_trials and
            len(trial_list) >= self.max_trials) or stopping_trials:
            trial_id = "n"
            hyperparameters = self.hyperparameters.copy()
            hyperparameters.values = {}
            # This will break the search loop later.
            return trial_module.Trial(
                hyperparameters=hyperparameters,
                trial_id=trial_id,
                status=trial_module.TrialStatus.STOPPED,
            )

        # Get suggestions
        suggestions = self.service.get_suggestions(tuner_id)

        if "trials" not in suggestions:
            return trial_module.Trial(
                hyperparameters={}, status=trial_module.TrialStatus.STOPPED
            )

        # Fetches the suggested trial.
        # Optimizer Trial instance
        optimizer_trial = suggestions["trials"][0]
        trial_id = utils.get_trial_id(optimizer_trial)

        # KerasTuner Trial instance
        kerastuner_trial = trial_module.Trial(
            hyperparameters=utils.convert_optimizer_trial_to_hps(
                self.hyperparameters.copy(), optimizer_trial
            ),
            trial_id=trial_id,
            status=trial_module.TrialStatus.RUNNING,
        )

        tf.get_logger().info(
            "Hyperparameters requested by tuner ({}): {} ".format(
                tuner_id, kerastuner_trial.hyperparameters.values
            )
        )

        self._start_time = time.time()
        self.trials[trial_id] = kerastuner_trial
        self.ongoing_trials[tuner_id] = kerastuner_trial
        self._save_trial(kerastuner_trial)
        self.save()
        return kerastuner_trial
示例#19
0
    def test_get_best_trials_multi_tuners(self):
        # Instantiate tuner_1
        tuner_1 = self._tuner(
            objective=oracle_module.Objective("val_acc", "max"),
            hyperparameters=self._test_hyperparameters,
            study_config=None,
        )
        tuner_1.tuner_id = "tuner_1"
        # tuner_1 has a completed trial
        trial_1 = trial_module.Trial(
            hyperparameters=self._test_hyperparameters,
            trial_id="1",
            status=trial_module.TrialStatus.COMPLETED,
        )
        tuner_1.oracle.trials = {"1": trial_1}

        # Instantiate tuner_2
        tuner_2 = self._tuner(
            objective=oracle_module.Objective("val_acc", "max"),
            hyperparameters=self._test_hyperparameters,
            study_config=None,
        )
        tuner_2.tuner_id = "tuner_2"
        # tuner_2 has a completed trial
        trial_2 = trial_module.Trial(
            hyperparameters=self._test_hyperparameters,
            trial_id="2",
            status=trial_module.TrialStatus.COMPLETED,
        )
        tuner_2.oracle.trials = {"2": trial_2}

        self.mock_client.list_trials.return_value = [
            {
                "name": "1",
                "state": "COMPLETED",
                "parameters": [{
                    "parameter": "learning_rate",
                    "floatValue": 0.01
                }],
                "finalMeasurement": {
                    "stepCount": "3",
                    "metrics": [{
                        "metric": "val_acc",
                        "value": 0.7
                    }],
                },
                "trial_infeasible": False,
                "infeasible_reason": None,
            },
            {
                "name": "2",
                "state": "COMPLETED",
                "parameters": [{
                    "parameter": "learning_rate",
                    "floatValue": 0.001
                }],
                "finalMeasurement": {
                    "stepCount": "3",
                    "metrics": [{
                        "metric": "val_acc",
                        "value": 0.9
                    }],
                },
                "trial_infeasible": False,
                "infeasible_reason": None,
            },
        ]

        # For any tuner worker who tries to get the best trials, all the top N
        # sorted trials will be returned.
        best_trials_1 = tuner_1.oracle.get_best_trials(num_trials=2)
        self.mock_client.list_trials.assert_called_once()

        best_trials_2 = tuner_2.oracle.get_best_trials(num_trials=2)

        self.assertEqual(len(best_trials_1), 2)
        self.assertEqual(best_trials_1[0].trial_id, best_trials_2[0].trial_id)
        self.assertEqual(best_trials_1[1].trial_id, best_trials_2[1].trial_id)
        self.assertEqual(best_trials_1[0].score, 0.9)
        self.assertEqual(best_trials_1[0].best_step, 3)
  def test_get_best_trials_multi_tuners(self):
    # Instantiate tuner_1
    tuner_1 = self._tuner(
        objective=oracle_module.Objective('val_acc', 'max'),
        hyperparameters=self._test_hyperparameters,
        study_config=None)
    tuner_1.tuner_id = 'tuner_1'
    # tuner_1 has a completed trial
    trial_1 = trial_module.Trial(
        hyperparameters=self._test_hyperparameters,
        trial_id='1',
        status=trial_module.TrialStatus.COMPLETED)
    tuner_1.oracle.trials = {'1': trial_1}

    # Instantiate tuner_2
    tuner_2 = self._tuner(
        objective=oracle_module.Objective('val_acc', 'max'),
        hyperparameters=self._test_hyperparameters,
        study_config=None)
    tuner_2.tuner_id = 'tuner_2'
    # tuner_2 has a completed trial
    trial_2 = trial_module.Trial(
        hyperparameters=self._test_hyperparameters,
        trial_id='2',
        status=trial_module.TrialStatus.COMPLETED)
    tuner_2.oracle.trials = {'2': trial_2}

    self.mock_client.list_trials.return_value = [{
        'name': '1',
        'state': 'COMPLETED',
        'parameters': [{
            'parameter': 'learning_rate',
            'floatValue': 0.01
        }],
        'finalMeasurement': {
            'stepCount': 3,
            'metrics': [{
                'metric': 'val_acc',
                'value': 0.7
            }]
        },
        'trial_infeasible': False,
        'infeasible_reason': None
    }, {
        'name': '2',
        'state': 'COMPLETED',
        'parameters': [{
            'parameter': 'learning_rate',
            'floatValue': 0.001
        }],
        'finalMeasurement': {
            'stepCount': 3,
            'metrics': [{
                'metric': 'val_acc',
                'value': 0.9
            }]
        },
        'trial_infeasible': False,
        'infeasible_reason': None
    }]

    # For any tuner worker who tries to get the best trials, all the top N
    # sorted trials will be returned.
    best_trials_1 = tuner_1.oracle.get_best_trials(num_trials=2)
    self.mock_client.list_trials.assert_called_once()

    best_trials_2 = tuner_2.oracle.get_best_trials(num_trials=2)

    self.assertEqual(len(best_trials_1), 2)
    self.assertEqual(best_trials_1[0].trial_id, best_trials_2[0].trial_id)
    self.assertEqual(best_trials_1[1].trial_id, best_trials_2[1].trial_id)
    self.assertEqual(best_trials_1[0].score, 0.9)
    self.assertEqual(best_trials_1[0].best_step, 3)