예제 #1
0
    def test_suggest_duplicates_between_execution(self, monkeypatch,
                                                  hyperband: Hyperband,
                                                  budgets: list[BudgetTuple]):
        """Test that sampling collisions are handled between different hyperband execution."""
        hyperband.repetitions = 2
        bracket = HyperbandBracket(hyperband, budgets, 1)
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        for i in range(9):
            force_observe(hyperband, create_trial_for_hb((1, i), objective=i))

        for i in range(3):
            force_observe(hyperband, create_trial_for_hb((3, i), objective=i))

        force_observe(hyperband, create_trial_for_hb((9, 0), objective=0))

        assert not hyperband.is_done

        # lr:7 and lr:8 are already sampled in first repetition, they should not be present
        # in second repetition. Samples with lr:7 and lr:8 will be ignored.

        # (9, 0) already exists
        candidates_for_epoch_9_bracket = [(9, 0), (9, 2), (9, 3), (9, 10)]
        # (9, 1) -> (3, 1) already promoted in last repetition
        # (9, 3) sampled for previous bracket
        candidates_for_epoch_3_bracket = [(9, 1), (9, 3), (9, 4), (9, 5),
                                          (9, 11)]
        # (9, 0) -> (1, 0) already sampled in last repetition
        # (9, 8) -> (1, 8) already sampled in last repetition
        candidates_for_epoch_1_bracket = [(9, 0), (9, 8), (9, 12), (9, 13)]

        zhe_point = list(
            map(
                create_trial_for_hb,
                candidates_for_epoch_9_bracket +
                candidates_for_epoch_3_bracket +
                candidates_for_epoch_1_bracket,
            ))

        hyperband._refresh_brackets()
        mock_samples(hyperband, zhe_point)
        zhe_samples = hyperband.suggest(100)
        assert zhe_samples is not None
        assert len(zhe_samples) == 8
        assert zhe_samples[0].params == {"epoch": 9, "lr": 2}
        assert zhe_samples[1].params == {"epoch": 9, "lr": 3}
        assert zhe_samples[2].params == {"epoch": 9, "lr": 10}
        assert zhe_samples[3].params == {"epoch": 3, "lr": 4}
        assert zhe_samples[4].params == {"epoch": 3, "lr": 5}
        assert zhe_samples[5].params == {"epoch": 3, "lr": 11}
        assert zhe_samples[6].params == {"epoch": 1, "lr": 12}
        assert zhe_samples[7].params == {"epoch": 1, "lr": 13}
예제 #2
0
    def test_full_process(self, monkeypatch, hyperband: Hyperband):
        """Test Hyperband full process."""
        sample_trials = [
            create_trial_for_hb(("fidelity", i)) for i in range(100)
        ]

        hyperband._refresh_brackets()
        mock_samples(hyperband, copy.deepcopy(sample_trials))

        # Fill all brackets' first rung
        first_rung = hyperband.suggest(100)
        assert first_rung is not None
        first_bracket_first_rung = first_rung[6:]
        second_bracket_first_rung = first_rung[3:6]
        third_bracket_first_rung = first_rung[:3]

        compare_trials(
            first_bracket_first_rung,
            [create_trial_for_hb((1, i)) for i in range(6, 15)],
        )
        compare_trials(
            second_bracket_first_rung,
            [create_trial_for_hb((3, i)) for i in range(3, 6)],
        )
        compare_trials(third_bracket_first_rung,
                       [create_trial_for_hb((9, i)) for i in range(3)])
        assert hyperband.brackets is not None
        assert hyperband.brackets[0].has_rung_filled(0)
        assert not hyperband.brackets[0].is_ready()
        assert hyperband.suggest(100) == []
        assert hyperband.suggest(100) == []

        # Observe first bracket first rung
        for i, trial in enumerate(first_bracket_first_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=16 - i))
        hyperband.observe(first_bracket_first_rung)

        assert hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Promote first bracket first rung
        first_bracket_second_rung = hyperband.suggest(100)
        compare_trials(
            first_bracket_second_rung,
            [create_trial_for_hb((3, 3 + 3 + 9 - 1 - i)) for i in range(3)],
        )

        assert hyperband.brackets[0].has_rung_filled(1)
        assert not hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        assert first_bracket_second_rung is not None
        # Observe first bracket second rung
        for i, trial in enumerate(first_bracket_second_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=8 - i))
        hyperband.observe(first_bracket_second_rung)

        assert hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Promote first bracket second rung
        first_bracket_third_rung = hyperband.suggest(100)
        compare_trials(first_bracket_third_rung,
                       [create_trial_for_hb((9, 12))])

        assert hyperband.brackets[0].has_rung_filled(2)
        assert not hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Observe second bracket first rung
        for i, trial in enumerate(second_bracket_first_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=8 - i))
        hyperband.observe(second_bracket_first_rung)

        assert not hyperband.brackets[0].is_ready()
        assert hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Promote second bracket first rung
        second_bracket_second_rung = hyperband.suggest(100)
        compare_trials(second_bracket_second_rung,
                       [create_trial_for_hb((9, 5))])

        assert not hyperband.brackets[0].is_ready()
        assert hyperband.brackets[1].has_rung_filled(1)
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Observe third bracket first rung
        for i, trial in enumerate(third_bracket_first_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=3 - i))
        hyperband.observe(third_bracket_first_rung)

        assert not hyperband.brackets[0].is_ready(2)
        assert not hyperband.brackets[1].is_ready(1)
        assert hyperband.brackets[2].is_ready(0)
        assert hyperband.brackets[2].is_done

        assert second_bracket_second_rung is not None
        # Observe second bracket second rung
        for i, trial in enumerate(second_bracket_second_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=5 - i))
        hyperband.observe(second_bracket_second_rung)

        assert not hyperband.brackets[0].is_ready(2)
        assert hyperband.brackets[1].is_ready(1)
        assert hyperband.brackets[1].is_done
        assert first_bracket_third_rung is not None
        hyperband.observe(first_bracket_third_rung)

        assert hyperband.is_done
        assert hyperband.brackets[0].is_done
        assert hyperband.suggest(100) == []

        # Refresh repeat and execution times property
        monkeypatch.setattr(hyperband, "repetitions", 2)
        # monkeypatch.setattr(hyperband.brackets[0], "repetition_id", 0)
        # hyperband.observe([(9, 12)], [{"objective": 3 - i}])
        assert len(hyperband.brackets) == 3
        hyperband._refresh_brackets()
        assert len(hyperband.brackets) == 6
        mock_samples(hyperband,
                     copy.deepcopy(sample_trials[:3] + sample_trials))
        trials = hyperband.suggest(100)
        assert not hyperband.is_done
        assert not hyperband.brackets[3].is_ready(2)
        assert not hyperband.brackets[3].is_done
        assert trials is not None
        compare_trials(trials[:3],
                       map(create_trial_for_hb, [(9, 3), (9, 4), (9, 6)]))
        compare_trials(trials[3:6],
                       map(create_trial_for_hb, [(3, 7), (3, 8), (3, 9)]))
        compare_trials(trials[6:],
                       [create_trial_for_hb((1, i)) for i in range(15, 24)])