コード例 #1
0
def run_noise_test(config):
    """Run noise test on the best scoring model found during training. Make
    sure to train the models before calling this function.

    :param config: The configuration of the pre-trained model.
    :return: dict with noise test results over all experiments
    """
    # Load experiment data
    name = config["name"]
    experiment_path = os.path.join(config["path"], name)
    experiment_state = load_ray_tune_experiment(
        experiment_path=experiment_path, load_results=True)

    # Go through all checkpoints in the experiment
    all_checkpoints = experiment_state["checkpoints"]
    for checkpoint in all_checkpoints:
        results = checkpoint["results"]
        if results is None:
            continue

        # For each checkpoint select the epoch with the best accuracy as the best epoch
        best_result = max(results, key=lambda x: x["mean_accuracy"])
        best_epoch = best_result["training_iteration"]
        best_config = best_result["config"]
        print("best_epoch: ", best_epoch)

        # Update path
        best_config["path"] = config["path"]
        best_config["data_dir"] = config["data_dir"]

        # Load pre-trained model from checkpoint and run noise test on it
        logdir = os.path.join(experiment_path,
                              os.path.basename(checkpoint["logdir"]))
        checkpoint_path = os.path.join(logdir,
                                       "checkpoint_{}".format(best_epoch))
        experiment = SparseSpeechExperiment(best_config)
        experiment.restore(checkpoint_path)

        # Save noise results in checkpoint log dir
        noise_test = os.path.join(logdir, "noise.json")
        with open(noise_test, "w") as f:
            res = experiment.run_noise_tests()
            json.dump(res, f)

        # Compute total noise score
        total_correct = 0
        for k, v in res.items():
            print(k, v, v["total_correct"])
            total_correct += v["total_correct"]
        print("Total across all noise values", total_correct)

    # Upload results to S3
    sync_function = config.get("sync_function", None)
    if sync_function is not None:
        upload_dir = config["upload_dir"]
        final_cmd = sync_function.format(local_dir=experiment_path,
                                         remote_dir=upload_dir)
        subprocess.Popen(final_cmd, shell=True)
コード例 #2
0
 def _setup(self, config):
     self.experiment = SparseSpeechExperiment(config)