def test_ss(): from aw_nas.common import get_search_space, genotype_from_str ss = get_search_space("dense_rob", cell_layout=[0, 0, 1, 0, 0, 1, 0, 0], reduce_cell_groups=[1]) rollout = ss.random_sample() print(rollout.genotype) rollout_rec = ss.rollout_from_genotype(rollout.genotype) genotype_str = str(rollout.genotype) genotype_rec = genotype_from_str(genotype_str, ss) assert genotype_rec == rollout.genotype # test msrobnet-1560M and its search space ss = get_search_space( "dense_rob", num_cell_groups=8, num_init_nodes=2, cell_layout=[0, 1, 2, 3, 4, 5, 6, 7], reduce_cell_groups=[2, 5], num_steps=4, primitives=["none", "skip_connect", "sep_conv_3x3", "ResSepConv"], ) genotype_str = "DenseRobGenotype(normal_0='init_node~2+|skip_connect~0|sep_conv_3x3~1|+|none~0|skip_connect~1|skip_connect~2|+|none~0|sep_conv_3x3~1|ResSepConv~2|skip_connect~3|+|skip_connect~0|sep_conv_3x3~1|sep_conv_3x3~2|sep_conv_3x3~3|sep_conv_3x3~4|', normal_1='init_node~2+|ResSepConv~0|sep_conv_3x3~1|+|none~0|sep_conv_3x3~1|skip_connect~2|+|ResSepConv~0|ResSepConv~1|ResSepConv~2|none~3|+|none~0|skip_connect~1|sep_conv_3x3~2|sep_conv_3x3~3|skip_connect~4|', reduce_2='init_node~2+|ResSepConv~0|skip_connect~1|+|sep_conv_3x3~0|none~1|none~2|+|skip_connect~0|ResSepConv~1|sep_conv_3x3~2|none~3|+|ResSepConv~0|skip_connect~1|ResSepConv~2|none~3|skip_connect~4|', normal_3='init_node~2+|skip_connect~0|skip_connect~1|+|ResSepConv~0|none~1|ResSepConv~2|+|ResSepConv~0|none~1|ResSepConv~2|skip_connect~3|+|none~0|sep_conv_3x3~1|none~2|skip_connect~3|sep_conv_3x3~4|', normal_4='init_node~2+|ResSepConv~0|sep_conv_3x3~1|+|skip_connect~0|skip_connect~1|none~2|+|sep_conv_3x3~0|skip_connect~1|sep_conv_3x3~2|sep_conv_3x3~3|+|ResSepConv~0|ResSepConv~1|none~2|skip_connect~3|sep_conv_3x3~4|', reduce_5='init_node~2+|sep_conv_3x3~0|sep_conv_3x3~1|+|ResSepConv~0|sep_conv_3x3~1|ResSepConv~2|+|none~0|sep_conv_3x3~1|ResSepConv~2|sep_conv_3x3~3|+|ResSepConv~0|skip_connect~1|skip_connect~2|skip_connect~3|sep_conv_3x3~4|', normal_6='init_node~2+|none~0|ResSepConv~1|+|none~0|sep_conv_3x3~1|ResSepConv~2|+|skip_connect~0|ResSepConv~1|ResSepConv~2|none~3|+|ResSepConv~0|skip_connect~1|ResSepConv~2|none~3|sep_conv_3x3~4|', normal_7='init_node~2+|none~0|none~1|+|none~0|ResSepConv~1|none~2|+|sep_conv_3x3~0|skip_connect~1|ResSepConv~2|none~3|+|ResSepConv~0|skip_connect~1|skip_connect~2|none~3|ResSepConv~4|')" rec_genotype = genotype_from_str(genotype_str, ss) # test a genotype does not fit for this search space with pytest.raises(TypeError): genotype_str = "DenseRobGenotype(normal_0='init_node~1+|sep_conv_3x3~0|+|skip_connect~0|none~1|+|sep_conv_3x3~0|sep_conv_3x3~1|skip_connect~2|+|skip_connect~0|skip_connect~1|sep_conv_3x3~2|none~3|', reduce_1='init_node~1+|none~0|+|none~0|sep_conv_3x3~1|+|none~0|none~1|sep_conv_3x3~2|+|sep_conv_3x3~0|none~1|none~2|none~3|', normal_2='init_node~1+|sep_conv_3x3~0|+|sep_conv_3x3~0|skip_connect~1|+|skip_connect~0|none~1|skip_connect~2|+|skip_connect~0|sep_conv_3x3~1|sep_conv_3x3~2|skip_connect~3|', reduce_3='init_node~1+|skip_connect~0|+|skip_connect~0|skip_connect~1|+|none~0|sep_conv_3x3~1|sep_conv_3x3~2|+|skip_connect~0|skip_connect~1|sep_conv_3x3~2|skip_connect~3|', normal_4='init_node~1+|sep_conv_3x3~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|sep_conv_3x3~1|sep_conv_3x3~2|+|sep_conv_3x3~0|none~1|skip_connect~2|skip_connect~3|')" rec_genotype = genotype_from_str(genotype_str, ss)
def genotype_from_str(self, genotype_str): match = re.search(r"\((.+Genotype\(.+\)), (.+Genotype\(.+\))\)", genotype_str) macro_genotype_str = match.group(1) micro_genotype_str = match.group(2) return ( genotype_from_str(macro_genotype_str, self.macro_search_space), genotype_from_str(micro_genotype_str, self.micro_search_space), )
def __setstate__(self, state): super(ParetoEvoController, self).__setstate__(state) self.population = { genotype_from_str(k, self.search_space): v for k, v in state["population"].items() } self.gt_population = { genotype_from_str(k, self.search_space): v for k, v in state["gt_population"].items() }
def load(self, path): state = torch.load(path, map_location=torch.device("cpu")) self.epoch = state["epoch"] self.population = { genotype_from_str(k, self.search_space): v for k, v in state["population"].items() } self.gt_population = { genotype_from_str(k, self.search_space): v for k, v in state["gt_population"].items() } self._start_pareto_sample = state["_start_pareto_sample"]
def genotype_from_str(self, genotype_str): matched = re.match(self.genotype_str_pattern, genotype_str) inner_genotypes = [ genotype_from_str(matched.group(i + 1), self.inner_search_space) for i in range(self.ensemble_size) ] return self.genotype_type(*inner_genotypes)
def rollout_from_genotype(self, genotype): if isinstance(genotype, str): genotype = genotype_from_str(genotype, self) genotype_list = list(genotype._asdict().values()) depth = genotype[:len(self.num_cell_groups)] width = [] kernel = [] ind = len(self.num_cell_groups) for i, max_depth in zip(depth, self.num_cell_groups): width_list = [] kernel_list = [] for j in range(max_depth): if j < i: try: width_list.append(genotype[ind][0]) kernel_list.append(genotype[ind][1]) except Exception: width_list.append(genotype[ind]) kernel_list.append(3) ind += 1 width.append(width_list) kernel.append(kernel_list) arch = {"depth": depth, "width": width, "kernel": kernel} return MNasNetOFARollout(arch, {}, self)
def test_layer2_ss(tmp_path): from aw_nas.common import get_search_space, genotype_from_str, rollout_from_genotype_str ss = get_search_space("layer2", macro_search_space_cfg={ "num_cell_groups": 2, "cell_layout": [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], "width_choice": [0.25, 0.5, 0.75, 1.0], "reduce_cell_groups": [1] }, micro_search_space_cfg={ "num_cell_groups": 2, "num_steps": 4 }) rollout = ss.random_sample() print(rollout.genotype) rollout_rec = ss.rollout_from_genotype(rollout.genotype) assert rollout_rec == rollout # genotype from str genotype = genotype_from_str(str(rollout.genotype), ss) rollout_rec2 = rollout_from_genotype_str(str(rollout.genotype), ss) assert rollout_rec2 == rollout # plot path = os.path.join(str(tmp_path), "layer2") rollout.plot_arch(path, label="layer2 rollout example") print("Plot save to path: ", path)
def add_model(self, model_record, index=None): index = self._next_index if index is None else index self.model_records[index] = model_record self.genotype_records[index] = genotype_from_str( model_record.genotype, self.search_space) self._next_index += 1 self._size += 1 return index
def __init__(self, search_space, device, genotypes, num_classes=10, dropout_rate=0.0, dropblock_rate=0.0, schedule_cfg=None): super(DenseGenotypeModel, self).__init__(schedule_cfg) self.search_space = search_space self.device = device assert isinstance(genotypes, str) self.genotypes = list( genotype_from_str(genotypes, self.search_space)._asdict().values()) self.num_classes = num_classes # training self.dropout_rate = dropout_rate self.dropblock_rate = dropblock_rate self._num_blocks = self.search_space.num_dense_blocks # build model self.stem = nn.Conv2d(3, self.genotypes[0], kernel_size=3, padding=1) self.dense_blocks = [] self.trans_blocks = [] last_channel = self.genotypes[0] for i_block in range(self._num_blocks): growths = self.genotypes[1 + i_block * 2] self.dense_blocks.append( self._new_dense_block(last_channel, growths)) last_channel = int(last_channel + np.sum(growths)) if i_block != self._num_blocks - 1: out_c = self.genotypes[2 + i_block * 2] self.trans_blocks.append( self._new_transition_block(last_channel, out_c)) last_channel = out_c self.dense_blocks = nn.ModuleList(self.dense_blocks) self.trans_blocks = nn.ModuleList(self.trans_blocks) self.final_bn = nn.BatchNorm2d(last_channel) self.final_relu = nn.ReLU() self.global_pooling = nn.AdaptiveAvgPool2d(1) if self.dropout_rate and self.dropout_rate > 0: self.dropout = nn.Dropout(p=self.dropout_rate) else: self.dropout = ops.Identity() self.classifier = nn.Linear(last_channel, self.num_classes) self.to(self.device) # for flops calculation self.total_flops = 0 self._flops_calculated = False self.set_hook()
def finalize(self, genotypes, filter_regex=None): assert isinstance(genotypes, str), \ "Type str excepted, got {} instead.".format(type(genotypes)) genotypes = list( genotype_from_str(genotypes, self.search_space)._asdict().values()) depth, width, kernel = self.parse(genotypes) self.backbone = self.backbone.finalize(depth, width, kernel) return self
def test_morphism(population, tmp_path): from aw_nas.rollout.mutation import MutationRollout, ConfigTemplate, ModelRecord, CellMutation from aw_nas.final import CNNGenotypeModel from aw_nas.main import _init_component from aw_nas.common import genotype_from_str, rollout_from_genotype_str from aw_nas.weights_manager import MorphismWeightsManager cfg = yaml.safe_load(SAMPLE_MODEL_CFG) device = "cuda:0" search_space = population.search_space genotype_str = ("normal_0=[('sep_conv_5x5', 0, 2), ('sep_conv_3x3', 1, 2), " "('sep_conv_3x3', 2, 3), ('none', 2, 3)], " "reduce_1=[('sep_conv_5x5', 0, 2), ('none', 0, 2), " "('sep_conv_5x5', 0, 3), ('sep_conv_5x5', 1, 3)]") parent_rollout = rollout_from_genotype_str(genotype_str, search_space) cfg["final_model_cfg"]["genotypes"] = genotype_str cnn_model = _init_component(cfg, "final_model", search_space=search_space, device=device) parent_state_dict = cnn_model.state_dict() torch.save(cnn_model, os.path.join(tmp_path, "test")) # add this record to the population new_model_record = ModelRecord( genotype_from_str(cfg["final_model_cfg"]["genotypes"], cnn_model.search_space), cfg, cnn_model.search_space, info_path=os.path.join(tmp_path, "test.yaml"), checkpoint_path=os.path.join(tmp_path, "test"), finished=True, confidence=1, perfs={"acc": np.random.rand(), "loss": np.random.uniform(0, 10)}) parent_index = population.add_model(new_model_record) mutation = CellMutation(search_space, CellMutation.PRIMITIVE, cell=0, step=0, connection=1, modified=search_space.shared_primitives.index("sep_conv_5x5")) print("mutation: ", mutation) rollout = MutationRollout(population, parent_index, [mutation], search_space) assert rollout.genotype != cnn_model.genotypes w_manager = MorphismWeightsManager(search_space, device, "mutation") cand_net = w_manager.assemble_candidate(rollout) child_state_dict = cand_net.state_dict() layers = [i_layer for i_layer, cg_id in enumerate(search_space.cell_layout) if cg_id == mutation.cell] removed_edges = ["cells.{layer}.edge_mod.f_1_t_2-sep_conv_3x3-0".format( layer=layer) for layer in layers] added_edges = ["cells.{layer}.edge_mod.f_1_t_2-sep_conv_5x5-0".format( layer=layer) for layer in layers] for n, v in six.iteritems(child_state_dict): if n not in added_edges: assert n in parent_state_dict assert (parent_state_dict[n].data.cpu().numpy() == v.data.cpu().numpy()).all() for n in removed_edges: assert n not in child_state_dict
def rollout_from_genotype(self, genotype): if isinstance(genotype, str): genotype = genotype_from_str(genotype, self) image_size, depth, width, kernel = self.parse(genotype) arch = { "image_size": image_size, "depth": depth, "width": width, "kernel": kernel } return MNasNetOFARollout(arch, {}, self)
def init_from_file(cls, path, search_space): with open(path, "r") as meta_f: meta_info = yaml.safe_load(meta_f) record = cls(str( genotype_from_str(meta_info["genotypes"], search_space)), meta_info["config"], search_space, os.path.abspath(path), meta_info["checkpoint_path"], finished=meta_info["finished"], confidence=meta_info.get("confidence", None), perfs=meta_info["perfs"]) return record
def __setstate__(self, state): super(EvoController, self).__setstate__(state) self.population = { genotype_from_str(k, self.search_space): v for k, v in state["population"].items() } self._gt_rollouts = [] for r in state["gt_rollouts"]: rollout = self.search_space.random_sample() if hasattr(rollout, "__setstate__"): rollout.__setstate__(r) else: rollout.__dict__.update(r) self._gt_rollouts.append(rollout)
def load(self, path): state = torch.load(path, map_location=torch.device("cpu")) self.epoch = state["epoch"] self.population = { genotype_from_str(k, self.search_space): v for k, v in state["population"].items() } self._gt_rollouts = [] for r in state["gt_rollouts"]: rollout = self.search_space.random_sample() rollout.__setstate__(r) self._gt_rollouts.append(rollout) self._gt_scores = state["gt_scores"]
def test_dense_dynamic_rollout(): from aw_nas.common import get_search_space, genotype_from_str ss = get_search_space("cnn_dense", dynamic_transition=True, reduction=None, transition_channels=None, num_dense_blocks=4) arch = ([[4] * 6, [4] * 12, [4] * 24, [4] * 6], [18, 20, 30]) genotype = ss.genotype(arch) for i, trans_c in enumerate(arch[1]): assert getattr(genotype, "transition_{}".format(i)) == trans_c assert genotype_from_str(str(genotype), ss) == genotype print(genotype) print("relative conv flops: ", ss.relative_conv_flops(arch))
def test_dense_morphism_wider(population, tmp_path): from aw_nas.rollout.mutation import ConfigTemplate, ModelRecord, CellMutation from aw_nas.rollout.dense import DenseMutationRollout, DenseMutation from aw_nas.final import DenseGenotypeModel from aw_nas.main import _init_component from aw_nas.common import genotype_from_str, rollout_from_genotype_str from aw_nas.weights_manager import DenseMorphismWeightsManager cfg = yaml.safe_load(DENSE_SAMPLE_MODEL_CFG) device = "cuda:0" search_space = population.search_space genotype_str = ("stem=8, block_0=[4, 4, 4, 4, 4, 4], transition_0=16, " "block_1=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], transition_1=32, " "block_2=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, " "4, 4, 4, 4, 4], transition_2=64, block_3=[4, 4, 4, 4, 4, 4]") parent_rollout = rollout_from_genotype_str(genotype_str, search_space) cfg["final_model_cfg"]["genotypes"] = genotype_str cnn_model = _init_component(cfg, "final_model", search_space=search_space, device=device) parent_state_dict = cnn_model.state_dict() torch.save(cnn_model, os.path.join(tmp_path, "test")) # add this record to the population new_model_record = ModelRecord( genotype_from_str(cfg["final_model_cfg"]["genotypes"], cnn_model.search_space), cfg, cnn_model.search_space, info_path=os.path.join(tmp_path, "test.yaml"), checkpoint_path=os.path.join(tmp_path, "test"), finished=True, confidence=1, perfs={"acc": np.random.rand(), "loss": np.random.uniform(0, 10)}) parent_index = population.add_model(new_model_record) mutation = DenseMutation(search_space, DenseMutation.WIDER, block_idx=1, miniblock_idx=0, modified=8) rollout = DenseMutationRollout(population, parent_index, [mutation], search_space) assert rollout.genotype != cnn_model.genotypes w_manager = DenseMorphismWeightsManager(search_space, device, "dense_mutation") cand_net = w_manager.assemble_candidate(rollout) cand_net.eval() child_state_dict = cand_net.state_dict() data = _cnn_data() logits = cand_net.forward(data[0]) origin_net = torch.load(rollout.population.get_model(rollout.parent_index).checkpoint_path) origin_net.eval() logits_ori = origin_net.forward(data[0]) assert (logits - logits_ori).abs().mean() < 1e-6
def rollout_from_genotype(self, genotype): if isinstance(genotype, str): genotype = genotype_from_str(genotype, self) image_size, depth, width, kernel = self.parse( genotype[:-self.num_head]) head_width, head_kernel = list(zip(*genotype[-self.num_head:])) arch = { "image_size": image_size, "depth": depth, "width": width, "kernel": kernel, "head_width": head_width, "head_kernel": head_kernel } return SSDOFARollout(arch, {}, self)
def __init__(self, search_space, device, genotypes, schedule_cfg=None): super(GeneralGenotypeModel, self).__init__(schedule_cfg) self.search_space = search_space self.device = device if isinstance(genotypes, str): self.genotypes = list( genotype_from_str(genotypes, self.search_space)) else: self.genotypes = copy.deepcopy(genotypes) model = [] for geno in copy.deepcopy(self.genotypes): op = geno.pop("prim_type") geno.pop("spatial_size") model += [get_op(op)(**geno)] self.model = nn.ModuleList(model) self.model.to(self.device)
def test_genprof(case): from aw_nas.common import get_search_space, genotype_from_str from aw_nas.hardware.base import MixinProfilingSearchSpace from aw_nas.hardware.utils import Prim, assemble_profiling_nets from aw_nas.rollout.general import GeneralSearchSpace ss = get_search_space("ofa_mixin", **case["search_space_cfg"]) assert isinstance(ss, MixinProfilingSearchSpace) primitives = ss.generate_profiling_primitives(**case["prof_prim_cfg"]) cfg = case["search_space_cfg"] assert (len(primitives) == len(cfg["width_choice"]) * len(cfg["kernel_choice"]) * len(cfg["num_cell_groups"][1:]) * 2 + 2) fields = {f for f in Prim._fields if not f == "kwargs"} for prim in primitives: assert isinstance(prim, dict) assert fields.issubset(prim.keys()) base_cfg = {"final_model_cfg": {}} nets = assemble_profiling_nets(primitives, base_cfg, image_size=224) gss = GeneralSearchSpace([]) counts = 0 for net in nets: genotype = net["final_model_cfg"]["genotypes"] assert isinstance(genotype, (list, str)) if isinstance(genotype, str): genotype = genotype_from_str(genotype, gss) counts += len(genotype) is_channel_consist = [ c["C_out"] == n["C"] for c, n in zip(genotype[:-1], genotype[1:]) ] assert all(is_channel_consist) spatial_size = [g["spatial_size"] for g in genotype] stride = [g["stride"] for g in genotype] is_size_consist = [ round(c_size / s) == n_size for s, c_size, n_size in zip( stride, spatial_size[:-1], spatial_size[1:]) ] assert all(is_size_consist) assert counts >= len(primitives)
def __init__(self, search_space, model_records, cfg_template, next_index=None): super(Population, self).__init__(schedule_cfg=None) self.search_space = search_space self._model_records = model_records self.genotype_records = collections.OrderedDict([ (ind, genotype_from_str(record.genotype, self.search_space)) for ind, record in six.iteritems(self._model_records) ]) self._size = len( model_records ) # _size will be adjusted along with self._model_records self.cfg_template = cfg_template if next_index is None: self._next_index = np.max(list( model_records.keys())) + 1 if model_records else 0 else: self._next_index = next_index self.start_save_index = self._next_index
def __init__( self, search_space, # layer2 device, genotypes, # layer2 micro_model_type="micro-dense-model", micro_model_cfg={}, num_classes=10, init_channels=36, stem_multiplier=1, dropout_rate=0.0, dropout_path_rate=0.0, use_stem="conv_bn_3x3", stem_stride=1, stem_affine=True, auxiliary_head=False, auxiliary_cfg=None, schedule_cfg=None, ): super(MacroStagewiseFinalModel, self).__init__(schedule_cfg) self.macro_ss = search_space.macro_search_space self.micro_ss = search_space.micro_search_space self.device = device assert isinstance(genotypes, str) self.genotypes_str = genotypes self.macro_g, self.micro_g = genotype_from_str(genotypes, search_space) # micro model (cell) class micro_model_cls = FinalModel.get_class_(micro_model_type) # cell type self.num_classes = num_classes self.init_channels = init_channels self.stem_multiplier = stem_multiplier self.stem_stride = stem_stride self.stem_affine = stem_affine self.use_stem = use_stem # training self.dropout_rate = dropout_rate self.dropout_path_rate = dropout_path_rate self.auxiliary_head = auxiliary_head self.overall_adj = self.macro_ss.parse_overall_adj(self.macro_g) self.layer_widths = [float(w) for w in self.macro_g.width.split(",")] # sort channels out assert self.stem_multiplier == 1, "Cannot handle stem_multiplier != 1 now" self.input_channel_list = [self.init_channels] for i in range(1, self.macro_ss.num_layers): self.input_channel_list.append( self.input_channel_list[i - 1] * 2 if self._is_reduce(i - 1) else self.input_channel_list[i - 1]) for i in range(self.macro_ss.num_layers): self.input_channel_list[i] = int( self.input_channel_list[i] * self.layer_widths[i] if not self._is_reduce(i) else self.input_channel_list[i] * self.layer_widths[i - 1]) self.output_channel_list = self.input_channel_list[1:] + [ self.input_channel_list[-1] ] # construct cells if not self.use_stem: raise NotImplementedError c_stem = 3 elif isinstance(self.use_stem, (list, tuple)): raise NotImplementedError self.stems = [] c_stem = self.stem_multiplier * self.init_channels for i, stem_type in enumerate(self.use_stem): c_in = 3 if i == 0 else c_stem self.stems.append( ops.get_op(stem_type)(c_in, c_stem, stride=stem_stride, affine=stem_affine)) self.stem = nn.Sequential(self.stems) else: self.stem = ops.get_op(self.use_stem)(3, self.input_channel_list[0], stride=stem_stride, affine=stem_affine) self.cells = nn.ModuleList() self.micro_arch_list = self.micro_ss.rollout_from_genotype( self.micro_g).arch for i_layer in range(self.macro_ss.num_layers): # print(i_layer, self._is_reduce(i_layer)) stride = 2 if self._is_reduce(i_layer) else 1 cg_idx = self.macro_ss.cell_layout[i_layer] # contruct micro cell # FIXME: Currently MacroStageWiseFinalModel doesnot support postprocess = False micro_model_cfg["postprocess"] = True cell = micro_model_cls( self.micro_ss, self.micro_arch_list[cg_idx], num_input_channels=self.input_channel_list[i_layer], num_out_channels=self.output_channel_list[i_layer], stride=stride, **micro_model_cfg) # assume non-reduce cell does not change channel number self.cells.append(cell) # add auxiliary head if i_layer == ( 2 * self.macro_ss.num_layers) // 3 and self.auxiliary_head: if auxiliary_head == "imagenet": self.auxiliary_net = AuxiliaryHeadImageNet( self.output_channel_list[i_layer], num_classes, **(auxiliary_cfg or {})) else: self.auxiliary_net = AuxiliaryHead( self.output_channel_list[i_layer], num_classes, **(auxiliary_cfg or {})) self.lastact = nn.Identity() self.global_pooling = nn.AdaptiveAvgPool2d(1) if self.dropout_rate and self.dropout_rate > 0: self.dropout = nn.Dropout(p=self.dropout_rate) else: self.dropout = ops.Identity() self.classifier = nn.Linear(self.output_channel_list[-1], self.num_classes) self.to(self.device) # for flops calculation self.total_flops = 0 self._flops_calculated = False self._set_hook()
"arch": [[4] * 6, [4] * 12, [4] * 24, [4] * 6], "stem": 8, "trans": [16, 32, 64] }, { "arch": [[4] * 6, [6] * 6, [4, 5, 6, 7, 8, 9, 10]], "stem": 11, "trans": [17, 26] }]) def test_dense_rollout(case): from aw_nas.common import get_search_space, genotype_from_str ss = get_search_space("cnn_dense", num_dense_blocks=len(case["arch"])) genotype = ss.genotype(case["arch"]) assert genotype.stem == case["stem"] for i, trans_c in enumerate(case["trans"]): assert getattr(genotype, "transition_{}".format(i)) == trans_c assert genotype_from_str(str(genotype), ss) == genotype print(genotype) print("relative conv flops: ", ss.relative_conv_flops(case["arch"])) def test_dense_dynamic_rollout(): from aw_nas.common import get_search_space, genotype_from_str ss = get_search_space("cnn_dense", dynamic_transition=True, reduction=None, transition_channels=None, num_dense_blocks=4) arch = ([[4] * 6, [4] * 12, [4] * 24, [4] * 6], [18, 20, 30]) genotype = ss.genotype(arch) for i, trans_c in enumerate(arch[1]): assert getattr(genotype, "transition_{}".format(i)) == trans_c
def genotype(self): return genotype_from_str(self._genotype, self.search_space)
def __init__( self, search_space, # layer2 device, genotypes, # layer2 micro_model_type="micro-dense-model", micro_model_cfg={}, num_classes=10, init_channels=36, stem_multiplier=1, dropout_rate=0.0, dropout_path_rate=0.0, use_stem="conv_bn_3x3", stem_stride=1, stem_affine=True, auxiliary_head=False, auxiliary_cfg=None, schedule_cfg=None, ): super(MacroStagewiseFinalModel, self).__init__(schedule_cfg) self.macro_ss = search_space.macro_search_space self.micro_ss = search_space.micro_search_space self.device = device assert isinstance(genotypes, str) self.genotypes_str = genotypes self.macro_g, self.micro_g = genotype_from_str(genotypes, search_space) # micro model (cell) class micro_model_cls = FinalModel.get_class_(micro_model_type) # cell type self.num_classes = num_classes self.init_channels = init_channels self.stem_multiplier = stem_multiplier self.stem_stride = stem_stride self.stem_affine = stem_affine self.use_stem = use_stem # training self.dropout_rate = dropout_rate self.dropout_path_rate = dropout_path_rate self.auxiliary_head = auxiliary_head self.overall_adj = self.macro_ss.parse_overall_adj(self.macro_g) self.layer_widths = [float(w) for w in self.macro_g.width.split(",")] self.micro_model_cfg = micro_model_cfg if "postprocess" in self.micro_model_cfg.keys(): self.cell_use_postprocess = self.micro_model_cfg["postprocess"] else: self.cell_use_postprocess = False # sort channels out assert self.stem_multiplier == 1, "Cannot handle stem_multiplier != 1 now" self.input_channel_list = [self.init_channels] for i in range(1, self.macro_ss.num_layers): self.input_channel_list.append( self.input_channel_list[i - 1] * 2 if self._is_reduce(i - 1) else self.input_channel_list[i - 1]) for i in range(self.macro_ss.num_layers): self.input_channel_list[i] = int( self.input_channel_list[i] * self.layer_widths[i] if not self._is_reduce(i) else self.input_channel_list[i] * self.layer_widths[i - 1]) self.output_channel_list = self.input_channel_list[1:] + [ self.input_channel_list[-1] ] # construct cells if not self.use_stem: raise NotImplementedError c_stem = 3 elif isinstance(self.use_stem, (list, tuple)): raise NotImplementedError self.stems = [] c_stem = self.stem_multiplier * self.init_channels for i, stem_type in enumerate(self.use_stem): c_in = 3 if i == 0 else c_stem self.stems.append( ops.get_op(stem_type)(c_in, c_stem, stride=stem_stride, affine=stem_affine)) self.stem = nn.Sequential(self.stems) else: self.stem = ops.get_op(self.use_stem)(3, self.input_channel_list[0], stride=stem_stride, affine=stem_affine) self.extra_stem = ops.get_op("nor_conv_1x1")( self.input_channel_list[0], self.input_channel_list[0] * self.micro_ss.num_steps, stride=1, affine=True, ) # For sink-connect, don't init all cells, just init connected cells connected_cells = [] for cell_idx in range(1, self.macro_ss.num_layers + 2): if len(self.overall_adj[cell_idx].nonzero()[0]) > 0: connected_cells.append(self.overall_adj[cell_idx].nonzero()[0]) # -1 to make the 1st element 0 self.connected_cells = np.concatenate(connected_cells)[1:] - 1 """ ininitialize cells, only connected cells are initialized also use `use_next_stage_width` to handle the disalignment of width due to width search """ self.cells = nn.ModuleList() self.micro_arch_list = self.micro_ss.rollout_from_genotype( self.micro_g).arch for i_layer in range(self.macro_ss.num_layers): stride = 2 if self._is_reduce(i_layer) else 1 connected_is_reduce = [ self._is_reduce(i) for i in self.connected_cells ] # the layer-idx to use next stage's width: the last cell before the redudction cell in each stage use_next_stage_width_layer_idx = self.connected_cells[ np.argwhere(np.array(connected_is_reduce)).reshape(-1) - 1] reduction_layer_idx = self.connected_cells[np.argwhere( np.array(connected_is_reduce) ).reshape( -1)] # find reudction cells are the 1-th in connected cells if not self.cell_use_postprocess: next_stage_widths = (np.array( self.output_channel_list)[self.macro_ss.stages_begin[1:]] // 2) # preprocess, so no //2 else: next_stage_widths = ( np.array(self.output_channel_list)[ self.macro_ss.stages_begin[1:]] // 2 ) # the width to use for `ues_next_stage_width`, the reduction cell is of expansion 2, so //2 use_next_stage_width = ( next_stage_widths[np.argwhere( use_next_stage_width_layer_idx == i_layer).reshape(-1)] if np.argwhere(use_next_stage_width_layer_idx == i_layer).size > 0 else None) input_channel_list_n = np.array(self.input_channel_list) input_channel_list_n[ reduction_layer_idx] = next_stage_widths # input of the reduction should be half of the next stage's width cg_idx = self.macro_ss.cell_layout[i_layer] if i_layer not in self.connected_cells: continue # contruct micro cell cell = micro_model_cls( self.micro_ss, self.micro_arch_list[cg_idx], num_input_channels=int( input_channel_list_n[i_layer] ), # TODO: input_channel_list is of type: np.int64 num_out_channels=self.output_channel_list[i_layer], stride=stride, use_next_stage_width=use_next_stage_width, is_last_cell=True if i_layer == self.connected_cells[-1] else False, is_first_cell=True if i_layer == self.connected_cells[0] else False, skip_cell=False, **micro_model_cfg) # assume non-reduce cell does not change channel number self.cells.append(cell) # add auxiliary head # connected_cells has 1 more element [0] than the self.cells if self.auxiliary_head: self.where_aux_head = self.connected_cells[(2 * len(self.cells)) // 3] extra_expansion_for_aux = ( 1 if self.cell_use_postprocess else self.micro_ss.num_steps ) # if use preprocess, aux head's input ch num should change accordingly # aux head is connected to last cell's output if auxiliary_head == "imagenet": self.auxiliary_net = AuxiliaryHeadImageNet( input_channel_list_n[self.where_aux_head] * extra_expansion_for_aux, num_classes, **(auxiliary_cfg or {})) else: self.auxiliary_net = AuxiliaryHead( input_channel_list_n[self.where_aux_head] * extra_expansion_for_aux, num_classes, **(auxiliary_cfg or {})) self.lastact = nn.Identity() self.global_pooling = nn.AdaptiveAvgPool2d(1) if self.dropout_rate and self.dropout_rate > 0: self.dropout = nn.Dropout(p=self.dropout_rate) else: self.dropout = ops.Identity() if not self.cell_use_postprocess: self.classifier = nn.Linear( self.output_channel_list[-1] * self.micro_ss.num_steps, self.num_classes) else: self.classifier = nn.Linear(self.output_channel_list[-1], self.num_classes) self.to(self.device) # for flops calculation self.total_flops = 0 self._flops_calculated = False self._set_hook()
def __setstate__(self, state): super(Population, self).__setstate__(state) self.genotype_records = collections.OrderedDict([ (ind, genotype_from_str(record.genotype, self.search_space)) for ind, record in six.iteritems(self._model_records) ])
def __init__(self, search_space, device, genotypes, num_classes=10, init_channels=36, layer_channels=tuple(), stem_multiplier=3, dropout_rate=0.1, dropout_path_rate=0.2, auxiliary_head=False, auxiliary_cfg=None, use_stem="conv_bn_3x3", stem_stride=1, stem_affine=True, no_fc=False, cell_use_preprocess=True, cell_pool_batchnorm=False, cell_group_kwargs=None, cell_independent_conn=False, cell_preprocess_stride="skip_connect", cell_preprocess_normal="relu_conv_bn_1x1", schedule_cfg=None): super(CNNGenotypeModel, self).__init__(schedule_cfg) self.search_space = search_space self.device = device assert isinstance(genotypes, str) self.genotypes = list( genotype_from_str(genotypes, self.search_space)._asdict().values()) self.genotypes_grouped = list( zip([ group_and_sort_by_to_node(conns) for conns in self.genotypes[:self.search_space.num_cell_groups] ], self.genotypes[self.search_space.num_cell_groups:])) # self.genotypes_grouped = [group_and_sort_by_to_node(g[1]) for g in self.genotypes\ # if "concat" not in g[0]] self.num_classes = num_classes self.init_channels = init_channels self.layer_channels = layer_channels self.stem_multiplier = stem_multiplier self.use_stem = use_stem self.cell_use_preprocess = cell_use_preprocess self.cell_group_kwargs = cell_group_kwargs self.cell_independent_conn = cell_independent_conn self.no_fc = no_fc # training self.dropout_rate = dropout_rate self.dropout_path_rate = dropout_path_rate self.auxiliary_head = auxiliary_head # search space configs self._num_init = self.search_space.num_init_nodes self._cell_layout = self.search_space.cell_layout self._reduce_cgs = self.search_space.reduce_cell_groups self._num_layers = self.search_space.num_layers expect(len(self.genotypes_grouped) == self.search_space.num_cell_groups, ("Config genotype cell group number({}) " "does not match search_space cell group number({})")\ .format(len(self.genotypes_grouped), self.search_space.num_cell_groups)) ## initialize sub modules if not self.use_stem: c_stem = 3 init_strides = [1] * self._num_init elif isinstance(self.use_stem, (list, tuple)): self.stems = [] c_stem = self.stem_multiplier * self.init_channels for i, stem_type in enumerate(self.use_stem): c_in = 3 if i == 0 else c_stem self.stems.append( ops.get_op(stem_type)(c_in, c_stem, stride=stem_stride, affine=stem_affine)) self.stems = nn.ModuleList(self.stems) init_strides = [stem_stride] * self._num_init else: c_stem = self.stem_multiplier * self.init_channels self.stem = ops.get_op(self.use_stem)(3, c_stem, stride=stem_stride, affine=stem_affine) init_strides = [1] * self._num_init self.cells = nn.ModuleList() num_channels = self.init_channels prev_num_channels = [c_stem] * self._num_init strides = [ 2 if self._is_reduce(i_layer) else 1 for i_layer in range(self._num_layers) ] if self.layer_channels: expect(len(self.layer_channels) == len(strides) + 1, ("Config cell channels({}) does not match search_space num layers + 1 ({})"\ .format(len(self.layer_channels), self.search_space.num_layers + 1)), ConfigException) for i_layer, stride in enumerate(strides): if self.layer_channels: # input and output channels of every layer is specified num_channels = self.layer_channels[i_layer] num_out_channels = self.layer_channels[i_layer + 1] else: if stride > 1: num_channels *= stride num_out_channels = num_channels if cell_group_kwargs is not None: # support passing in different kwargs when instantializing # cell class for different cell groups # Can specificy input/output channels by hand in configuration, # instead of relying on the default # "whenever stride/2, channelx2 and mapping with preprocess operations" assumption kwargs = { k: v for k, v in cell_group_kwargs[ self._cell_layout[i_layer]].items() } if "C_in" in kwargs: num_channels = kwargs.pop("C_in") if "C_out" in kwargs: num_out_channels = kwargs.pop("C_out") else: kwargs = {} cg_idx = self.search_space.cell_layout[i_layer] cell = CNNGenotypeCell(self.search_space, self.genotypes_grouped[cg_idx], layer_index=i_layer, num_channels=num_channels, num_out_channels=num_out_channels, prev_num_channels=tuple(prev_num_channels), stride=stride, prev_strides=init_strides + strides[:i_layer], use_preprocess=cell_use_preprocess, pool_batchnorm=cell_pool_batchnorm, independent_conn=cell_independent_conn, preprocess_stride=cell_preprocess_stride, preprocess_normal=cell_preprocess_normal, **kwargs) # TODO: support specify concat explicitly prev_num_channel = cell.num_out_channel() prev_num_channels.append(prev_num_channel) prev_num_channels = prev_num_channels[1:] self.cells.append(cell) if i_layer == (2 * self._num_layers) // 3 and self.auxiliary_head: if auxiliary_head == "imagenet": self.auxiliary_net = AuxiliaryHeadImageNet( prev_num_channels[-1], num_classes, **(auxiliary_cfg or {})) else: self.auxiliary_net = AuxiliaryHead(prev_num_channels[-1], num_classes, **(auxiliary_cfg or {})) self.global_pooling = nn.AdaptiveAvgPool2d(1) if self.dropout_rate and self.dropout_rate > 0: self.dropout = nn.Dropout(p=self.dropout_rate) else: self.dropout = ops.Identity() if self.no_fc: self.classifier = ops.Identity() else: self.classifier = nn.Linear(prev_num_channels[-1], self.num_classes) self.to(self.device) # for flops calculation self.total_flops = 0 self._flops_calculated = False self.set_hook()
def __setstate(self, state): super(MacroStagewiseFinalModel, self).__setstate(state) self.macro_g, self.micro_g = genotype_from_str(self.genotypes_str, self.search_space)
def __init__( self, search_space, device, genotypes, num_classes=10, init_channels=36, stem_multiplier=1, dropout_rate=0.0, dropout_path_rate=0.0, use_stem="conv_bn_3x3", stem_stride=1, stem_affine=True, schedule_cfg=None, ): super(DenseRobFinalModel, self).__init__(schedule_cfg) self.search_space = search_space self.device = device assert isinstance(genotypes, str) genotypes = genotype_from_str(genotypes, self.search_space) self.arch_list = self.search_space.rollout_from_genotype( genotypes).arch self.num_classes = num_classes self.init_channels = init_channels self.stem_multiplier = stem_multiplier self.use_stem = use_stem # training self.dropout_rate = dropout_rate self.dropout_path_rate = dropout_path_rate # search space configs self._num_init = self.search_space.num_init_nodes self._num_layers = self.search_space.num_layers ## initialize sub modules if not self.use_stem: c_stem = 3 init_strides = [1] * self._num_init elif isinstance(self.use_stem, (list, tuple)): self.stems = [] c_stem = self.stem_multiplier * self.init_channels for i, stem_type in enumerate(self.use_stem): c_in = 3 if i == 0 else c_stem self.stems.append( ops.get_op(stem_type)(c_in, c_stem, stride=stem_stride, affine=stem_affine)) self.stems = nn.ModuleList(self.stems) init_strides = [stem_stride] * self._num_init else: c_stem = self.stem_multiplier * self.init_channels self.stem = ops.get_op(self.use_stem)(3, c_stem, stride=stem_stride, affine=stem_affine) init_strides = [1] * self._num_init self.cells = nn.ModuleList() num_channels = self.init_channels prev_num_channels = [c_stem] * self._num_init strides = [ 2 if self._is_reduce(i_layer) else 1 for i_layer in range(self._num_layers) ] for i_layer, stride in enumerate(strides): if stride > 1: num_channels *= stride num_out_channels = num_channels kwargs = {} cg_idx = self.search_space.cell_layout[i_layer] cell = DenseRobCell( self.search_space, self.arch_list[cg_idx], # num_channels=num_channels, num_input_channels=prev_num_channels, num_out_channels=num_out_channels, # prev_num_channels=tuple(prev_num_channels), prev_strides=init_strides + strides[:i_layer], stride=stride, **kwargs) prev_num_channel = cell.num_out_channel() prev_num_channels.append(prev_num_channel) prev_num_channels = prev_num_channels[1:] self.cells.append(cell) self.lastact = nn.Identity() self.global_pooling = nn.AdaptiveAvgPool2d(1) if self.dropout_rate and self.dropout_rate > 0: self.dropout = nn.Dropout(p=self.dropout_rate) else: self.dropout = ops.Identity() self.classifier = nn.Linear(prev_num_channels[-1], self.num_classes) self.to(self.device) # for flops calculation self.total_flops = 0 self._flops_calculated = False self.set_hook()
def rollout_from_genotype(self, genotype): if isinstance(genotype, str): genotype = genotype_from_str(genotype, self) return GeneralRollout(genotype, {}, self)