Exemplo n.º 1
0
def init_population_dir(tmp_path, request):
    import torch
    from aw_nas.common import get_search_space
    from aw_nas import utils
    from aw_nas.main import _init_component

    cfg = getattr(request, "param", {})
    scfg = cfg.pop("search_space_cfg", {})
    search_space = get_search_space(cls="cnn", **scfg)
    path = utils.makedir(os.path.join(tmp_path, "init_population_dir"))
    ckpt_dir = utils.makedir(os.path.join(tmp_path, "init_ckpt_path"))

    # dump config template
    with open(os.path.join(path, "template.yaml"), "w") as wf:
        wf.write(sample_config)

    # generate mock records, ckpts
    num_records = cfg.get("num_records", 3)
    cfg_template = ConfigTemplate(yaml.load(StringIO(sample_config)))
    model_records = collections.OrderedDict()
    for ind in range(num_records):
        rollout = search_space.random_sample()
        cfg = cfg_template.create_cfg(rollout.genotype)
        ckpt_path = os.path.join(ckpt_dir, str(ind))
        cnn_model = _init_component(cfg,
                                    "final_model",
                                    search_space=search_space,
                                    device=torch.device("cpu"))
        torch.save(cnn_model, ckpt_path)
        model_records[ind] = ModelRecord(rollout.genotype,
                                         cfg,
                                         search_space,
                                         checkpoint_path=ckpt_path,
                                         finished=True,
                                         confidence=1,
                                         perfs={
                                             "acc": np.random.rand(),
                                             "loss": np.random.uniform(0, 10)
                                         })
    # initialize population
    population = Population(search_space, model_records, cfg_template)
    # save population
    population.save(path, 0)

    # ugly: return ss for reference
    return (path, search_space)
Exemplo n.º 2
0
def test_population_init(init_population_dir):
    from aw_nas.rollout.mutation import Population
    from aw_nas.common import rollout_from_genotype_str
    import glob

    init_dir, search_space = init_population_dir
    population = Population.init_from_dirs([init_dir], search_space)
    num_records = len(glob.glob(os.path.join(init_dir, "*.yaml"))) - 1
    assert population.size == num_records

    # test `population.contain` judgement
    rollout = rollout_from_genotype_str(str(population.get_model(0).genotype), search_space)
    assert str(rollout.genotype) == str(population.get_model(0).genotype)
    assert population.contain_rollout(rollout)
Exemplo n.º 3
0
def population(request):
    cfg = getattr(request, "param", {})
    init_dirs = cfg.get("init_dirs", None)
    scfg = cfg.pop("search_space_cfg", {})
    s_type = cfg.pop("search_space_type", "cnn")
    cfg_template = cfg.pop("cfg_template", sample_config)
    from aw_nas.common import get_search_space
    search_space = get_search_space(s_type, **scfg)
    if init_dirs:
        population = Population.init_from_dirs(init_dirs, search_space)
    else:
        population = StubPopulation(search_space,
                                    num_records=cfg.get("num_records", 3),
                                    config_template=cfg_template)
    return population
Exemplo n.º 4
0
    def __init__(
            self,
            search_space,
            device,
            rollout_type="mutation",
            mode="eval",
            score_func="choose('acc')",
            population_dirs=[],
            result_population_dir=None,
            num_mutations_per_child=1,
            # choose parent
            parent_pool_size=25,
            mutation_sampler_type="random",
            mutation_sampler_cfg=None):
        super(PopulationController, self).__init__(search_space, rollout_type,
                                                   mode)

        self.device = device
        expect(population_dirs,
               "Config `population_dirs` should not be empty.",
               ConfigException)
        expect(result_population_dir,
               "Config `result_population_dir` must be given.",
               ConfigException)
        self.result_population_dir = result_population_dir
        self.num_mutations_per_child = num_mutations_per_child
        self.parent_pool_size = parent_pool_size

        self.population = Population.init_from_dirs(population_dirs,
                                                    self.search_space)
        self.score_func = self._get_score_func(score_func)
        self.indexes, self.scores = self._init_indexes_and_scores(
            self.population, self.score_func)

        ms_cls = BaseMutationSampler.get_class_(mutation_sampler_type)
        self.mutation_sampler = ms_cls(self.search_space, self.population,
                                       self.device,
                                       **(mutation_sampler_cfg or {}))