def check_grid_search_complete(filter_func=None):
    print("\nCHECK GRID SEARCH COMPLETE\n")
    statistics = _get_statistics_from_deepobs_check(get_results_path())

    # key-value pairs (problem, optim_cls) : num_settings
    optims_problems_settings = {(item[0], item[1]): item[2]
                                for item in statistics}

    passing = True
    for experiment, grid_dim in zip(create_grid_search(filter_func),
                                    multi_batch_grid_dims(filter_func)):
        optim_cls = experiment._get_optim_name()
        problem = experiment.get_deepobs_problem()
        key = (problem, optim_cls)

        try:
            dim = optims_problems_settings[key]
        except KeyError:
            print("{:55} Expect grid dim {}, but no data found".format(
                str(key), grid_dim))
            passing = False
            continue

        if dim != grid_dim:
            passing = False
            print("{:55} Expect grid dim {}, but found {}".format(
                str(key), grid_dim, dim))
        else:
            print("{} passed".format(key))
    return passing
def rerun_best_for_seed(damping, problem, curvature, mode, metric, output_dir,
                        seed, extended_logs):
    def filter_config(curv, damp, prob):
        return (curv == curvature) and (damp == damping) and (prob == problem)

    search = create_grid_search(filter_func=filter_config)[0]
    best_run = BPBestRun(search, mode, metric, output_dir=output_dir)
    best_run.rerun_best_for_seeds([seed], extended_logs=extended_logs)
def multi_batch_grid_dims(filter_func=None):
    """Dimension of the grid, including the batch sizes."""
    experiments = create_grid_search(filter_func)
    num_batch_sizes = len(BATCH_SIZES)

    return [
        num_batch_sizes * experiment.get_grid_dim()
        for experiment in experiments
    ]
def check_complete(path, filter_func=None):
    print("\nCHECK PROBLEMS AND OPTIMS COMPLETE\n")
    statistics = _get_statistics_from_deepobs_check(path)

    # pairs (problem, optim_cls)
    optims_problems = [(item[0], item[1]) for item in statistics]

    passing = True
    for experiment in create_grid_search(filter_func):
        optim_cls = experiment._get_optim_name()
        problem = experiment.get_deepobs_problem()

        item = (problem, optim_cls)
        if item not in optims_problems:
            passing = False
            print("{:55} missing".format(str(item)))
        else:
            print("{:55} found".format(str(item)))
    return passing
def get_results_path():
    first_grid_search = create_grid_search()[0]
    results_path = first_grid_search.get_output_dir()
    return results_path