class AdelaideRandom(SearchAlgorithm): """Search algorithm of the random structures.""" def __init__(self, search_space=None): """Construct the AdelaideRandom class. :param search_space: Config of the search space """ super(AdelaideRandom, self).__init__(search_space) self.search_space = search_space self.codec = Codec(self.cfg.codec, search_space) self.max_sample = self.cfg.max_sample self.sample_count = 0 @property def is_completed(self): """Tell whether the search process is completed. :return: True is completed, or False otherwise """ return self.sample_count >= self.max_sample def search(self): """Search one random model. :return: current number of samples, and the model """ search_desc = self.search_space.search_space.custom num_ops = len(search_desc.op_names) ops = [random.randrange(num_ops) for _ in range(7)] inputs = list() for inputs_index in range(3): for i in range(2): inputs.append(random.randint(0, (inputs_index + 2) * 3 - 5)) conns = list() for conns_index in range(3): for i in range(2): conns.append(random.randint(0, conns_index + 3)) decoder_cell_str = list() decoder_cell_str.append(ops[0]) decoder_cell_str.append([inputs[0], inputs[1], ops[1], ops[2]]) decoder_cell_str.append([inputs[2], inputs[3], ops[3], ops[4]]) decoder_cell_str.append([inputs[4], inputs[5], ops[5], ops[6]]) decoder_conn_str = [[conns[0], conns[1]], [conns[2], conns[3]], [conns[4], conns[5]]] decoder_arch_str = [decoder_cell_str, decoder_conn_str] search_desc['config'] = decoder_arch_str search_desc['method'] = "random" search_desc = self.codec.encode(search_desc) self.sample_count += 1 return self.sample_count, NetworkDesc(self.search_space.search_space) def update(self, local_worker_path): """Update function. :param local_worker_path: Local path that saved `performance.txt` :type local_worker_path: str """ pass
class SRRandom(SearchAlgorithm): """Search algorithm of the random structures.""" def __init__(self, search_space=None): """Construct the SRRandom class. :param search_space: Config of the search space """ super(SRRandom, self).__init__(search_space) self.search_space = search_space self.codec = Codec(self.cfg.codec, search_space) self.max_sample = self.policy.num_sample self.sample_count = 0 @property def is_completed(self): """Tell whether the search process is completed. :return: True is completed, or False otherwise """ return self.sample_count >= self.max_sample def search(self): """Search one random model. :return: current number of samples, and the model """ search_desc = self.search_space.search_space.custom num_blocks = random.randint(*search_desc.block_range) num_cibs = random.randint(*search_desc.cib_range) candidates = search_desc.candidates blocks = [random.choice(candidates) for _ in range(num_blocks)] for _ in range(num_cibs): cib = [random.choice(candidates) for _ in range(2)] blocks.insert(random.randint(0, len(blocks)), cib) search_desc['blocks'] = blocks search_desc['method'] = "random" search_desc = self.codec.encode(search_desc) self.sample_count += 1 return self.sample_count, NetworkDesc(self.search_space.search_space) def update(self, local_worker_path): """Update function. :param local_worker_path: Local path that saved `performance.txt` :type local_worker_path: str """ pass
class SpNas(SearchAlgorithm): """Search Algorithm Stage of SPNAS.""" def __init__(self, search_space=None): super(SpNas, self).__init__(search_space) self.search_space = search_space self.codec = Codec(self.cfg.codec, search_space) self.sample_level = self.cfg.sample_level self.max_sample = self.cfg.max_sample self.max_optimal = self.cfg.max_optimal self._total_list_name = self.cfg.total_list self.serial_settings = self.cfg.serial_settings self._total_list = ListDict() self.sample_count = 0 self.init_code = None remote_output_path = FileOps.join_path(self.local_output_path, self.cfg.step_name) if 'last_search_result' in self.cfg: last_search_file = self.cfg.last_search_result assert FileOps.exists(os.path.join(remote_output_path, last_search_file) ), "Not found serial results!" # self.download_task_folder() last_search_results = os.path.join(self.local_output_path, last_search_file) last_search_results = ListDict.load_csv(last_search_results) pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, last_search_results) # re-write config template if self.cfg.regnition: self.codec.config_template['model']['backbone']['reignition'] = True assert FileOps.exists(os.path.join(remote_output_path, pre_arch + '_imagenet.pth') ), "Not found {} pretrained .pth file!".format(pre_arch) pretrained_pth = os.path.join(self.local_output_path, pre_arch + '_imagenet.pth') self.codec.config_template['model']['pretrained'] = pretrained_pth pre_worker_id = -1 # update config template self.init_code = dict(arch=pre_arch, pre_arch=pre_arch.split('_')[1], pre_worker_id=pre_worker_id) logging.info("inited SpNas {}-level search...".format(self.sample_level)) @property def is_completed(self): """Check sampling if finished. :return: True is completed, or False otherwise :rtype: bool """ return self.sample_count > self.max_sample @property def num_samples(self): """Get the number of sampled architectures. :return: The number of sampled architectures :rtype: int """ return len(self._total_list) def select_from_remote(self, max_optimal, total_list): """Select base model to mutate. :return: worker id and arch encode :rtype: int, str """ def normalization(x): sum_ = sum(x) return [float(i) / float(sum_) for i in x] # rank with mAP and memory top_ = total_list.sort('mAP')[:max_optimal] if max_optimal > 1: prob = [round(np.log(i + 1e-2), 2) for i in range(1, len(top_) + 1)] prob_temp = prob sorted_ind = sorted(range(len(top_)), key=lambda k: top_['memory'][k], reverse=True) for idx, ind in enumerate(sorted_ind): prob[ind] += prob_temp[idx] ind = np.random.choice(len(top_), p=normalization(prob)) worker_id, arch = top_['worker_id', 'arch'][ind] else: worker_id, arch = top_['worker_id', 'arch'][0] return worker_id, arch def search(self): """Search a sample. :return: sample count and info :rtype: int, dict """ code = self.init_code if self.num_samples > 0: pre_worker_id, pre_arch = self.select_from_remote(self.max_optimal, self._total_list) block_type, pre_serial, pre_paral = pre_arch.split('_') success = False while not success: serialnet, parallelnet = pre_serial, pre_paral if self.sample_level == 'serial': serialnet = self._mutate_serialnet(serialnet, **self.serial_settings) parallelnet = '-'.join(['0'] * len(serialnet.split('-'))) elif self.sample_level == 'parallel': parallelnet = self._mutate_parallelnet(parallelnet) pre_worker_id = self.init_code['pre_worker_id'] else: raise ValueError("Unknown type of sample level") arch = self.codec.encode(block_type, serialnet, parallelnet) if arch not in self._total_list['arch'] or len(self._total_list['arch']) == 0: success = True code = dict(arch=arch, pre_arch=pre_serial, pre_worker_id=pre_worker_id) self.sample_count += 1 logging.info("The {}-th successfully sampling result: {}".format(self.sample_count, code)) net_desc = self.codec.decode(code) return self.sample_count, net_desc def update(self, worker_result_path): """Update sampler.""" performance_file = self.performance_path(worker_result_path) logging.info( "SpNas.update(), performance file={}".format(performance_file)) info = FileOps.load_pickle(performance_file) if info is not None: self._total_list.append(info) else: logging.info("SpNas.update(), file is not exited, " "performance file={}".format(performance_file)) self.save_output(self.local_output_path) if self.backup_base_path is not None: FileOps.copy_folder(self.local_output_path, self.backup_base_path) def performance_path(self, worker_result_path): """Get performance path.""" performance_dir = os.path.join(worker_result_path, 'performance') if not os.path.exists(performance_dir): FileOps.make_dir(performance_dir) return os.path.join(performance_dir, 'performance.pkl') def save_output(self, local_output_path): """Save results.""" local_totallist_path = os.path.join(local_output_path, self._total_list_name) self._total_list.to_csv(local_totallist_path) def _mutate_serialnet(self, arch, num_mutate=3, expend_ratio=0, addstage_ratio=0.95, max_stages=6): """Swap & Expend operation in Serial-level searching. :param arch: base arch encode :type arch: str :param num_mutate: number of mutate :type num_mutate: int :param expend_ratio: probability of expend block :type expend_ratio: float :param addstage_ratio: probability of expand new stage :type addstage_ratio: float :param max_stages: max stage allowed to expand :type max_stages: int :return: arch encode after mutate :rtype: str """ def is_valid(arch): stages = arch.split('-') for stage in stages: if len(stage) == 0: return False return True def expend(arc): idx = np.random.randint(low=1, high=len(arc)) arc = arc[:idx] + '1' + arc[idx:] return arc, idx def swap(arc, len_step=3): is_not_valid = True arc_origin = copy.deepcopy(arc) temp = arc.split('-') num_insert = len(temp) - 1 while is_not_valid or arc == arc_origin: next_start = 0 arc = list(''.join(temp)) for i in range(num_insert): pos = arc_origin[next_start:].find('-') + next_start assert arc_origin[pos] == '-', "Wrong '-' is found!" max_step = min(len_step, max(len(temp[i]), len(temp[i + 1]))) step_range = list(range(-1 * max_step, max_step)) step = random.choice(step_range) next_start = pos + 1 pos = pos + step arc.insert(pos, '-') arc = ''.join(arc) is_not_valid = (not is_valid(arc)) return arc arch_origin = arch success = False k = 0 while not success: k += 1 arch = arch_origin ops = [] for i in range(num_mutate): op_idx = np.random.randint(low=0, high=3) adds_thresh_ = addstage_ratio if len(arch.split('-')) < max_stages else 1 if op_idx == 0 and random.random() > expend_ratio: arch, idx = expend(arch) arch, idx = expend(arch) elif op_idx == 1: arch = swap(arch) elif op_idx == 2 and random.random() > adds_thresh_: arch = arch + '-1' ops.append('add stage') else: ops.append('Do Nothing.') success = arch != arch_origin flag = 'Success' if success else 'Failed' logging.info('Serial-level Sample{}: {}. {}.'.format(k + 1, ' -> '.join(ops), flag)) return arch def _mutate_parallelnet(self, arch): """Mutate operation in Parallel-level searching. :param arch: base arch encode :type arch: str :return: parallel arch encode after mutate :rtype: str """ def limited_random(num_stage): p = [0.4, 0.3, 0.2, 0.1] l = np.random.choice(4, size=num_stage, replace=True, p=p) l = [str(i) for i in l] return '-'.join(l) num_stage = len(arch.split('-')) success = False k = 0 while not success: k += 1 new_arch = limited_random(num_stage) success = new_arch != arch flag = 'Success' if success else 'Failed' logging.info('Parallel-level Sample{}: {}.'.format(k + 1, flag)) return new_arch
class BackboneNas(SearchAlgorithm): """BackboneNas. :param search_space: input search_space :type: SeachSpace """ def __init__(self, search_space=None): """Init BackboneNas.""" super(BackboneNas, self).__init__(search_space) # ea or random self.search_space = search_space self.codec = Codec(self.cfg.codec, search_space) self.num_mutate = self.policy.num_mutate self.random_ratio = self.policy.random_ratio self.max_sample = self.range.max_sample self.min_sample = self.range.min_sample self.sample_count = 0 logging.info("inited BackboneNas") self.pareto_front = ParetoFront(self.cfg) self.random_search = RandomSearchAlgorithm(self.search_space) self._best_desc_file = 'nas_model_desc.json' if 'best_desc_file' in self.cfg and self.cfg.best_desc_file is not None: self._best_desc_file = self.cfg.best_desc_file @property def is_completed(self): """Check if NAS is finished.""" return self.sample_count > self.max_sample def search(self): """Search in search_space and return a sample.""" sample = {} while sample is None or 'code' not in sample: pareto_dict = self.pareto_front.get_pareto_front() pareto_list = list(pareto_dict.values()) if self.pareto_front.size < self.min_sample or random.random( ) < self.random_ratio or len(pareto_list) == 0: sample_desc = self.random_search.search() sample = self.codec.encode(sample_desc) else: sample = pareto_list[0] if sample is not None and 'code' in sample: code = sample['code'] code = self.ea_sample(code) sample['code'] = code if not self.pareto_front._add_to_board(id=self.sample_count + 1, config=sample): sample = None self.sample_count += 1 logging.info(sample) sample_desc = self.codec.decode(sample) return self.sample_count, NetworkDesc(sample_desc) def random_sample(self): """Random sample from search_space.""" sample_desc = self.random_search.search() sample = self.codec.encode(sample_desc, is_random=True) return sample def ea_sample(self, code): """Use EA op to change a arch code. :param code: list of code for arch :type code: list :return: changed code :rtype: list """ new_arch = code.copy() self._insert(new_arch) self._remove(new_arch) self._swap(new_arch[0], self.num_mutate // 2) self._swap(new_arch[1], self.num_mutate // 2) return new_arch def update(self, worker_result_path): """Use train and evaluate result to update algorithm. :param worker_result_path: current result path :type: str """ step_name = os.path.basename(os.path.dirname(worker_result_path)) config_id = int(os.path.basename(worker_result_path)) performance = self._get_performance(step_name, config_id) logging.info("update performance={}".format(performance)) self.pareto_front.add_pareto_score(config_id, performance) self.save_output(self.local_output_path) if self.backup_base_path is not None: FileOps.copy_folder(self.local_base_path, self.backup_base_path) def _get_performance(self, step_name, worker_id): saved_folder = self.get_local_worker_path(step_name, worker_id) performance_file = FileOps.join_path(saved_folder, "performance.txt") if not os.path.isfile(performance_file): logging.info("Performance file is not exited, file={}".format( performance_file)) return [] with open(performance_file, 'r') as f: performance = [] for line in f.readlines(): line = line.strip() if line == "": continue data = json.loads(line) if isinstance(data, list): data = data[0] performance.append(data) logging.info("performance={}".format(performance)) return performance def _insert(self, arch): """Random insert to arch code. :param arch: input arch code :type arch: list :return: changed arch code :rtype: list """ idx = np.random.randint(low=0, high=len(arch[0])) arch[0].insert(idx, 1) idx = np.random.randint(low=0, high=len(arch[1])) arch[1].insert(idx, 1) return arch def _remove(self, arch): """Random remove one from arch code. :param arch: input arch code :type arch: list :return: changed arch code :rtype: list """ # random pop arch[0] ones_index = [i for i, char in enumerate(arch[0]) if char == 1] idx = random.choice(ones_index) arch[0].pop(idx) # random pop arch[1] ones_index = [i for i, char in enumerate(arch[1]) if char == 1] idx = random.choice(ones_index) arch[1].pop(idx) return arch def _swap(self, arch, R): """Random swap one in arch code. :param arch: input arch code :type arch: list :return: changed arch code :rtype: list """ while True: not_ones_index = [i for i, char in enumerate(arch) if char != 1] idx = random.choice(not_ones_index) r = random.randint(1, R) direction = -r if random.random() > 0.5 else r try: arch[idx], arch[idx + direction] = arch[idx + direction], arch[idx] break except Exception: continue return arch def is_valid(self, arch): """Check if valid for arch code. :param arch: input arch code :type arch: list :return: if current arch is valid :rtype: bool """ return True def save_output(self, output_path): """Save result to output_path. :param output_path: the result save output_path :type: str """ try: self.pareto_front.sieve_board.to_csv(os.path.join( output_path, 'nas_score_board.csv'), index=None, header=True) except Exception as e: logging.error("write nas_score_board.csv error:{}".format(str(e))) try: pareto_dict = self.pareto_front.get_pareto_front() if len(pareto_dict) > 0: id = list(pareto_dict.keys())[0] net_desc = pareto_dict[id] net_desc = self.codec.decode(net_desc) with open(os.path.join(output_path, self._best_desc_file), 'w') as fp: json.dump(net_desc, fp) except Exception as e: logging.error("write best model error:{}".format(str(e)))