コード例 #1
0
ファイル: test_api.py プロジェクト: magedhelmy1/ray
    def testTrialPlateauStopper(self):
        def train(config):
            tune.report(10.0)
            tune.report(11.0)
            tune.report(12.0)
            for i in range(10):
                tune.report(20.0)

        # num_results = 4, no other constraints --> early stop after 7
        stopper = TrialPlateauStopper(metric="_metric", num_results=4)

        out = tune.run(train, stop=stopper)
        self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 7)

        # num_results = 4, grace period 9 --> early stop after 9
        stopper = TrialPlateauStopper(metric="_metric",
                                      num_results=4,
                                      grace_period=9)

        out = tune.run(train, stop=stopper)
        self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 9)

        # num_results = 4, min_metric = 22 --> full 13 iterations
        stopper = TrialPlateauStopper(metric="_metric",
                                      num_results=4,
                                      metric_threshold=22.0,
                                      mode="max")

        out = tune.run(train, stop=stopper)
        self.assertEqual(out.trials[0].last_result[TRAINING_ITERATION], 13)
コード例 #2
0
    def test_plateau(self):
        try:
            from ray.tune.stopper import TrialPlateauStopper
        except ImportError:
            self.skipTest("`TrialPlateauStopper` not available in "
                          "current Ray version.")
            return

        X, y = make_classification(n_samples=50,
                                   n_features=50,
                                   n_informative=3,
                                   random_state=0)

        clf = PlateauClassifier(converge_after=4)

        stopper = TrialPlateauStopper(metric="objective")

        search = TuneGridSearchCV(clf, {"foo_param": [2.0, 3.0, 4.0]},
                                  cv=2,
                                  max_iters=20,
                                  stopper=stopper,
                                  early_stopping=True)

        search.fit(X, y)

        print(search.cv_results_)

        for iters in search.cv_results_["training_iteration"]:
            # Converges after 4 iterations, but the stopper needs another
            # 4 to detect it converged.
            self.assertLessEqual(iters, 8)