def init_from_dirs(cls, dirs, search_space=None, cfg_template_file=None): """ Init population from directories. Args: dirs: [directory paths] search_space: SearchSpace cfg_template_file: if not specified, default: "template.yaml" under `dirs[0]` Returns: Population There should be multiple meta-info (yaml) files named as "`<number>.yaml` under each directory, each of them specificy the meta information for a model in the population, with `<number>` represent its index. Note there should not be duplicate index, if there are duplicate index, rename or soft-link the files. In each meta-info file, the possible meta informations are: * genotype * train_config * checkpoint_path * (optional) confidence * perfs: a dict of performance name to performance value "template.yaml" under the first dir will be used as the template training config for new candidate model """ assert dirs, "No dirs specified!" if cfg_template_file is None: cfg_template_file = os.path.join(dirs[0], "template.yaml") with open(cfg_template_file, "r") as cfg_f: cfg_template = ConfigTemplate(yaml.safe_load(cfg_f)) _logger.getChild("population").info("Read the template config from %s", cfg_template_file) model_records = collections.OrderedDict() if search_space is None: # assume can parse search space from config template from aw_nas.common import get_search_space search_space = get_search_space(cfg_template["search_space_type"], **cfg_template["search_space_cfg"]) for _, dir_ in enumerate(dirs): meta_files = glob.glob(os.path.join(dir_, "*.yaml")) for fname in meta_files: if "template.yaml" in fname: # do not parse template.yaml continue index = int(os.path.basename(fname).rsplit(".", 1)[0]) expect( index not in model_records, "There are duplicate index: {}. rename or soft-link the files" .format(index)) model_records[index] = ModelRecord.init_from_file( fname, search_space) _logger.getChild("population").info( "Parsed %d directories, total %d model records loaded.", len(dirs), len(model_records)) return Population(search_space, model_records, cfg_template)
def random_sample(cls, population, parent_index, num_mutations=1, primitive_prob=0.5): """ Random sample a MutationRollout with mutations. Duplication is checked for multiple mutations. """ search_space = population.search_space base_arch = search_space.rollout_from_genotype( population.get_model(parent_index).genotype).arch mutations = [] primitive_choices = collections.defaultdict(list) primitive_mutated = collections.defaultdict(int) node_choices = collections.defaultdict(list) node_mutated = collections.defaultdict(int) for _ in range(num_mutations): mutation_type = CellMutation.PRIMITIVE if np.random.rand() < primitive_prob \ else CellMutation.NODE cell = np.random.randint(low=0, high=search_space.num_cell_groups) step = np.random.randint(low=0, high=search_space.num_steps) connection = np.random.randint(low=0, high=search_space.num_node_inputs) if mutation_type == CellMutation.PRIMITIVE: # modify primitive on the connection if (cell, step, connection) in primitive_choices: choices = primitive_choices[(cell, step, connection)] else: ori = base_arch[cell][1][search_space.num_node_inputs * step + connection] num_prims = search_space._num_primitives \ if not search_space.cellwise_primitives \ else search_space._num_primitives_list[cell] choices = list(range(num_prims)) choices.remove(ori) primitive_choices[(cell, step, connection)] = choices expect( choices, ("There are no non-duplicate primitive mutation available" " anymore for ({}, {}, {}) after {} mutations").format( cell, step, connection, primitive_mutated[(cell, step, connection)])) new_choice = np.random.choice(choices) choices.remove(new_choice) base_arch[cell][1][search_space.num_node_inputs * step + connection] = new_choice primitive_mutated[(cell, step, connection)] += 1 else: # modify input node if (cell, step, connection) in node_choices: choices = node_choices[(cell, step, connection)] else: ori = base_arch[cell][0][search_space.num_node_inputs * step + connection] choices = list(range(search_space.num_init_nodes + step)) choices.remove(ori) node_choices[(cell, step, connection)] = choices expect( choices, ("There are no non-duplicate input node mutation available" " anymore for ({}, {}, {}) after {} mutations").format( cell, step, connection, node_mutated[(cell, step, connection)])) new_choice = np.random.choice(choices) choices.remove(new_choice) base_arch[cell][0][search_space.num_node_inputs * step + connection] = new_choice node_mutated[(cell, step, connection)] += 1 mutations.append( CellMutation(search_space, mutation_type, cell, step, connection, modified=new_choice)) return cls(population, parent_index, mutations, search_space)