Exemple #1
0
def create_management_objects(dir_tree, logger, pbar, config):
    # Creates directory tres

    if dir_tree is None:
        dir_tree = DirectoryTree(alg_name=config.alg_name,
                                 task_name=config.task_name,
                                 desc=config.desc,
                                 seed=config.seed,
                                 root=config.root_dir)

        dir_tree.create_directories()

    # Creates logger and prints config

    if logger is None:
        logger = create_logger('MASTER', config.log_level,
                               dir_tree.seed_dir / 'logger.out')
    logger.debug(config_to_str(config))

    # Creates a progress-bar

    if pbar == "default_pbar":
        pbar = tqdm()

    if pbar is not None:
        pbar.n = 0
        pbar.desc += f'{dir_tree.storage_dir.name}/{dir_tree.experiment_dir.name}/{dir_tree.seed_dir.name}'
        pbar.total = config.max_episodes

    return dir_tree, logger, pbar
def create_experiment_dir(storage_name_id, config, config_unique_dict, SEEDS, root_dir, git_hashes=None):
    # Determine experiment number

    tmp_dir_tree = DirectoryTree(id=storage_name_id, alg_name=config.alg_name, task_name=config.task_name,
                                 desc=config.desc, seed=1, git_hashes=git_hashes, root=root_dir)

    experiment_num = int(tmp_dir_tree.experiment_dir.name.strip('experiment'))

    # For each seed in these experiments, creates a directory

    for seed in SEEDS:
        config.seed = seed
        config_unique_dict['seed'] = seed

        # Creates the experiment directory

        dir_tree = DirectoryTree(id=storage_name_id,
                                 alg_name=config.alg_name,
                                 task_name=config.task_name,
                                 desc=config.desc,
                                 seed=config.seed,
                                 experiment_num=experiment_num,
                                 git_hashes=git_hashes,
                                 root=root_dir)

        dir_tree.create_directories()

        # Saves the config as json file (to be run later)

        save_config_to_json(config, filename=str(dir_tree.seed_dir / 'config.json'))

        # Saves a dictionary of what makes each seed_dir unique (just for display on graphs)

        validate_config_unique(config, config_unique_dict)
        save_dict_to_json(config_unique_dict, filename=str(dir_tree.seed_dir / 'config_unique.json'))

        # Creates empty file UNHATCHED meaning that the experiment is ready to be run

        open(str(dir_tree.seed_dir / 'UNHATCHED'), 'w+').close()

    return dir_tree
Exemple #3
0
def _compute_seed_scores(storage_dir, performance_metric,
                         performance_aggregation, group_key, bar_key,
                         re_run_if_exists, save_dir, logger, root_dir,
                         n_eval_runs):
    if (storage_dir / save_dir /
            f"{save_dir}_seed_scores.pkl").exists() and not re_run_if_exists:
        logger.info(
            f" SKIPPING {storage_dir} - {save_dir}_seed_scores.pkl already exists"
        )
        return

    else:
        logger.info(f"Benchmarking {storage_dir}...")

    assert group_key in [
        'task_name', 'storage_name', 'experiment_num', 'alg_name'
    ]
    assert bar_key in [
        'task_name', 'storage_name', 'experiment_num', 'alg_name'
    ]

    # Initialize container

    scores = OrderedDict()

    # Get all experiment directories

    all_experiments = DirectoryTree.get_all_experiments(
        storage_dir=storage_dir)

    for experiment_dir in all_experiments:

        # For that experiment, get all seed directories

        experiment_seeds = DirectoryTree.get_all_seeds(
            experiment_dir=experiment_dir)

        # Initialize container

        all_seeds_scores = []

        for i, seed_dir in enumerate(experiment_seeds):
            # Prints which seed directory is being treated

            logger.debug(f"{seed_dir}")

            # Loads training config

            config_dict = load_dict_from_json(str(seed_dir / "config.json"))

            # Selects how data will be identified

            keys = {
                "task_name": config_dict["task_name"],
                "storage_name": seed_dir.parents[1].name,
                "alg_name": config_dict["alg_name"],
                "experiment_num": seed_dir.parents[0].name.strip('experiment')
            }

            outer_key = keys[bar_key]
            inner_key = keys[group_key]

            # Evaluation phase

            if performance_metric == 'evaluation_runs':

                assert n_eval_runs is not None

                try:
                    from evaluate import evaluate, get_evaluation_args
                except ImportError as e:
                    raise ImportError(
                        f"{e}\nTo evaluate models based on --performance_metric='evaluation_runs' "
                        f"alfred.benchmark assumes the following structure that the working directory contains "
                        f"a file called evaluate.py containing two functions: "
                        f"\n\t1. a function evaluate() that returns a score for each evaluation run"
                        f"\n\t2. a function get_evaluation_args() that returns a Namespace of arguments for evaluate()"
                    )

                # Sets config for evaluation phase

                eval_config = get_evaluation_args(overwritten_args="")
                eval_config.storage_name = seed_dir.parents[1].name
                eval_config.experiment_num = int(
                    seed_dir.parents[0].name.strip("experiment"))
                eval_config.seed_num = int(seed_dir.name.strip("seed"))
                eval_config.render = False
                eval_config.n_episodes = n_eval_runs
                eval_config.root_dir = root_dir

                # Evaluates agent and stores the return

                performance_data = evaluate(eval_config)

            else:

                # Loads training data

                loaded_recorder = Recorder.init_from_pickle_file(
                    filename=str(seed_dir / 'recorders' /
                                 'train_recorder.pkl'))

                performance_data = loaded_recorder.tape[performance_metric]

            # Aggregation phase

            if performance_aggregation == 'min':
                score = np.min(performance_data)

            elif performance_aggregation == 'max':
                score = np.max(performance_data)

            elif performance_aggregation == 'avg':
                score = np.mean(performance_data)

            elif performance_aggregation == 'last':
                score = performance_data[-1]

            elif performance_aggregation == 'mean_on_last_20_percents':
                eighty_percent_index = int(0.8 * len(performance_data))
                score = np.mean(performance_data[eighty_percent_index:])
            else:
                raise NotImplementedError

            all_seeds_scores.append(score)

        if outer_key not in scores.keys():
            scores[outer_key] = OrderedDict()

        scores[outer_key][inner_key] = np.stack(all_seeds_scores)

    os.makedirs(storage_dir / save_dir, exist_ok=True)

    with open(storage_dir / save_dir / f"{save_dir}_seed_scores.pkl",
              "wb") as f:
        pickle.dump(scores, f)

    scores_info = {
        'n_eval_runs': n_eval_runs,
        'performance_metric': performance_metric,
        'performance_aggregation': performance_aggregation
    }

    save_dict_to_json(scores_info,
                      filename=str(storage_dir / save_dir /
                                   f"{save_dir}_seed_scores_info.json"))
Exemple #4
0
def _make_vertical_densities_figure(storage_dirs, visuals_file,
                                    additional_curves_file, make_box_plot,
                                    queried_performance_metric,
                                    queried_performance_aggregation, save_dir,
                                    load_dir, logger):
    # Initialize container

    all_means = OrderedDict()
    long_labels = OrderedDict()
    titles = OrderedDict()
    labels = OrderedDict()
    colors = OrderedDict()
    markers = OrderedDict()
    all_performance_metrics = []
    all_performance_aggregation = []

    # Gathers data

    for storage_dir in storage_dirs:
        logger.debug(storage_dir)

        # Loads the scores and scores_info saved by summarize_search

        with open(str(storage_dir / load_dir / f"{load_dir}_seed_scores.pkl"),
                  "rb") as f:
            scores = pickle.load(f)

        scores_info = load_dict_from_json(
            str(storage_dir / "summary" / f"summary_seed_scores_info.json"))
        all_performance_metrics.append(scores_info['performance_metric'])
        all_performance_aggregation.append(
            scores_info['performance_aggregation'])

        x = list(scores.keys())[0]
        storage_name = storage_dir.name

        # Adding task_name if first time it is encountered

        _, _, _, outer_key, _ = DirectoryTree.extract_info_from_storage_name(
            storage_name)
        if outer_key not in list(all_means.keys()):
            all_means[outer_key] = OrderedDict()

        # Taking the mean across evaluations and seeds

        _, _, _, outer_key, _ = DirectoryTree.extract_info_from_storage_name(
            storage_name)
        all_means[outer_key][storage_name] = [
            array.mean() for array in scores[x].values()
        ]

        if outer_key not in long_labels.keys():
            long_labels[outer_key] = [storage_dir]
        else:
            long_labels[outer_key].append(storage_dir)

    # Security checks

    assert len(set(all_performance_metrics)) == 1 and len(set(all_performance_aggregation)) == 1, \
        "Error: all seeds do not have scores computed using the same performance metric or performance aggregation. " \
        "You should benchmark with --re_run_if_exists=True using the desired --performance_aggregation and " \
        "--performance_metric so that all seeds that you want to compare have the same metrics."
    actual_performance_metric = all_performance_metrics.pop()
    actual_performance_aggregation = all_performance_aggregation.pop()

    assert queried_performance_metric == actual_performance_metric and \
           queried_performance_aggregation == actual_performance_aggregation, \
        "Error: The performance_metric or performance_aggregation that was queried for the vertical_densities " \
        "is not the same as what was saved by summarize_search. You should benchmark with --re_run_if_exists=True " \
        "using the desired --performance_aggregation and  --performance_metric so that all seeds that you want " \
        "to compare have the same metrics."

    # Initialize figure

    n_graphs = len(all_means.keys())

    if n_graphs == 3:
        axes_shape = (1, 3)

    elif n_graphs > 1:
        i_max = int(np.ceil(np.sqrt(len(all_means.keys()))))
        axes_shape = (int(np.ceil(len(all_means.keys()) / i_max)), i_max)
    else:
        axes_shape = (1, 1)

    # Creates figure

    gs = gridspec.GridSpec(*axes_shape)
    fig = plt.figure(figsize=(12 * axes_shape[1], 5 * axes_shape[0]))

    for i, outer_key in enumerate(all_means.keys()):

        # Selects right ax object

        if axes_shape == (1, 1):
            current_ax = fig.add_subplot(gs[0, 0])
        elif any(np.array(axes_shape) == 1):
            current_ax = fig.add_subplot(gs[0, i])
        else:
            current_ax = fig.add_subplot(gs[i // axes_shape[1],
                                            i % axes_shape[1]])

        # Collect algorithm names

        if all([
                type(long_label) is pathlib.PosixPath
                for long_label in long_labels[outer_key]
        ]):
            algs = []
            for path in long_labels[outer_key]:
                _, _, alg, _, _ = DirectoryTree.extract_info_from_storage_name(
                    path.name)
                algs.append(alg)

        # Loads visuals dictionaries

        if visuals_file is not None:
            visuals = load_dict_from_json(visuals_file)
        else:
            visuals = None

        # Loads additional curves file

        if additional_curves_file is not None:
            additional_curves = load_dict_from_json(additional_curves_file)
        else:
            additional_curves = None

        # Sets visuals

        if type(visuals) is dict and 'titles_dict' in visuals.keys():
            titles[outer_key] = visuals['titles_dict'][outer_key]
        else:
            titles[outer_key] = outer_key

        if type(visuals) is dict and 'labels_dict' in visuals.keys():
            labels[outer_key] = [visuals['labels_dict'][alg] for alg in algs]
        else:
            labels[outer_key] = long_labels[outer_key]

        if type(visuals) is dict and 'colors_dict' in visuals.keys():
            colors[outer_key] = [visuals['colors_dict'][alg] for alg in algs]
        else:
            colors[outer_key] = [None for _ in long_labels[outer_key]]

        if type(visuals) is dict and 'markers_dict' in visuals.keys():
            markers[outer_key] = [visuals['markers_dict'][alg] for alg in algs]
        else:
            markers[outer_key] = [None for _ in long_labels[outer_key]]

        logger.info(
            f"Graph for {outer_key}:\n\tlabels={labels}\n\tcolors={colors}\n\tmarkers={markers}"
        )

        if additional_curves_file is not None:
            hlines = additional_curves['hlines'][outer_key]
        else:
            hlines = None

        # Makes the plots

        plot_vertical_densities(
            ax=current_ax,
            ys=list(all_means[outer_key].values()),
            labels=labels[outer_key],
            colors=colors[outer_key],
            make_boxplot=make_box_plot,
            title=titles[outer_key].upper(),
            ylabel=
            f"{actual_performance_aggregation}-{actual_performance_metric}",
            hlines=hlines)

    # Saves the figure

    plt.tight_layout()

    filename_addon = "boxplot" if make_box_plot else ""

    for storage_dir in storage_dirs:
        os.makedirs(storage_dir / save_dir, exist_ok=True)

        fig.savefig(storage_dir / save_dir /
                    f'{save_dir}_vertical_densities_{filename_addon}.pdf',
                    bbox_inches="tight")

        save_dict_to_json([str(storage_dir) in storage_dirs],
                          storage_dir / save_dir /
                          f'{save_dir}_vertical_densities_sources.json')

    plt.close(fig)
Exemple #5
0
def _make_benchmark_learning_figure(x_data,
                                    y_data,
                                    x_metric,
                                    y_metric,
                                    y_error_bars,
                                    storage_dirs,
                                    save_dir,
                                    logger,
                                    n_labels=np.inf,
                                    visuals_file=None,
                                    additional_curves_file=None):
    # Initialize containers

    y_data_means = OrderedDict()
    y_data_err_up = OrderedDict()
    y_data_err_down = OrderedDict()
    long_labels = OrderedDict()
    titles = OrderedDict()
    x_axis_titles = OrderedDict()
    y_axis_titles = OrderedDict()
    labels = OrderedDict()
    colors = OrderedDict()
    markers = OrderedDict()

    for outer_key in y_data:
        y_data_means[outer_key] = OrderedDict()
        y_data_err_up[outer_key] = OrderedDict()
        y_data_err_down[outer_key] = OrderedDict()

    # Initialize figure

    n_graphs = len(y_data.keys())

    if n_graphs == 3:
        axes_shape = (1, 3)

    elif n_graphs > 1:
        i_max = int(np.ceil(np.sqrt(len(y_data.keys()))))
        axes_shape = (int(np.ceil(len(y_data.keys()) / i_max)), i_max)
    else:
        axes_shape = (1, 1)

    # Creates figure

    gs = gridspec.GridSpec(*axes_shape)
    fig = plt.figure(figsize=(8 * axes_shape[1], 4 * axes_shape[0]))

    # Compute means and stds for all inner_key curve from raw data

    for i, outer_key in enumerate(y_data.keys()):
        for inner_key in y_data[outer_key].keys():
            x_data[outer_key][inner_key] = x_data[outer_key][inner_key][
                0]  # assumes all x_data are the same

            if y_error_bars == "stderr":
                y_data_means[outer_key][inner_key] = np.stack(
                    y_data[outer_key][inner_key], axis=-1).mean(-1)
                y_data_err_up[outer_key][inner_key] = np.stack(y_data[outer_key][inner_key], axis=-1).std(-1) \
                                                      / len(y_data_means[outer_key][inner_key]) ** 0.5
                y_data_err_down = y_data_err_up

            elif y_error_bars == "bootstrapped_CI":
                y_data_samples = np.stack(
                    y_data[outer_key][inner_key],
                    axis=-1)  # dim=0 is accross time (n_time_steps, n_samples)
                mean, err_up, err_down = get_95_confidence_interval_of_sequence(
                    list_of_samples=y_data_samples, method=y_error_bars)
                y_data_means[outer_key][inner_key] = mean
                y_data_err_up[outer_key][inner_key] = err_up
                y_data_err_down[outer_key][inner_key] = err_down

            else:
                raise NotImplementedError

        long_labels[outer_key] = list(y_data_means[outer_key].keys())

        # Limits the number of labels to be displayed (only displays labels of n_labels best experiments)

        if n_labels < np.inf:
            mean_over_entire_curves = np.array(
                [array.mean() for array in y_data_means[outer_key].values()])
            n_max_idxs = (-mean_over_entire_curves).argsort()[:n_labels]

            for k in range(len(long_labels[outer_key])):
                if k in n_max_idxs:
                    continue
                else:
                    long_labels[outer_key][k] = None

        # Selects right ax object

        if axes_shape == (1, 1):
            current_ax = fig.add_subplot(gs[0, 0])
        elif any(np.array(axes_shape) == 1):
            current_ax = fig.add_subplot(gs[0, i])
        else:
            current_ax = fig.add_subplot(gs[i // axes_shape[1],
                                            i % axes_shape[1]])

        # Collect algorithm names

        if all([
                type(long_label) is pathlib.PosixPath
                for long_label in long_labels[outer_key]
        ]):
            algs = []
            for path in long_labels[outer_key]:
                _, _, alg, _, _ = DirectoryTree.extract_info_from_storage_name(
                    path.name)
                algs.append(alg)

        # Loads visuals dictionaries

        if visuals_file is not None:
            visuals = load_dict_from_json(visuals_file)
        else:
            visuals = None

        # Loads additional curves file

        if additional_curves_file is not None:
            additional_curves = load_dict_from_json(additional_curves_file)
        else:
            additional_curves = None

        # Sets visuals

        if type(visuals) is dict and 'titles_dict' in visuals.keys():
            titles[outer_key] = visuals['titles_dict'][outer_key]
        else:
            titles[outer_key] = outer_key

        if type(visuals) is dict and 'axis_titles_dict' in visuals.keys():
            x_axis_titles[outer_key] = visuals['axis_titles_dict'][x_metric]
            y_axis_titles[outer_key] = visuals['axis_titles_dict'][y_metric]
        else:
            x_axis_titles[outer_key] = x_metric
            y_axis_titles[outer_key] = y_metric

        if type(visuals) is dict and 'labels_dict' in visuals.keys():
            labels[outer_key] = [
                visuals['labels_dict'][inner_key]
                for inner_key in y_data_means[outer_key].keys()
            ]
        else:
            labels[outer_key] = long_labels[outer_key]

        if type(visuals) is dict and 'colors_dict' in visuals.keys():
            colors[outer_key] = [
                visuals['colors_dict'][inner_key]
                for inner_key in y_data_means[outer_key].keys()
            ]
        else:
            colors[outer_key] = [None for _ in long_labels[outer_key]]

        if type(visuals) is dict and 'markers_dict' in visuals.keys():
            markers[outer_key] = [
                visuals['markers_dict'][inner_key]
                for inner_key in y_data_means[outer_key].keys()
            ]
        else:
            markers[outer_key] = [None for _ in long_labels[outer_key]]

        logger.info(
            f"Graph for {outer_key}:\n\tlabels={labels}\n\tcolors={colors}\n\tmarkers={markers}"
        )

        if additional_curves_file is not None:
            hlines = additional_curves['hlines'][outer_key]
            n_lines = len(hlines)
        else:
            hlines = None
            n_lines = 0

        # Plots the curves

        plot_curves(
            current_ax,
            xs=list(x_data[outer_key].values()),
            ys=list(y_data_means[outer_key].values()),
            fill_up=list(y_data_err_up[outer_key].values()),
            fill_down=list(y_data_err_down[outer_key].values()),
            labels=labels[outer_key],
            colors=colors[outer_key],
            markers=markers[outer_key],
            xlabel=x_axis_titles[outer_key],
            ylabel=y_axis_titles[outer_key] if i == 0 else "",
            title=titles[outer_key].upper(),
            add_legend=True if i == (len(list(y_data.keys())) - 1) else False,
            legend_outside=True,
            legend_loc="upper right",
            legend_pos=(0.95, -0.2),
            legend_n_columns=len(list(y_data_means[outer_key].values())) +
            n_lines,
            hlines=hlines,
            tick_font_size=22,
            axis_font_size=26,
            legend_font_size=26,
            title_font_size=28)

    plt.tight_layout()

    for storage_dir in storage_dirs:
        os.makedirs(storage_dir / save_dir, exist_ok=True)
        fig.savefig(storage_dir / save_dir / f'{save_dir}_learning.pdf',
                    bbox_inches='tight')

    plt.close(fig)
Exemple #6
0
def _gather_experiments_training_curves(storage_dir,
                                        graph_key,
                                        curve_key,
                                        logger,
                                        x_metric,
                                        y_metric,
                                        x_data=None,
                                        y_data=None):

    # Initialize containers

    if x_data is None:
        x_data = OrderedDict()
    else:
        assert type(x_data) is OrderedDict

    if y_data is None:
        y_data = OrderedDict()
    else:
        assert type(y_data) is OrderedDict

    # Get all experiment directories

    all_experiments = DirectoryTree.get_all_experiments(
        storage_dir=storage_dir)

    for experiment_dir in all_experiments:

        # For that experiment, get all seed directories

        experiment_seeds = DirectoryTree.get_all_seeds(
            experiment_dir=experiment_dir)

        for i, seed_dir in enumerate(experiment_seeds):

            # Prints which seed directory is being treated

            logger.debug(f"{seed_dir}")

            # Loads training config

            config_dict = load_dict_from_json(str(seed_dir / "config.json"))

            # Keys can be any information stored in config.json
            # We also handle a few special cases (e.g. "experiment_num")

            keys = config_dict.copy()
            keys['experiment_num'] = seed_dir.parent.stem.strip('experiment')
            keys['storage_name'] = seed_dir.parents[1]

            outer_key = keys[graph_key]  # number of graphs to be made
            inner_key = keys[curve_key]  # number of curves per graph

            # Loads training data

            loaded_recorder = Recorder.init_from_pickle_file(
                filename=str(seed_dir / 'recorders' / 'train_recorder.pkl'))

            # Stores the data

            if outer_key not in y_data.keys():
                x_data[outer_key] = OrderedDict()
                y_data[outer_key] = OrderedDict()

            if inner_key not in y_data[outer_key].keys():
                x_data[outer_key][inner_key] = []
                y_data[outer_key][inner_key] = []

            x_data[outer_key][inner_key].append(loaded_recorder.tape[x_metric])
            y_data[outer_key][inner_key].append(
                loaded_recorder.tape[y_metric]
            )  # TODO: make sure that this is a scalar metric, even for eval_return (and not 10 points for every eval_step). All metrics saved in the recorder should be scalars for every time point.

    return x_data, y_data
Exemple #7
0
def _gather_scores(storage_dirs,
                   save_dir,
                   y_error_bars,
                   logger,
                   normalize_with_first_model=True,
                   sort_bars=False):
    # Initialize containers

    scores_means = OrderedDict()
    scores_err_up = OrderedDict()
    scores_err_down = OrderedDict()

    # Loads performance benchmark data

    individual_scores = OrderedDict()
    for storage_dir in storage_dirs:
        with open(storage_dir / save_dir / f"{save_dir}_seed_scores.pkl",
                  "rb") as f:
            individual_scores[storage_dir.name] = pickle.load(f)

    # Print keys so that user can verify all these benchmarks make sense to compare (e.g. same tasks)

    for storage_name, idv_score in individual_scores.items():
        logger.debug(storage_name)
        for outer_key in idv_score.keys():
            logger.debug(f"{outer_key}: {list(idv_score[outer_key].keys())}")
        logger.debug(f"\n")

    # Reorganize all individual_scores in a single dictionary

    scores = OrderedDict()
    for storage_name, idv_score in individual_scores.items():
        for outer_key in idv_score:
            if outer_key not in list(scores.keys()):
                scores[outer_key] = OrderedDict()
            for inner_key in idv_score[outer_key]:
                if inner_key not in list(scores.keys()):
                    scores[outer_key][inner_key] = OrderedDict()
                _, _, _, task_name, _ = DirectoryTree.extract_info_from_storage_name(
                    storage_name)
                scores[outer_key][inner_key] = idv_score[outer_key][inner_key]

    # First storage_dir will serve as reference if normalize_with_first_model is True

    reference_key = list(scores.keys())[0]
    reference_means = OrderedDict()
    for inner_key in scores[reference_key].keys():
        if normalize_with_first_model:
            reference_means[inner_key] = scores[reference_key][inner_key].mean(
            )
        else:
            reference_means[inner_key] = 1.

    # Sorts inner_keys (bars among groups)

    sorted_inner_keys = list(
        reversed(
            sorted(reference_means.keys(),
                   key=lambda item:
                   (scores[reference_key][item].mean(), item))))

    if sort_bars:
        inner_keys = sorted_inner_keys
    else:
        inner_keys = scores[reference_key].keys()

    # Computes means and error bars

    for inner_key in inner_keys:
        for outer_key in scores.keys():
            if outer_key not in scores_means.keys():
                scores_means[outer_key] = OrderedDict()
                scores_err_up[outer_key] = OrderedDict()
                scores_err_down[outer_key] = OrderedDict()

            if y_error_bars == "stderr":
                scores_means[outer_key][inner_key] = np.mean(
                    scores[outer_key][inner_key] /
                    (reference_means[inner_key] + 1e-8))

                scores_err_down[outer_key][inner_key] = np.std(
                    scores[outer_key][inner_key] /
                    (reference_means[inner_key] + 1e-8)) / len(
                        scores[outer_key][inner_key])**0.5
                scores_err_up[outer_key][inner_key] = scores_err_down[
                    outer_key][inner_key]

            elif y_error_bars == "10th_quantiles":
                scores_means[outer_key][inner_key] = np.mean(
                    scores[outer_key][inner_key] /
                    (reference_means[inner_key] + 1e-8))

                quantile = 0.10
                scores_err_down[outer_key][inner_key] = np.abs(
                    np.quantile(a=scores[outer_key][inner_key] / (reference_means[inner_key] + 1e-8), q=0. + quantile) \
                    - scores_means[outer_key][inner_key])
                scores_err_up[outer_key][inner_key] = np.abs(
                    np.quantile(a=scores[outer_key][inner_key] / (reference_means[inner_key] + 1e-8), q=1. - quantile) \
                    - scores_means[outer_key][inner_key])

            elif y_error_bars == "bootstrapped_CI":
                scores_samples = scores[outer_key][inner_key] / (
                    reference_means[inner_key] + 1e-8)

                mean, err_up, err_down = get_95_confidence_interval(
                    samples=scores_samples, method=y_error_bars)
                scores_means[outer_key][inner_key] = mean
                scores_err_up[outer_key][inner_key] = err_up
                scores_err_down[outer_key][inner_key] = err_down

            else:
                raise NotImplementedError

    return scores, scores_means, scores_err_up, scores_err_down, sorted_inner_keys, reference_key
def prepare_schedule(desc, schedule_file, root_dir, add_to_folder, resample, logger, ask_for_validation):
    # Infers the search_type (grid or random) from provided schedule_file

    schedule_file_path = Path(schedule_file)

    assert schedule_file_path.suffix == '.py', f"The provided --schedule_file should be a python file " \
                                               f"(see: alfred/schedule_examples). You provided " \
                                               f"'--schedule_file={schedule_file}'"

    if "grid_schedule" in schedule_file_path.name:
        search_type = 'grid'
    elif "random_schedule" in schedule_file_path.name:
        search_type = 'random'
    else:
        raise ValueError(f"Provided --schedule_file has the name '{schedule_file_path.name}'. "
                         "Only grid_schedule's and random_schedule's are supported. "
                         "The name of the provided '--schedule_file' must fit one of the following forms: "
                         "'grid_schedule_NAME.py' or 'random_schedule_NAME.py'.")

    if not schedule_file_path.exists():
        raise ValueError(f"Cannot find the provided '--schedule_file': {schedule_file_path}")

    # Gets experiments parameters

    schedule_module = re.sub('\.py$', '', ".".join(schedule_file.split('/')))

    if search_type == 'grid':

        VARIATIONS, ALG_NAMES, TASK_NAMES, SEEDS, experiments, varied_params, get_run_args, schedule = extract_schedule_grid(schedule_module)

    elif search_type == 'random':

        param_samples, ALG_NAMES, TASK_NAMES, SEEDS, experiments, varied_params, get_run_args, schedule = extract_schedule_random(schedule_module)

    else:
        raise NotImplementedError

    # Creates a list of alg_agent and task_name unique combinations

    if desc is not None:
        assert add_to_folder is None, "If --desc is defined, a new storage_dir folder will be created." \
                                      "No --add_to_folder should be provided."

        desc = f"{search_type}_{desc}"
        agent_task_combinations = list(itertools.product(ALG_NAMES, TASK_NAMES))
        mode = "NEW_STORAGE"

    elif add_to_folder is not None:
        assert (get_root(root_dir) / add_to_folder).exists(), f"{add_to_folder} does not exist."
        assert desc is None, "If --add_to_folder is defined, new experiments will be added to the existing folder." \
                             "No --desc should be provided."

        storage_name_id, git_hashes, alg_name, task_name, desc = \
            DirectoryTree.extract_info_from_storage_name(add_to_folder)

        agent_task_combinations = list(itertools.product([alg_name], [task_name]))
        mode = "EXISTING_STORAGE"

    else:
        raise NotImplementedError

    # Duplicates or resamples hyperparameters to match the number of agent_task_combinations

    n_combinations = len(agent_task_combinations)

    experiments = [experiments]
    if search_type == 'random':
        param_samples = [param_samples]

    if search_type == 'random' and resample:
        assert not add_to_folder
        for i in range(n_combinations - 1):
            param_sa, _, _, _, expe, varied_pa, get_run_args, _ = extract_schedule_random(schedule_module)
            experiments.append(expe)
            param_samples.append(param_sa)

    else:
        experiments = experiments * n_combinations
        if search_type == 'random':
            param_samples = param_samples * n_combinations

    # Printing summary of schedule_xyz.py

    info_str = f"\n\nPreparing a {search_type.upper()} search over {len(experiments)} experiments, {len(SEEDS)} seeds"
    info_str += f"\nALG_NAMES: {ALG_NAMES}"
    info_str += f"\nTASK_NAMES: {TASK_NAMES}"
    info_str += f"\nSEEDS: {SEEDS}"

    if search_type == "grid":
        info_str += f"\n\nVARIATIONS:"
        for key in VARIATIONS.keys():
            info_str += f"\n\t{key}: {VARIATIONS[key]}"
    else:
        info_str += f"\n\nParams to be varied over: {varied_params}"

    info_str += f"\n\nDefault {config_to_str(get_run_args(overwritten_cmd_line=''))}\n"

    logger.debug(info_str)

    # Asking for user validation

    if ask_for_validation:

        if mode == "NEW_STORAGE":
            git_hashes = DirectoryTree.get_git_hashes()

            string = "\n"
            for alg_name, task_name in agent_task_combinations:
                string += f"\n\tID_{git_hashes}_{alg_name}_{task_name}_{desc}"
            logger.debug(f"\n\nAbout to create {len(agent_task_combinations)} storage directories, "
                         f"each with {len(experiments)} experiments:"
                         f"{string}")

        else:
            n_existing_experiments = len([path for path in get_root(root_dir) / add_to_folder.iterdir()
                                          if path.name.startswith('experiment')])

            logger.debug(f"\n\nAbout to add {len(experiments)} experiment folders in the following directory"
                         f" (there are currently {n_existing_experiments} in this folder):"
                         f"\n\t{add_to_folder}")

        answer = input("\nShould we proceed? [y or n]")
        if answer.lower() not in ['y', 'yes']:
            logger.debug("Aborting...")
            sys.exit()

    logger.debug("Starting...")

    # For each storage_dir to be created

    all_storage_dirs = []

    for alg_task_i, (alg_name, task_name) in enumerate(agent_task_combinations):

        # Determines storing ID (if new storage_dir)

        if mode == "NEW_STORAGE":
            tmp_dir_tree = DirectoryTree(alg_name=alg_name, task_name=task_name, desc=desc, seed=1, root=root_dir)
            storage_name_id = tmp_dir_tree.storage_dir.name.split('_')[0]

        # For each experiments...

        for param_dict in experiments[alg_task_i]:

            # Creates dictionary pointer-access to a training config object initialized by default

            config = get_run_args(overwritten_cmd_line="")
            config_dict = vars(config)

            # Modifies the config for this particular experiment

            config.alg_name = alg_name
            config.task_name = task_name
            config.desc = desc

            config_unique_dict = {k: v for k, v in param_dict.items() if k in varied_params}
            config_unique_dict['alg_name'] = config.alg_name
            config_unique_dict['task_name'] = config.task_name
            config_unique_dict['seed'] = config.seed

            for param_name in param_dict.keys():
                if param_name not in config_dict.keys():
                    raise ValueError(f"'{param_name}' taken from the schedule is not a valid hyperparameter "
                                     f"i.e. it cannot be found in the Namespace returned by get_run_args().")
                else:
                    config_dict[param_name] = param_dict[param_name]

            # Create the experiment directory

            dir_tree = create_experiment_dir(storage_name_id, config, config_unique_dict, SEEDS, root_dir, git_hashes)

        all_storage_dirs.append(dir_tree.storage_dir)

        # Saves VARIATIONS in the storage directory

        first_experiment_created = int(dir_tree.current_experiment.strip('experiment')) - len(experiments[0]) + 1
        last_experiment_created = first_experiment_created + len(experiments[0]) - 1

        if search_type == 'grid':

            VARIATIONS['alg_name'] = ALG_NAMES
            VARIATIONS['task_name'] = TASK_NAMES
            VARIATIONS['seed'] = SEEDS

            key = f'{first_experiment_created}-{last_experiment_created}'

            if (dir_tree.storage_dir / 'variations.json').exists():
                variations_dict = load_dict_from_json(filename=str(dir_tree.storage_dir / 'variations.json'))
                assert key not in variations_dict.keys()
                variations_dict[key] = VARIATIONS
            else:
                variations_dict = {key: VARIATIONS}

            save_dict_to_json(variations_dict, filename=str(dir_tree.storage_dir / 'variations.json'))
            open(str(dir_tree.storage_dir / 'GRID_SEARCH'), 'w+').close()

        elif search_type == 'random':
            len_samples = len(param_samples[alg_task_i])
            fig_width = 2 * len_samples if len_samples > 0 else 2
            fig, ax = plt.subplots(len(param_samples[alg_task_i]), 1, figsize=(6, fig_width))
            if not hasattr(ax, '__iter__'):
                ax = [ax]

            plot_sampled_hyperparams(ax, param_samples[alg_task_i],
                                     log_params=['lr', 'tau', 'initial_alpha', 'grad_clip_value', 'lamda1', 'lamda2'])

            j = 1
            while True:
                if (dir_tree.storage_dir / f'variations{j}.png').exists():
                    j += 1
                else:
                    break
            fig.savefig(str(dir_tree.storage_dir / f'variations{j}.png'))
            plt.close(fig)

            open(str(dir_tree.storage_dir / 'RANDOM_SEARCH'), 'w+').close()

        # Printing summary

        logger.info(f'Created directories '
                    f'{str(dir_tree.storage_dir)}/experiment{first_experiment_created}-{last_experiment_created}')

    # Saving the list of created storage_dirs in a text file located with the provided schedule_file

    schedule_name = Path(schedule.__file__).parent.stem
    with open(Path(schedule.__file__).parent / f"list_searches_{schedule_name}.txt", "a+") as f:
        for storage_dir in all_storage_dirs:
            f.write(f"{storage_dir.name}\n")

    logger.info(f"\nEach of these experiments contain directories for the following seeds: {SEEDS}")
def create_plot_arrays(
        from_file,
        storage_name,
        root_dir,
        remove_none,
        logger,
        plots_to_make=alfred.defaults.DEFAULT_PLOTS_ARRAYS_TO_MAKE):
    """
    Creates and and saves comparative figure containing a plot of total reward for each different experiment
    :param storage_dir: pathlib.Path object of the model directory containing the experiments to compare
    :param plots_to_make: list of strings indicating which comparative plots to make
    :return: None
    """
    # Select storage_dirs to run over

    storage_dirs = select_storage_dirs(from_file, storage_name, root_dir)

    for storage_dir in storage_dirs:

        # Get all experiment directories and sorts them numerically

        sorted_experiments = DirectoryTree.get_all_experiments(storage_dir)

        all_seeds_dir = []
        for experiment in sorted_experiments:
            all_seeds_dir = all_seeds_dir + DirectoryTree.get_all_seeds(
                experiment)

        # Determines what type of search was done

        if (storage_dir / 'GRID_SEARCH').exists():
            search_type = 'grid'
        elif (storage_dir / 'RANDOM_SEARCH').exists():
            search_type = 'random'
        else:
            search_type = 'unknown'

        # Determines row and columns of subplots

        if search_type == 'grid':
            variations = load_dict_from_json(filename=str(storage_dir /
                                                          'variations.json'))

            # experiment_groups account for the fact that all the experiment_dir in a storage_dir may have been created
            # though several runs of prepare_schedule.py, and therefore, many "groups" of experiments have been created
            experiment_groups = {key: {} for key in variations.keys()}
            for group_key, properties in experiment_groups.items():
                properties['variations'] = variations[group_key]

                properties['variations_lengths'] = {
                    k: len(properties['variations'][k])
                    for k in properties['variations'].keys()
                }

                # Deleting alg_name and task_name from variations (because they will not be contained in same storage_dir)

                hyperparam_variations_lengths = deepcopy(
                    properties['variations_lengths'])
                del hyperparam_variations_lengths['alg_name']
                del hyperparam_variations_lengths['task_name']

                i_max = sorted(hyperparam_variations_lengths.values())[-1]
                j_max = int(
                    np.prod(
                        sorted(hyperparam_variations_lengths.values())[:-1]))

                if i_max < 4 and j_max == 1:
                    # If only one hyperparameter was varied over, we order plots on a line
                    j_max = i_max
                    i_max = 1
                    ax_array_dim = 1

                elif i_max >= 4 and j_max == 1:
                    # ... unless there are 4 or more variations, then we put them in a square-ish fashion
                    j_max = int(np.sqrt(i_max))
                    i_max = int(np.ceil(float(i_max) / float(j_max)))
                    ax_array_dim = 2

                else:
                    ax_array_dim = 2

                properties['ax_array_shape'] = (i_max, j_max)
                properties['ax_array_dim'] = ax_array_dim

        else:
            experiment_groups = {"all": {}}
            for group_key, properties in experiment_groups.items():
                i_max = len(sorted_experiments
                            )  # each experiment is on a different row
                j_max = len(all_seeds_dir
                            ) // i_max  # each seed is on a different column

                if i_max == 1:
                    ax_array_dim = 1
                else:
                    ax_array_dim = 2

                properties['ax_array_shape'] = (i_max, j_max)
                properties['ax_array_dim'] = ax_array_dim

        for group_key, properties in experiment_groups.items():
            logger.debug(
                f"\n===========================\nPLOTS FOR EXPERIMENT GROUP: {group_key}"
            )
            i_max, j_max = properties['ax_array_shape']
            ax_array_dim = properties['ax_array_dim']

            first_exp = group_key.split('-')[0] if group_key != "all" else 0
            if first_exp != 0:
                for seed_idx, seed_dir in enumerate(all_seeds_dir):
                    if seed_dir.parent.stem.strip('experiment') == first_exp:
                        first_seed_idx = seed_idx
                        break
            else:
                first_seed_idx = 0

            for plot_to_make in plots_to_make:
                x_metric, y_metric, x_lim, y_lim = plot_to_make
                logger.debug(f'\n{y_metric} as a function of {x_metric}:')

                # Creates the subplots

                fig, ax_array = plt.subplots(i_max,
                                             j_max,
                                             figsize=(10 * j_max, 6 * i_max))

                for i in range(i_max):
                    for j in range(j_max):

                        if ax_array_dim == 1 and i_max == 1 and j_max == 1:
                            current_ax = ax_array
                        elif ax_array_dim == 1 and (i_max > 1 or j_max > 1):
                            current_ax = ax_array[j]
                        elif ax_array_dim == 2:
                            current_ax = ax_array[i, j]
                        else:
                            raise Exception(
                                'ax_array should not have more than two dimensions'
                            )

                        try:
                            seed_dir = all_seeds_dir[first_seed_idx +
                                                     (i * j_max + j)]
                            if group_key != 'all' \
                                    and (int(str(seed_dir.parent).split('experiment')[1]) < int(group_key.split('-')[0]) \
                                         or int(str(seed_dir.parent).split('experiment')[1]) > int(
                                        group_key.split('-')[1])):
                                raise IndexError
                            logger.debug(str(seed_dir))
                        except IndexError as e:
                            logger.debug(
                                f'experiment{i * j_max + j} does not exist')
                            current_ax.text(0.2,
                                            0.2,
                                            "no experiment\n found",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='red')
                            continue

                        logger.debug(seed_dir)

                        # Writes unique hyperparameters on plot

                        config = load_config_from_json(
                            filename=str(seed_dir / 'config.json'))
                        config_unique_dict = load_dict_from_json(
                            filename=str(seed_dir / 'config_unique.json'))
                        validate_config_unique(config, config_unique_dict)

                        if search_type == 'grid':
                            sorted_keys = sorted(
                                config_unique_dict.keys(),
                                key=lambda item:
                                (properties['variations_lengths'][item], item),
                                reverse=True)

                        else:
                            sorted_keys = config_unique_dict

                        info_str = f'{seed_dir.parent.stem}\n' + '\n'.join([
                            f'{k} = {config_unique_dict[k]}'
                            for k in sorted_keys
                        ])
                        bbox_props = dict(facecolor='gray', alpha=0.1)
                        current_ax.text(0.05,
                                        0.95,
                                        info_str,
                                        transform=current_ax.transAxes,
                                        fontsize=12,
                                        verticalalignment='top',
                                        bbox=bbox_props)

                        # Skip cases of UNHATCHED or CRASHED experiments

                        if (seed_dir / 'UNHATCHED').exists():
                            logger.debug('UNHATCHED')
                            current_ax.text(0.2,
                                            0.2,
                                            "UNHATCHED",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='blue')
                            continue

                        if (seed_dir / 'CRASH.txt').exists():
                            logger.debug('CRASHED')
                            current_ax.text(0.2,
                                            0.2,
                                            "CRASHED",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='red')
                            continue

                        try:

                            # Loading the recorder

                            loaded_recorder = Recorder.init_from_pickle_file(
                                filename=str(seed_dir / 'recorders' /
                                             'train_recorder.pkl'))

                            # Checking if provided metrics are present in the recorder

                            if y_metric not in loaded_recorder.tape.keys():
                                logger.debug(
                                    f"'{y_metric}' was not recorded in train_recorder."
                                )
                                current_ax.text(0.2,
                                                0.2,
                                                "ABSENT METRIC",
                                                transform=current_ax.transAxes,
                                                fontsize=24,
                                                fontweight='bold',
                                                color='red')
                                continue

                            if x_metric not in loaded_recorder.tape.keys(
                            ) and x_metric is not None:
                                if x_metric is None:
                                    pass
                                else:
                                    logger.debug(
                                        f"'{x_metric}' was not recorded in train_recorder."
                                    )
                                    current_ax.text(
                                        0.2,
                                        0.2,
                                        "ABSENT METRIC",
                                        transform=current_ax.transAxes,
                                        fontsize=24,
                                        fontweight='bold',
                                        color='red')
                                    continue

                            # Removing None entries

                            if remove_none:
                                loaded_recorder.tape[x_metric] = remove_nones(
                                    loaded_recorder.tape[x_metric])
                                loaded_recorder.tape[y_metric] = remove_nones(
                                    loaded_recorder.tape[y_metric])

                            # Plotting

                            try:

                                if x_metric is not None:
                                    plot_curves(
                                        current_ax,
                                        ys=[loaded_recorder.tape[y_metric]],
                                        xs=[loaded_recorder.tape[x_metric]],
                                        xlim=x_lim,
                                        ylim=y_lim,
                                        xlabel=x_metric,
                                        title=y_metric)
                                else:
                                    plot_curves(
                                        current_ax,
                                        ys=[loaded_recorder.tape[y_metric]],
                                        xlim=x_lim,
                                        ylim=y_lim,
                                        title=y_metric)

                            except Exception as e:
                                logger.debug(f'Polotting error: {e}')

                        except FileNotFoundError:
                            logger.debug('Training recorder not found')
                            current_ax.text(0.2,
                                            0.2,
                                            "'train_recorder'\nnot found",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='red')
                            continue

                plt.tight_layout()
                fig.savefig(
                    str(storage_dir /
                        f'{group_key}_comparative_{y_metric}_over_{x_metric}.png'
                        ))
                plt.close(fig)
Exemple #10
0
def evaluate(args):
    # loads config and model

    dir_tree = DirectoryTree.init_from_branching_info(
        root_dir=args.root_dir,
        storage_name=args.storage_name,
        experiment_num=args.experiment_num,
        seed_num=args.seed_num)

    config = load_config_from_json(dir_tree.seed_dir / "config.json")

    if args.model_name is not None:
        model_path = dir_tree.seed_dir / args.model_name
    else:
        if 'rl_alg_name' in config.__dict__.keys():
            if config.rl_alg_name == "":
                model_name = config.irl_alg_name
            else:
                model_name = config.rl_alg_name
        else:
            model_name = config.alg_name
        model_path = dir_tree.seed_dir / (model_name + '_model_best.pt')
    learner = init_from_save(model_path, device=torch.device('cpu'))

    if args.make_gif:
        gif_path = dir_tree.storage_dir / 'gifs'
        gif_path.mkdir(exist_ok=True)
        gif_full_name = uniquify(gif_path / f"{config.task_name}"
                                 f"_experiment{args.experiment_num}"
                                 f"_seed{args.seed_num}"
                                 f"_evalseed{args.eval_seed}.gif")

        if config.task_name in POMMERMAN_TASKS:
            temp_png_folder_base = uniquify(gif_path / 'temp_png')
        else:
            temp_png_folder = False

    # Makes task_name and recorders

    env = make_env(config.task_name)
    ml.set_seeds(args.eval_seed, env)
    Ti = TrainingIterator(args.number_of_eps)
    frames = []
    dt = 1. / args.fps
    trajectories = []

    # camera angles and stuff

    if config['task_name'] in MUJOCO_TASKS:
        env.render(mode='human' if args.render else 'rgb_array')
        env.unwrapped.viewer.cam.type = const.CAMERA_TRACKING

        # # Option 1 (FROM THE SIDE)
        # env.unwrapped.viewer.cam.trackbodyid = 0
        # env.unwrapped.viewer.cam.elevation = -25
        # env.unwrapped.viewer.cam.distance = 6

        # Option 2 (FROM PERSPECTIVE)
        env.unwrapped.viewer.cam.trackbodyid = 0
        env.unwrapped.viewer.cam.elevation = -15
        env.unwrapped.viewer.cam.distance = 4
        env.unwrapped.viewer.cam.azimuth = 35

    # Get expert demonstrations initial states

    if config.task_name in POMMERMAN_TASKS:

        if args.demos_folder is None:
            args.demos_folder = config.task_name.replace(
                'learnable', 'agent47')

        demos = load_expert_demos(config.demos_folder, config.demos_name)
        env.init_game_states = load_game_states_from_demos(demos, idx=0)

    # Episodes loop

    for it in Ti:
        t = 0
        trajectory = []
        ret = 0
        done = False

        # Initial reset

        obs = env.reset()

        # Rendering options

        if args.make_gif:
            if config.task_name in POMMERMAN_TASKS:  # pommerman saves .png per episode
                temp_png_folder = temp_png_folder_base / f"ep_{it.itr}"
                temp_png_folder.mkdir(parents=True, exist_ok=True)
            record_frame(env, t, args.n_skipped_frames, config.task_name,
                         frames, temp_png_folder)

        if args.render:
            env.render()

        if args.waiting:
            wait_for_ENTER_keypress()

        # transitions loop

        while not done:
            calc_start = time.time()
            action = learner.act(obs=obs, sample=not args.act_deterministic)
            next_obs, reward, done, _ = env.step(action)

            if args.make_expert:
                trajectory.append(
                    (obs, action, next_obs, reward, ml.mask(done)))

            obs = next_obs
            ret += reward
            t += 1

            if args.render:
                # Enforces the fps config
                calc_end = time.time()
                elapsed = calc_end - calc_start
                if elapsed < dt:
                    time.sleep(dt - elapsed)
                env.render('human')

            if args.waiting:
                wait_for_ENTER_keypress()

            if args.make_gif:
                # we want the last frame even if we skip some frames
                record_frame(env, t * (1 - done), args.n_skipped_frames,
                             config.task_name, frames, temp_png_folder)
            if t > args.max_ep_len:
                break
        it.record('eval_return', ret)
        if args.make_expert:
            trajectories.append(trajectory)

    # Saves gif of all the episodes

    if args.make_gif:
        if config.task_name in POMMERMAN_TASKS:
            save_gif_from_png_folder(temp_png_folder_base,
                                     gif_full_name,
                                     1 / dt,
                                     delete_folder=True)
        else:
            imageio.mimsave(str(gif_full_name), frames, duration=dt)
    env.close()

    # Saves expert_trajectories

    if args.make_expert:
        if args.expert_save_path is not None:
            expert_path = Path(args.expert_save_path)
        else:
            expert_path = Path('./data/' + config.task_name +
                               f'/expert_demo_{args.number_of_eps}.pkl')

        expert_path.parent.mkdir(exist_ok=True, parents=True)
        expert_path = uniquify(expert_path)
        with open(str(expert_path), 'wb') as fp:
            pickle.dump(trajectories, fp)
            fp.close()
    return Ti.pop_all_means()['eval_return']
Exemple #11
0
def create_retrain_best(from_file, storage_name, best_experiments_mapping,
                        n_retrain_seeds, train_time_factor, root_dir):
    logger = create_logger(name="CREATE_RETRAIN", loglevel=logging.INFO)
    logger.info("\nCREATING retrainBest directories")

    # Select storage_dirs to run over

    storage_dirs = select_storage_dirs(from_file, storage_name, root_dir)

    # Sanity-check that storages exist

    storage_dirs = [
        storage_dir for storage_dir in storage_dirs
        if sanity_check_exists(storage_dir, logger)
    ]

    # Imports schedule file to have same settings for DirectoryTree.git_repos_to_track

    if from_file:
        schedule_file = str([
            path for path in Path(from_file).parent.iterdir()
            if 'schedule' in path.name and path.name.endswith('.py')
        ][0])
        schedule_module = ".".join(schedule_file.split('/')).strip('.py')
        schedule = import_module(schedule_module)

    # Creates retrainBest directories

    retrainBest_storage_dirs = []
    new_retrainBest_storage_dirs = []
    for storage_dir in storage_dirs:

        try:
            # Checks if a retrainBest directory already exists for this search

            search_storage_id = storage_dir.name.split('_')[0]
            corresponding_retrain_directories = [
                path for path in get_root(root_dir).iterdir()
                if f"retrainBest{search_storage_id}" in path.name.split('_')
            ]

            if len(corresponding_retrain_directories) > 0:
                assert len(corresponding_retrain_directories) == 1
                retrainBest_dir = corresponding_retrain_directories[0]

                logger.info(f"Existing retrainBest\n\n"
                            f"\t{storage_dir.name} -> {retrainBest_dir.name}")

                retrainBest_storage_dirs.append(retrainBest_dir)
                continue

            else:

                # The retrainBest directory will contain one experiment with bestConfig from the search...

                if best_experiments_mapping is None:

                    # ... bestConfig is found in the summary/ folder from the search

                    best_config = [
                        path for path in (storage_dir / "summary").iterdir()
                        if path.name.startswith("bestConfig")
                    ][0]

                    assert len(best_config) == 1 and type(best_config) is list

                else:

                    # ... bestConfig is loaded based on specified --best_experiment_mapping

                    best_experiments_mapping_dict = load_dict_from_json(
                        best_experiments_mapping)
                    assert storage_dir.name in best_experiments_mapping_dict.keys(
                    )

                    best_experiment_num = best_experiments_mapping_dict[
                        storage_dir.name]
                    seed_dir = DirectoryTree.get_all_seeds(
                        experiment_dir=storage_dir /
                        f"experiment{best_experiment_num}")[0]
                    best_config = seed_dir / "config.json"

                config_dict = load_dict_from_json(filename=str(best_config))

                # Retrain experiments run for twice as long

                if config_dict['max_episodes'] is not None:
                    config_dict['max_episodes'] = int(
                        config_dict['max_episodes'] * train_time_factor)
                elif config_dict['max_steps'] is not None:
                    config_dict['max_steps'] = int(config_dict['max_steps'] *
                                                   train_time_factor)
                else:
                    raise ValueError(
                        "At least one of max_episodes or max_steps should be defined"
                    )

                # Updates the description

                if "random" in config_dict['desc'] or "grid" in config_dict[
                        'desc']:
                    new_desc = config_dict['desc'] \
                        .replace("random", f"retrainBest{search_storage_id}") \
                        .replace("grid", f"retrainBest{search_storage_id}")
                else:
                    new_desc = config_dict[
                        'desc'] + f"_retrainBest{search_storage_id}"

                config_dict['desc'] = new_desc

                # Creates config Namespace with loaded config_dict

                config = argparse.ArgumentParser().parse_args("")
                config_pointer = vars(config)
                config_pointer.update(config_dict)  # updates config

                config_unique_dict = {}
                config_unique_dict['alg_name'] = config.alg_name
                config_unique_dict['task_name'] = config.task_name
                config_unique_dict['seed'] = config.seed

                # Gets new storage_name_id

                tmp_dir_tree = DirectoryTree(alg_name="",
                                             task_name="",
                                             desc="",
                                             seed=1,
                                             root=root_dir)
                retrain_storage_id = tmp_dir_tree.storage_dir.name.split(
                    '_')[0]

                # Creates the new storage_dir for retrainBest

                dir_tree = create_experiment_dir(
                    storage_name_id=retrain_storage_id,
                    config=config,
                    config_unique_dict=config_unique_dict,
                    SEEDS=[i * 10 for i in range(n_retrain_seeds)],
                    root_dir=root_dir,
                    git_hashes=DirectoryTree.get_git_hashes())

                retrainBest_storage_dirs.append(dir_tree.storage_dir)
                new_retrainBest_storage_dirs.append(dir_tree.storage_dir)

                logger.info(
                    f"New retrainBest:\n\n"
                    f"\t{storage_dir.name} -> {dir_tree.storage_dir.name}")

        except Exception as e:
            logger.info(
                f"Could not create retrainBest-storage_dir {storage_dir}")
            logger.info(f"\n\n{e}\n{traceback.format_exc()}")

    # Saving the list of created storage_dirs in a text file located with the provided schedule_file

    schedule_name = Path(from_file).parent.stem
    with open(
            Path(from_file).parent / f"list_retrains_{schedule_name}.txt",
            "a+") as f:
        for storage_dir in new_retrainBest_storage_dirs:
            f.write(f"{storage_dir.name}\n")

    return retrainBest_storage_dirs