Ejemplo n.º 1
0
class MNISTExperimentTune(tune.Trainable):
    """Ray tune trainable wrapping MNISTSparseExperiment."""
    def _setup(self, config):
        self.experiment = MNISTSparseExperiment(config)

    def _train(self):
        self.experiment.train(self._iteration)
        return self.experiment.test()

    def _save(self, checkpoint_dir):
        return self.experiment.save(checkpoint_dir)

    def _restore(self, checkpoint_dir):
        self.experiment.restore(checkpoint_dir)
Ejemplo n.º 2
0
def run_noise_test(config):
    """Run noise test on the best scoring model found during training. Make
    sure to train the models before calling this function.

    :param config: The configuration of the pre-trained model.
    :return: dict with noise test results over all experiments
    """
    # Load experiment data
    name = config["name"]
    experiment_path = os.path.join(config["path"], name)
    experiment_state = load_ray_tune_experiment(
        experiment_path=experiment_path, load_results=True)

    # Go through all checkpoints in the experiment
    all_checkpoints = experiment_state["checkpoints"]
    for checkpoint in all_checkpoints:
        results = checkpoint["results"]
        if results is None:
            continue

        # For each checkpoint select the epoch with the best accuracy as the best epoch
        best_result = max(results, key=lambda x: x["mean_accuracy"])
        best_epoch = best_result["training_iteration"]

        # Load pre-trained model from checkpoint and run noise test on it
        logdir = os.path.join(experiment_path,
                              os.path.basename(checkpoint["logdir"]))
        checkpoint_path = os.path.join(logdir,
                                       "checkpoint_{}".format(best_epoch))

        # Get the actual config from the saved version (required for sample
        # or grid search experiments). Replace paths to be the locally correct ones
        filename = os.path.join(logdir, "params.json")
        with open(filename, "r") as f:
            saved_params = json.load(f)
        saved_params["data_dir"] = config["data_dir"]
        saved_params["path"] = config["path"]

        experiment = MNISTSparseExperiment(saved_params)
        experiment.restore(checkpoint_path)

        # Save noise results in checkpoint log dir
        noise_test = os.path.join(logdir, "noise.json")
        with open(noise_test, "w") as f:
            json.dump(experiment.run_noise_tests(), f)