class Nasbench201(Dataset): """Nasbench201 Dataset.""" def __init__(self): """Construct the Nasbench201 class.""" super(Nasbench201, self).__init__() self.args.data_path = FileOps.download_dataset(self.args.data_path) self.nasbench201_api = API('self.args.data_path') def query(self, arch_str, dataset): """Query an item from the dataset according to the given arch_str and dataset . :arch_str: arch_str to define the topology of the cell :type path: str :dataset: dataset type :type dataset: str :return: an item of the dataset, which contains the network info and its results like accuracy, flops and etc :rtype: dict """ if dataset not in VALID_DATASET: raise ValueError( "Only cifar10, cifar100, and Imagenet dataset is supported.") ops_list = self.nasbench201_api.str2lists(arch_str) for op in ops_list: if op not in VALID_OPS: raise ValueError( "{} is not in the nasbench201 space.".format(op)) index = self.nasbench201_api.query_index_by_arch(arch_str) results = self.nasbench201_api.query_by_index(index, dataset) return results
class Nasbench201: def __init__(self, dataset, apiloc): self.dataset = dataset self.api = API(apiloc, verbose=False) self.epochs = '12' def get_network(self, uid): #config = self.api.get_net_config(uid, self.dataset) config = self.api.get_net_config(uid, 'cifar10-valid') config['num_classes'] = 1 network = get_cell_based_tiny_net(config) return network def __iter__(self): for uid in range(len(self)): network = self.get_network(uid) yield uid, network def __getitem__(self, index): return index def __len__(self): return 15625 def num_activations(self): network = self.get_network(0) return network.classifier.in_features #def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False): # archinfo = self.api.query_meta_info_by_index(uid) # if (self.dataset == 'cifar10' or traincifar10) and trainval: # #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy'] # return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy'] # elif traincifar10: # return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy'] # else: # return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy'] def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False): #archinfo = self.api.query_meta_info_by_index(uid) #if (self.dataset == 'cifar10' and trainval) or traincifar10: info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True) #else: # info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True) return info['valid-accuracy'] def get_final_accuracy(self, uid, acc_type, trainval): #archinfo = self.api.query_meta_info_by_index(uid) if self.dataset == 'cifar10' and trainval: info = self.api.query_meta_info_by_index( uid, hp='200').get_metrics('cifar10-valid', 'x-valid') #info = self.api.query_by_index(uid, 'cifar10-valid', hp='200') #info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True) else: info = self.api.query_meta_info_by_index( uid, hp='200').get_metrics(self.dataset, acc_type) #info = self.api.query_by_index(uid, self.dataset, hp='200') #info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True) return info['accuracy'] #return info['valid-accuracy'] #if self.dataset == 'cifar10' and trainval: # return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy'] #else: # #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy'] # return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy'] ##dataset = self.dataset ##if self.dataset == 'cifar10' and trainval: ## dataset = 'cifar10-valid' ##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True) ##return archinfo['valid-accuracy'] def get_accuracy(self, uid, acc_type, trainval=True): archinfo = self.api.query_meta_info_by_index(uid) if self.dataset == 'cifar10' and trainval: return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy'] else: return archinfo.get_metrics(self.dataset, acc_type)['accuracy'] def get_accuracy_for_all_datasets(self, uid): archinfo = self.api.query_meta_info_by_index(uid, hp='200') c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy'] c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy'] c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy'] c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy'] imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy'] imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy'] return c10, c10_val, c100, c100_val, imagenet, imagenet_val #def train_and_eval(self, arch, dataname, acc_type, trainval=True): # unique_hash = self.__getitem__(arch) # time = self.get_training_time(unique_hash) # acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval) # acc = self.get_final_accuracy(unique_hash, acc_type, trainval) # return acc12, acc, time def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False): unique_hash = self.__getitem__(arch) time = self.get_training_time(unique_hash) acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10) acc = self.get_final_accuracy(unique_hash, acc_type, trainval) return acc12, acc, time def random_arch(self): return random.randint(0, len(self) - 1) def get_training_time(self, unique_hash): #info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True) #info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True) info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True) return info['train-all-time'] + info['valid-per-time'] #if self.dataset == 'cifar10' and trainval: # info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True) #else: # info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True) ##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True) #return info['train-all-time'] + info['valid-per-time'] def mutate_arch(self, arch): op_names = get_search_spaces('cell', 'nas-bench-201') #config = self.api.get_net_config(arch, self.dataset) config = self.api.get_net_config(arch, 'cifar10-valid') parent_arch = Structure(self.api.str2lists(config['arch_str'])) child_arch = deepcopy(parent_arch) node_id = random.randint(0, len(child_arch.nodes) - 1) node_info = list(child_arch.nodes[node_id]) snode_id = random.randint(0, len(node_info) - 1) xop = random.choice(op_names) while xop == node_info[snode_id][0]: xop = random.choice(op_names) node_info[snode_id] = (xop, node_info[snode_id][1]) child_arch.nodes[node_id] = tuple(node_info) arch_index = self.api.query_index_by_arch(child_arch) return arch_index