Ejemplo n.º 1
0
    def testFromPath(self):
        self.run_test_exp()
        analysis = ExperimentAnalysis(self.test_path)

        self.assertTrue(analysis.get_best_trial(metric=self.metric, mode="max"))

        ray.shutdown()
        ray.tune.registry._global_registry = ray.tune.registry._Registry(
            prefix="global"
        )

        analysis = ExperimentAnalysis(self.test_path)

        # This will be None if validate_trainable during loading fails
        self.assertTrue(analysis.get_best_trial(metric=self.metric, mode="max"))
Ejemplo n.º 2
0
def load_best_model(exp: tune.ExperimentAnalysis,
                    metric="val_loss",
                    mode="min"):
    """Loads best checkpoint overall"""
    tr = exp.get_best_trial(metric=metric, mode=mode, scope="all")
    chkpoint = exp.get_best_checkpoint(tr, metric=metric, mode=mode)
    model = CLIPFineTunedModel.load_from_checkpoint(chkpoint + "/checkpoint")
    return model
Ejemplo n.º 3
0
def get_best_trial(analysis: tune.ExperimentAnalysis, objective: Objective, scope: str):
    return analysis.get_best_trial(
        full_metric_name(objective), mode=objective.mode, scope=scope
    )