def run_noise_test(config): """Run noise test on the best scoring model found during training. Make sure to train the models before calling this function. :param config: The configuration of the pre-trained model. :return: dict with noise test results over all experiments """ # Load experiment data name = config["name"] experiment_path = os.path.join(config["path"], name) experiment_state = load_ray_tune_experiment( experiment_path=experiment_path, load_results=True) # Go through all checkpoints in the experiment all_checkpoints = experiment_state["checkpoints"] for checkpoint in all_checkpoints: results = checkpoint["results"] if results is None: continue # For each checkpoint select the epoch with the best accuracy as the best epoch best_result = max(results, key=lambda x: x["mean_accuracy"]) best_epoch = best_result["training_iteration"] # Load pre-trained model from checkpoint and run noise test on it logdir = os.path.join(experiment_path, os.path.basename(checkpoint["logdir"])) checkpoint_path = os.path.join(logdir, "checkpoint_{}".format(best_epoch)) # Get the actual config from the saved version (required for sample # or grid search experiments). Replace paths to be the locally correct ones filename = os.path.join(logdir, "params.json") with open(filename, "r") as f: saved_params = json.load(f) saved_params["data_dir"] = config["data_dir"] saved_params["path"] = config["path"] experiment = MNISTSparseExperiment(saved_params) experiment.restore(checkpoint_path) # Save noise results in checkpoint log dir noise_test = os.path.join(logdir, "noise.json") with open(noise_test, "w") as f: json.dump(experiment.run_noise_tests(), f)
class MNISTExperimentTune(tune.Trainable): """Ray tune trainable wrapping MNISTSparseExperiment.""" def _setup(self, config): self.experiment = MNISTSparseExperiment(config) def _train(self): self.experiment.train(self._iteration) return self.experiment.test() def _save(self, checkpoint_dir): return self.experiment.save(checkpoint_dir) def _restore(self, checkpoint_dir): self.experiment.restore(checkpoint_dir)
def _setup(self, config): self.experiment = MNISTSparseExperiment(config)