示例#1
0
文件: common.py 项目: zzzDavid/aw_nas
 def mutate(self, rollout, node_mutate_prob=0.5):
     new_arch = copy.deepcopy(rollout.arch)
     # randomly select a cell gorup to modify
     mutate_i_cg = np.random.randint(0, self.num_cell_groups)
     num_prims = self._num_primitives \
                 if not self.cellwise_primitives else self._num_primitives_list[mutate_i_cg]
     _num_step = self.get_num_steps(mutate_i_cg)
     if np.random.random() < node_mutate_prob:
         # mutate connection
         # if #cell init nodes is 1, no need to mutate the connection to node 1
         start = int(self.num_init_nodes == 1) * self.num_node_inputs
         node_mutate_idx = np.random.randint(start,
                                             len(new_arch[mutate_i_cg][0]))
         i_out = node_mutate_idx // self.num_node_inputs
         out_node = i_out + self.num_init_nodes
         offset = np.random.randint(1, out_node)
         new_arch[mutate_i_cg][0][node_mutate_idx] = \
                 (new_arch[mutate_i_cg][0][node_mutate_idx] + offset) % out_node
     else:
         # mutate op
         op_mutate_idx = np.random.randint(0, len(new_arch[mutate_i_cg][1]))
         offset = np.random.randint(1, num_prims)
         new_arch[mutate_i_cg][0][op_mutate_idx] = \
                 (new_arch[mutate_i_cg][0][op_mutate_idx] + offset) % num_prims
     return Rollout(new_arch, info={}, search_space=self)
示例#2
0
文件: common.py 项目: zzzDavid/aw_nas
 def random_sample(self):
     """
     Random sample a discrete architecture.
     """
     return Rollout(Rollout.random_sample_arch(
         self.num_cell_groups, self.num_steps, self.num_init_nodes,
         self.num_node_inputs, self._num_primitives
         if not self.cellwise_primitives else self._num_primitives_list),
                    info={},
                    search_space=self)
示例#3
0
文件: common.py 项目: zzzDavid/aw_nas
 def rollout_from_genotype(self, genotype):
     genotype_list = list(genotype._asdict().values())
     assert len(genotype_list) == 2 * self.num_cell_groups
     arch = []
     for i in range(self.num_cell_groups):
         cell_geno = genotype_list[2 * i]
         nodes, ops = [], []
         for i_out in range(self.num_steps):
             for i_in in range(self.num_node_inputs):
                 conn = cell_geno[i_out * self.num_node_inputs + i_in]
                 ops.append(self.shared_primitives.index(conn[0]))
                 nodes.append(conn[1])
         cg_arch = [nodes, ops]
         arch.append(cg_arch)
     return Rollout(arch, {}, self)
示例#4
0
    def rollout_from_genotype(self, genotype):
        genotype_list = list(genotype._asdict().values())
        assert len(genotype_list) == 2 * self.num_cell_groups
        conn_list = genotype_list[:self.num_cell_groups]

        arch = []
        for i_cg, cell_geno in enumerate(conn_list):
            nodes, ops = [], []
            num_steps = self.get_num_steps(i_cg)
            for i_out in range(num_steps):
                for i_in in range(self.num_node_inputs):
                    conn = cell_geno[i_out * self.num_node_inputs + i_in]
                    ops.append(self.cell_shared_primitives[i_cg].index(conn[0]))
                    nodes.append(conn[1])
            cg_arch = [nodes, ops]
            arch.append(cg_arch)
        return Rollout(arch, {}, self)
示例#5
0
文件: common.py 项目: zeta1999/aw_nas
 def mutate(self, rollout, node_mutate_prob=0.5):
     new_arch = copy.deepcopy(rollout.arch)
     mutate_i_cg = np.random.randint(0, self.num_cell_groups)
     num_prims = self._num_primitives if not self.cellwise_primitives else self._num_primitives_list[
         mutate_i_cg]
     _num_step = self.get_num_steps(mutate_i_cg)
     if np.random.random() < node_mutate_prob:
         node_mutate_idx = np.random.randint(0,
                                             len(new_arch[mutate_i_cg][0]))
         i_out = node_mutate_idx // self.num_init_nodes
         new_node = np.random.randint(0, i_out + self.num_init_nodes)
         while new_node == new_arch[mutate_i_cg][0][node_mutate_idx]:
             new_node = np.random.randint(0, i_out + self.num_init_nodes)
         new_arch[mutate_i_cg][0][node_mutate_idx] = new_node
     else:
         op_mutate_idx = np.random.randint(0, len(new_arch[mutate_i_cg][1]))
         new_prim = np.random.randint(0, num_prims)
         while new_prim == new_arch[mutate_i_cg][0][op_mutate_idx]:
             new_prim = np.random.randint(0, num_prims)
         new_arch[mutate_i_cg][0][op_mutate_idx] = new_prim
     return Rollout(new_arch, info={}, search_space=self)