Ejemplo n.º 1
0
    def init_from_dirs(cls, dirs, search_space=None, cfg_template_file=None):
        """
        Init population from directories.

        Args:
          dirs: [directory paths]
          search_space: SearchSpace
          cfg_template_file: if not specified, default: "template.yaml" under `dirs[0]`
        Returns: Population

        There should be multiple meta-info (yaml) files named as "`<number>.yaml` under each
        directory, each of them specificy the meta information for a model in the population,
        with `<number>` represent its index.
        Note there should not be duplicate index, if there are duplicate index,
        rename or soft-link the files.

        In each meta-info file, the possible meta informations are:
        * genotype
        * train_config
        * checkpoint_path
        * (optional) confidence
        * perfs: a dict of performance name to performance value

        "template.yaml" under the first dir will be used as the template training config for
        new candidate model
        """
        assert dirs, "No dirs specified!"
        if cfg_template_file is None:
            cfg_template_file = os.path.join(dirs[0], "template.yaml")
        with open(cfg_template_file, "r") as cfg_f:
            cfg_template = ConfigTemplate(yaml.safe_load(cfg_f))
        _logger.getChild("population").info("Read the template config from %s",
                                            cfg_template_file)
        model_records = collections.OrderedDict()
        if search_space is None:
            # assume can parse search space from config template
            from aw_nas.common import get_search_space
            search_space = get_search_space(cfg_template["search_space_type"],
                                            **cfg_template["search_space_cfg"])
        for _, dir_ in enumerate(dirs):
            meta_files = glob.glob(os.path.join(dir_, "*.yaml"))
            for fname in meta_files:
                if "template.yaml" in fname:
                    # do not parse template.yaml
                    continue
                index = int(os.path.basename(fname).rsplit(".", 1)[0])
                expect(
                    index not in model_records,
                    "There are duplicate index: {}. rename or soft-link the files"
                    .format(index))
                model_records[index] = ModelRecord.init_from_file(
                    fname, search_space)
        _logger.getChild("population").info(
            "Parsed %d directories, total %d model records loaded.", len(dirs),
            len(model_records))
        return Population(search_space, model_records, cfg_template)
Ejemplo n.º 2
0
    def random_sample(cls,
                      population,
                      parent_index,
                      num_mutations=1,
                      primitive_prob=0.5):
        """
        Random sample a MutationRollout with mutations.

        Duplication is checked for multiple mutations.
        """
        search_space = population.search_space
        base_arch = search_space.rollout_from_genotype(
            population.get_model(parent_index).genotype).arch

        mutations = []
        primitive_choices = collections.defaultdict(list)
        primitive_mutated = collections.defaultdict(int)
        node_choices = collections.defaultdict(list)
        node_mutated = collections.defaultdict(int)
        for _ in range(num_mutations):
            mutation_type = CellMutation.PRIMITIVE if np.random.rand() < primitive_prob \
                else CellMutation.NODE
            cell = np.random.randint(low=0, high=search_space.num_cell_groups)
            step = np.random.randint(low=0, high=search_space.num_steps)
            connection = np.random.randint(low=0,
                                           high=search_space.num_node_inputs)
            if mutation_type == CellMutation.PRIMITIVE:
                # modify primitive on the connection
                if (cell, step, connection) in primitive_choices:
                    choices = primitive_choices[(cell, step, connection)]
                else:
                    ori = base_arch[cell][1][search_space.num_node_inputs *
                                             step + connection]
                    num_prims = search_space._num_primitives \
                                if not search_space.cellwise_primitives \
                                   else search_space._num_primitives_list[cell]
                    choices = list(range(num_prims))
                    choices.remove(ori)
                    primitive_choices[(cell, step, connection)] = choices
                expect(
                    choices,
                    ("There are no non-duplicate primitive mutation available"
                     " anymore for ({}, {}, {}) after {} mutations").format(
                         cell, step, connection,
                         primitive_mutated[(cell, step, connection)]))
                new_choice = np.random.choice(choices)
                choices.remove(new_choice)
                base_arch[cell][1][search_space.num_node_inputs * step +
                                   connection] = new_choice
                primitive_mutated[(cell, step, connection)] += 1
            else:
                # modify input node
                if (cell, step, connection) in node_choices:
                    choices = node_choices[(cell, step, connection)]
                else:
                    ori = base_arch[cell][0][search_space.num_node_inputs *
                                             step + connection]
                    choices = list(range(search_space.num_init_nodes + step))
                    choices.remove(ori)
                    node_choices[(cell, step, connection)] = choices
                expect(
                    choices,
                    ("There are no non-duplicate input node mutation available"
                     " anymore for ({}, {}, {}) after {} mutations").format(
                         cell, step, connection,
                         node_mutated[(cell, step, connection)]))
                new_choice = np.random.choice(choices)
                choices.remove(new_choice)
                base_arch[cell][0][search_space.num_node_inputs * step +
                                   connection] = new_choice
                node_mutated[(cell, step, connection)] += 1
            mutations.append(
                CellMutation(search_space,
                             mutation_type,
                             cell,
                             step,
                             connection,
                             modified=new_choice))
        return cls(population, parent_index, mutations, search_space)