class Nasbench201: def __init__(self, dataset, apiloc): self.dataset = dataset self.api = API(apiloc, verbose=False) self.epochs = '12' def get_network(self, uid): #config = self.api.get_net_config(uid, self.dataset) config = self.api.get_net_config(uid, 'cifar10-valid') config['num_classes'] = 1 network = get_cell_based_tiny_net(config) return network def __iter__(self): for uid in range(len(self)): network = self.get_network(uid) yield uid, network def __getitem__(self, index): return index def __len__(self): return 15625 def num_activations(self): network = self.get_network(0) return network.classifier.in_features #def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False): # archinfo = self.api.query_meta_info_by_index(uid) # if (self.dataset == 'cifar10' or traincifar10) and trainval: # #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy'] # return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy'] # elif traincifar10: # return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy'] # else: # return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy'] def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False): #archinfo = self.api.query_meta_info_by_index(uid) #if (self.dataset == 'cifar10' and trainval) or traincifar10: info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True) #else: # info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True) return info['valid-accuracy'] def get_final_accuracy(self, uid, acc_type, trainval): #archinfo = self.api.query_meta_info_by_index(uid) if self.dataset == 'cifar10' and trainval: info = self.api.query_meta_info_by_index( uid, hp='200').get_metrics('cifar10-valid', 'x-valid') #info = self.api.query_by_index(uid, 'cifar10-valid', hp='200') #info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True) else: info = self.api.query_meta_info_by_index( uid, hp='200').get_metrics(self.dataset, acc_type) #info = self.api.query_by_index(uid, self.dataset, hp='200') #info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True) return info['accuracy'] #return info['valid-accuracy'] #if self.dataset == 'cifar10' and trainval: # return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy'] #else: # #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy'] # return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy'] ##dataset = self.dataset ##if self.dataset == 'cifar10' and trainval: ## dataset = 'cifar10-valid' ##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True) ##return archinfo['valid-accuracy'] def get_accuracy(self, uid, acc_type, trainval=True): archinfo = self.api.query_meta_info_by_index(uid) if self.dataset == 'cifar10' and trainval: return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy'] else: return archinfo.get_metrics(self.dataset, acc_type)['accuracy'] def get_accuracy_for_all_datasets(self, uid): archinfo = self.api.query_meta_info_by_index(uid, hp='200') c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy'] c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy'] c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy'] c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy'] imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy'] imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy'] return c10, c10_val, c100, c100_val, imagenet, imagenet_val #def train_and_eval(self, arch, dataname, acc_type, trainval=True): # unique_hash = self.__getitem__(arch) # time = self.get_training_time(unique_hash) # acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval) # acc = self.get_final_accuracy(unique_hash, acc_type, trainval) # return acc12, acc, time def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False): unique_hash = self.__getitem__(arch) time = self.get_training_time(unique_hash) acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10) acc = self.get_final_accuracy(unique_hash, acc_type, trainval) return acc12, acc, time def random_arch(self): return random.randint(0, len(self) - 1) def get_training_time(self, unique_hash): #info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True) #info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True) info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True) return info['train-all-time'] + info['valid-per-time'] #if self.dataset == 'cifar10' and trainval: # info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True) #else: # info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True) ##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True) #return info['train-all-time'] + info['valid-per-time'] def mutate_arch(self, arch): op_names = get_search_spaces('cell', 'nas-bench-201') #config = self.api.get_net_config(arch, self.dataset) config = self.api.get_net_config(arch, 'cifar10-valid') parent_arch = Structure(self.api.str2lists(config['arch_str'])) child_arch = deepcopy(parent_arch) node_id = random.randint(0, len(child_arch.nodes) - 1) node_info = list(child_arch.nodes[node_id]) snode_id = random.randint(0, len(node_info) - 1) xop = random.choice(op_names) while xop == node_info[snode_id][0]: xop = random.choice(op_names) node_info[snode_id] = (xop, node_info[snode_id][1]) child_arch.nodes[node_id] = tuple(node_info) arch_index = self.api.query_index_by_arch(child_arch) return arch_index
order_fn = np.nanargmax runs = trange(args.n_runs, desc='acc: ') for N in runs: start = time.time() indices = np.random.randint(0,15625,args.n_samples) scores = [] for arch in indices: data_iterator = iter(train_loader) x, target = next(data_iterator) x, target = x.to(device), target.to(device) config = api.get_net_config(arch, args.dataset) config['num_classes'] = 1 network = get_cell_based_tiny_net(config) # create the network from configuration network = network.to(device) jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args) jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy() try: s = eval_score(jacobs, labels) except Exception as e: print(e) s = np.nan scores.append(s)
plot_shape = (25, 5) num_plots = plot_shape[0]*plot_shape[1] fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(9, 9) ) plt_cts = [0 for i in range(plot_shape[1])] api = API(args.api_loc) archs = list(range(ARCH_START, ARCH_END)) colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B'] strs = [] random.shuffle(archs) for arch in archs: try: config = api.get_net_config(arch, 'cifar10') archinfo = api.query_meta_info_by_index(arch) acc = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy'] network = get_cell_based_tiny_net(config) network = network.to(device) jacobs, labels = get_batch_jacobian(network, train_loader, device) boundaries = [60., 70., 80., 90.] can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries) if not can_plt: continue axes[row, col].axis('off') plot_hist(jacobs, axes[row, col], colours[col]) if row == 0: