Example #1
0
File: NS.py Project: salehiac/BR-NS
def main():
    parser = argparse.ArgumentParser(description='Novelty Search.')
    parser.add_argument('--config',
                        type=str,
                        help="yaml config file for ns",
                        default="")

    args = parser.parse_args()

    if not len(args.config):
        raise Exception("You need to provide a yaml config file")

    if len(args.config):
        with open(args.config, "r") as fl:
            config = yaml.load(fl, Loader=yaml.FullLoader)

        if config["problem"]["name"] == "hardmaze":
            max_steps = config["problem"]["max_steps"]
            bd_type = config["problem"]["bd_type"]
            assets = config["problem"]["assets"]
            import HardMaze
            problem = HardMaze.HardMaze(bd_type=bd_type,
                                        max_steps=max_steps,
                                        assets=assets)
        elif config["problem"]["name"] == "large_ant_maze" or config[
                "problem"]["name"] == "huge_ant_maze":
            max_steps = config["problem"]["max_steps"]
            bd_type = config["problem"]["bd_type"]
            assets = config["problem"]["assets"]
            pb_type = "huge" if config["problem"][
                "name"] == "huge_ant_maze" else "large"
            import LargeAntMaze
            problem = LargeAntMaze.LargeAntMaze(pb_type=pb_type,
                                                bd_type=bd_type,
                                                max_steps=max_steps,
                                                assets=assets)
        else:
            raise NotImplementedError("Problem type")

        if config["novelty_estimator"]["type"] == "archive_based":
            nov_estimator = NoveltyEstimators.ArchiveBasedNoveltyEstimator(
                k=config["hyperparams"]["k"])
            arch_types = {"list_based": Archives.ListArchive}
            arch = arch_types[config["archive"]["type"]](
                max_size=config["archive"]["max_size"],
                growth_rate=config["archive"]["growth_rate"],
                growth_strategy=config["archive"]["growth_strategy"],
                removal_strategy=config["archive"]["removal_strategy"])
        elif config["novelty_estimator"]["type"] == "learned":
            bd_dims = problem.get_bd_dims()
            embedding_dims = 2 * bd_dims
            nov_estimator = NoveltyEstimators.LearnedNovelty1d(
                in_dim=bd_dims,
                emb_dim=embedding_dims,
                pb_limits=problem.get_behavior_space_boundaries())
            arch = None

        if config["selector"]["type"] == "elitist_with_thresh":

            selector = functools.partial(
                MiscUtils.selBest, k=config["hyperparams"]["population_size"])

        elif config["selector"]["type"] == "roulette_with_thresh":
            roulette_msg = "Usage currently not supported: it ends up chosing the same element many times, this duplicates agent._ids etc"
            roulette_msg += " fixing this bug is not a priority since selBest with thresholding actually works well"
            raise Exception(roulette_msg)

        elif config["selector"]["type"] == "nsga2_with_thresh":

            selector = MiscUtils.NSGA2(
                k=config["hyperparams"]["population_size"])

        elif config["selector"]["type"] == "elitist":

            selector = functools.partial(
                MiscUtils.selBest,
                k=config["hyperparams"]["population_size"],
                automatic_threshold=False)

        else:
            raise NotImplementedError("selector")

        in_dims = problem.dim_obs
        out_dims = problem.dim_act
        num_pop = config["hyperparams"]["population_size"]
        if config["population"]["individual_type"] == "simple_fw_fc":

            normalise_output_with = ""
            num_hidden = 3
            hidden_dim = 10
            if "large_ant_maze" == config["problem"]["name"]:
                normalise_output_with = "tanh"
                num_hidden = 4
                hidden_dim = 10

            def make_ag():
                return Agents.SmallFC_FW(
                    in_d=in_dims,
                    out_d=out_dims,
                    num_hidden=num_hidden,
                    hidden_dim=hidden_dim,
                    output_normalisation=normalise_output_with)
        elif config["population"]["individual_type"] == "agent1d":

            def make_ag():
                return Agents.Agent1d(min(problem.env.phi_vals),
                                      max(problem.env.phi_vals))

        mutator_type = config["mutator"]["type"]
        genotype_len = make_ag().get_genotype_len()
        if mutator_type == "gaussian_same":
            mutator_conf = config["mutator"]["gaussian_params"]
            mu, sigma, indpb = mutator_conf["mu"], mutator_conf[
                "sigma"], mutator_conf["indpb"]
            mus = [mu] * genotype_len
            sigmas = [sigma] * genotype_len
            mutator = functools.partial(deap_tools.mutGaussian,
                                        mu=mus,
                                        sigma=sigmas,
                                        indpb=indpb)

        elif mutator_type == "poly_same":
            mutator_conf = config["mutator"]["poly_params"]
            eta, low, up, indpb = mutator_conf["eta"], mutator_conf[
                "low"], mutator_conf["up"], mutator_conf["indpb"]

            if config["population"]["individual_type"] == "agent1d":
                dummy_ag = make_ag()
                low = dummy_ag.min_val
                up = dummy_ag.max_val

            mutator = functools.partial(deap_tools.mutPolynomialBounded,
                                        eta=eta,
                                        low=low,
                                        up=up,
                                        indpb=indpb)

        else:
            raise NotImplementedError("mutation type")

        map_t = "scoop" if config["use_scoop"] else "std"
        visualise_bds = config["visualise_bds"]
        ns = NoveltySearch(
            archive=arch,
            nov_estimator=nov_estimator,
            mutator=mutator,
            problem=problem,
            selector=selector,
            n_pop=num_pop,
            n_offspring=config["hyperparams"]["offspring_size"],
            agent_factory=make_ag,
            visualise_bds_flag=visualise_bds,
            map_type=map_t,
            logs_root=config["ns_log_root"],
            compute_parent_child_stats=config["compute_parent_child_stats"])

        MiscUtils.bash_command(
            ["cp", args.config, ns.log_dir_path + "/config.yaml"])

        stop_on_reaching_task = config["stop_when_task_solved"]
        nov_estimator.log_dir = ns.log_dir_path
        ns.disable_tqdm = config["disable_tqdm"]
        ns.save_archive_to_file = config["archive"]["save_to_file"]
        if ns.disable_tqdm:
            print(
                colored("[NS info] tqdm is disabled.",
                        "magenta",
                        attrs=["bold"]))

        final_pop, solutions = ns(
            iters=config["hyperparams"]["num_generations"],
            stop_on_reaching_task=stop_on_reaching_task,
            save_checkpoints=config["save_checkpoints"])