class Nasbench201(Dataset): """Nasbench201 Dataset.""" def __init__(self): """Construct the Nasbench201 class.""" super(Nasbench201, self).__init__() self.args.data_path = FileOps.download_dataset(self.args.data_path) self.nasbench201_api = API('self.args.data_path') def query(self, arch_str, dataset): """Query an item from the dataset according to the given arch_str and dataset . :arch_str: arch_str to define the topology of the cell :type path: str :dataset: dataset type :type dataset: str :return: an item of the dataset, which contains the network info and its results like accuracy, flops and etc :rtype: dict """ if dataset not in VALID_DATASET: raise ValueError( "Only cifar10, cifar100, and Imagenet dataset is supported.") ops_list = self.nasbench201_api.str2lists(arch_str) for op in ops_list: if op not in VALID_OPS: raise ValueError( "{} is not in the nasbench201 space.".format(op)) index = self.nasbench201_api.query_index_by_arch(arch_str) results = self.nasbench201_api.query_by_index(index, dataset) return results
def visualize_rank_over_time(meta_file, vis_save_dir): print('\n' + '-' * 150) vis_save_dir.mkdir(parents=True, exist_ok=True) print('{:} start to visualize rank-over-time into {:}'.format( time_string(), vis_save_dir)) cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth' if not cache_file_path.exists(): print('Do not find cache file : {:}'.format(cache_file_path)) nas_bench = API(str(meta_file)) print('{:} load nas_bench done'.format(time_string())) params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict( list), defaultdict(list), defaultdict(list), defaultdict(list) #for iepoch in range(200): for index in range( len(nas_bench) ): for index in tqdm(range(len(nas_bench))): info = nas_bench.query_by_index(index, use_12epochs_result=False) for iepoch in range(200): res = info.get_metrics('cifar10', 'train', iepoch) train_acc = res['accuracy'] res = info.get_metrics('cifar10-valid', 'x-valid', iepoch) valid_acc = res['accuracy'] res = info.get_metrics('cifar10', 'ori-test', iepoch) test_acc = res['accuracy'] res = info.get_metrics('cifar10', 'ori-test', iepoch) otest_acc = res['accuracy'] train_accs[iepoch].append(train_acc) valid_accs[iepoch].append(valid_acc) test_accs[iepoch].append(test_acc) otest_accs[iepoch].append(otest_acc) if iepoch == 0: res = info.get_comput_costs('cifar10') flop, param = res['flops'], res['params'] flops.append(flop) params.append(param) info = { 'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs } torch.save(info, cache_file_path) else: print('Find cache file : {:}'.format(cache_file_path)) info = torch.load(cache_file_path) params, flops, train_accs, valid_accs, test_accs, otest_accs = info[ 'params'], info['flops'], info['train_accs'], info[ 'valid_accs'], info['test_accs'], info['otest_accs'] print('{:} collect data done.'.format(time_string())) #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199] selected_epochs = list(range(200)) x_xtests = test_accs[199] indexes = list(range(len(x_xtests))) ord_idxs = sorted(indexes, key=lambda i: x_xtests[i]) for sepoch in selected_epochs: x_valids = valid_accs[sepoch] valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i]) valid_ord_lbls = [] for idx in ord_idxs: valid_ord_lbls.append(valid_ord_idxs.index(idx)) # labeled data dpi, width, height = 300, 2600, 2600 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 18, 18 fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize, rotation='vertical') plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 6), fontsize=LegendFontsize) ax.scatter(indexes, valid_ord_lbls, marker='^', s=0.5, c='tab:green', alpha=0.8) ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue', alpha=0.8) ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-10 validation') ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue', label='CIFAR-10 test') plt.grid(zorder=0) ax.set_axisbelow(True) plt.legend(loc='upper left', fontsize=LegendFontsize) ax.set_xlabel('architecture ranking in the final test accuracy', fontsize=LabelSize) ax.set_ylabel('architecture ranking in the validation set', fontsize=LabelSize) save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') print('{:} save into {:}'.format(time_string(), save_path)) plt.close('all')
def visualize_info(meta_file, dataset, vis_save_dir): print('{:} start to visualize {:} information'.format( time_string(), dataset)) cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset) if not cache_file_path.exists(): print('Do not find cache file : {:}'.format(cache_file_path)) nas_bench = API(str(meta_file)) params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], [] for index in range(len(nas_bench)): info = nas_bench.query_by_index(index, use_12epochs_result=False) resx = info.get_comput_costs(dataset) flop, param = resx['flops'], resx['params'] if dataset == 'cifar10': res = info.get_metrics('cifar10', 'train') train_acc = res['accuracy'] res = info.get_metrics('cifar10-valid', 'x-valid') valid_acc = res['accuracy'] res = info.get_metrics('cifar10', 'ori-test') test_acc = res['accuracy'] res = info.get_metrics('cifar10', 'ori-test') otest_acc = res['accuracy'] else: res = info.get_metrics(dataset, 'train') train_acc = res['accuracy'] res = info.get_metrics(dataset, 'x-valid') valid_acc = res['accuracy'] res = info.get_metrics(dataset, 'x-test') test_acc = res['accuracy'] res = info.get_metrics(dataset, 'ori-test') otest_acc = res['accuracy'] if index == 11472: # resnet resnet = { 'params': param, 'flops': flop, 'index': 11472, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc, 'otest_acc': otest_acc } flops.append(flop) params.append(param) train_accs.append(train_acc) valid_accs.append(valid_acc) test_accs.append(test_acc) otest_accs.append(otest_acc) #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97} info = { 'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs, 'otest_accs': otest_accs } info['resnet'] = resnet torch.save(info, cache_file_path) else: print('Find cache file : {:}'.format(cache_file_path)) info = torch.load(cache_file_path) params, flops, train_accs, valid_accs, test_accs, otest_accs = info[ 'params'], info['flops'], info['train_accs'], info[ 'valid_accs'], info['test_accs'], info['otest_accs'] resnet = info['resnet'] print('{:} collect data done.'.format(time_string())) indexes = list(range(len(params))) dpi, width, height = 300, 2600, 2600 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 22, 22 resnet_scale, resnet_alpha = 120, 0.5 fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) if dataset == 'cifar10': plt.ylim(50, 100) plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) elif dataset == 'cifar100': plt.ylim(25, 75) plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) else: plt.ylim(0, 50) plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue') ax.scatter([resnet['params']], [resnet['valid_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=0.4) plt.grid(zorder=0) ax.set_axisbelow(True) plt.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize) save_path = (vis_save_dir / '{:}-param-vs-valid.pdf'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') save_path = (vis_save_dir / '{:}-param-vs-valid.png'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') print('{:} save into {:}'.format(time_string(), save_path)) fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) if dataset == 'cifar10': plt.ylim(50, 100) plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) elif dataset == 'cifar100': plt.ylim(25, 75) plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) else: plt.ylim(0, 50) plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue') ax.scatter([resnet['params']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) plt.grid() ax.set_axisbelow(True) plt.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) save_path = (vis_save_dir / '{:}-param-vs-test.pdf'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') save_path = (vis_save_dir / '{:}-param-vs-test.png'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') print('{:} save into {:}'.format(time_string(), save_path)) fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize) if dataset == 'cifar10': plt.ylim(50, 100) plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) elif dataset == 'cifar100': plt.ylim(20, 100) plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize) else: plt.ylim(25, 76) plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue') ax.scatter([resnet['params']], [resnet['train_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) plt.grid() ax.set_axisbelow(True) plt.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel('#parameters (MB)', fontsize=LabelSize) ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize) save_path = (vis_save_dir / '{:}-param-vs-train.pdf'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') save_path = (vis_save_dir / '{:}-param-vs-train.png'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') print('{:} save into {:}'.format(time_string(), save_path)) fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xlim(0, max(indexes)) plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) if dataset == 'cifar10': plt.ylim(50, 100) plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize) elif dataset == 'cifar100': plt.ylim(25, 75) plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize) else: plt.ylim(0, 50) plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize) ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') ax.scatter([resnet['index']], [resnet['test_acc']], marker='*', s=resnet_scale, c='tab:orange', label='resnet', alpha=resnet_alpha) plt.grid() ax.set_axisbelow(True) plt.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel('architecture ID', fontsize=LabelSize) ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize) save_path = (vis_save_dir / '{:}-test-over-ID.pdf'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') save_path = (vis_save_dir / '{:}-test-over-ID.png'.format(dataset)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') print('{:} save into {:}'.format(time_string(), save_path)) plt.close('all')
class MacroGraph(NodeOpGraph): def __init__(self, config, primitives, ops_dict, *args, **kwargs): self.config = config self.primitives = primitives self.ops_dict = ops_dict self.nasbench_api = API('/home/siemsj/nasbench_201.pth') super(MacroGraph, self).__init__(*args, **kwargs) def _build_graph(self): num_cells_per_stack = self.config['num_cells_per_stack'] C = self.config['init_channels'] layer_channels = [C] * num_cells_per_stack + [C * 2] + [C * 2] * num_cells_per_stack + [C * 4] + [ C * 4] * num_cells_per_stack layer_reductions = [False] * num_cells_per_stack + [True] + [False] * num_cells_per_stack + [True] + [ False] * num_cells_per_stack stem = NASBENCH_201_Stem(C=C) self.add_node(0, type='input') self.add_node(1, op=stem, type='stem') C_prev = C self.cells = nn.ModuleList() for cell_num, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): if reduction: cell = ResNetBasicblock(C_prev, C_curr, 2, True) self.add_node(cell_num + 2, op=cell, primitives=self.primitives, transform=lambda x: x[0]) else: cell = Cell(primitives=self.primitives, stride=1, C_prev=C_prev, C=C_curr, ops_dict=self.ops_dict, cell_type='normal') self.add_node(cell_num + 2, op=cell, primitives=self.primitives) C_prev = C_curr lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) pooling = nn.AdaptiveAvgPool2d(1) classifier = nn.Linear(C_prev, self.config['num_classes']) self.add_node(cell_num + 3, op=lastact, transform=lambda x: x[0], type='postprocessing_nb201') self.add_node(cell_num + 4, op=pooling, transform=lambda x: x[0], type='pooling') self.add_node(cell_num + 5, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1), type='output') # Edges for i in range(1, cell_num + 6): self.add_edge(i - 1, i, type='input', desc='previous') def sample(self, same_cell_struct=True, n_ops_per_edge=1, n_input_edges=None, dist=None, seed=1): """ same_cell_struct: True; if the sampled cell topology is the same or not n_ops_per_edge: 1; number of sampled operations per edge in cell n_input_edges: None; list equal with length with number of intermediate nodes. Determines the number of predecesor nodes for each of them dist: None; distribution to sample operations in edges from seed: 1; random seed """ # create a new graph that we will discretize new_graph = MacroGraph(self.config, self.primitives, self.ops_dict) np.random.seed(seed) seeds = {'normal': seed + 1, 'reduction': seed + 2} for node in new_graph: cell = new_graph.get_node_op(node) if not isinstance(cell, Cell): continue if same_cell_struct: np.random.seed(seeds[new_graph.get_node_type(node)]) for edge in cell.edges: op_choices = cell.get_edge_op_choices(*edge) sampled_op = np.random.choice(op_choices, n_ops_per_edge, False, p=dist) cell[edge[0]][edge[1]]['op_choices'] = [*sampled_op] if n_input_edges is not None: for inter_node, k in zip(cell.inter_nodes(), n_input_edges): # in case the start node index is not 0 node_idx = list(cell.nodes).index(inter_node) prev_node_choices = list(cell.nodes)[:node_idx] assert k <= len(prev_node_choices), 'cannot sample more' ' than number of predecesor nodes' sampled_input_edges = np.random.choice(prev_node_choices, k, False) for i in set(prev_node_choices) - set(sampled_input_edges): cell.remove_edge(i, inter_node) return new_graph @classmethod def from_config(cls, config=None, filename=None): with open(filename, 'r') as f: graph_dict = yaml.safe_load(f) if config is None: raise ('No configuration provided') graph = cls(config, []) graph_type = graph_dict['type'] edges = [(*eval(e), attr) for e, attr in graph_dict['edges'].items()] graph.clear() graph.add_edges_from(edges) C = config['init_channels'] C_curr = config['stem_multiplier'] * C stem = Stem(C_curr=C_curr) C_prev_prev, C_prev, C_curr = C_curr, C_curr, C for node, attr in graph_dict['nodes'].items(): node_type = attr['type'] if node_type == 'input': graph.add_node(node, type='input') elif node_type == 'stem': graph.add_node(node, op=stem, type='stem') elif node_type in ['normal', 'reduction']: assert attr['op']['type'] == 'Cell' graph.add_node(node, op=Cell.from_config(attr['op'], primitives=attr['op']['primitives'], C_prev_prev=C_prev_prev, C_prev=C_prev, C=C_curr, reduction_prev=graph_dict['nodes'][node - 1]['type'] == 'reduction', cell_type=node_type), type=node_type) C_prev_prev, C_prev = C_prev, config['channel_multiplier'] * C_curr elif node_type == 'pooling': pooling = nn.AdaptiveAvgPool2d(1) graph.add_node(node, op=pooling, transform=lambda x: x[0], type='pooling') elif node_type == 'output': classifier = nn.Linear(C_prev, config['num_classes']) graph.add_node(node, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1), type='output') return graph @staticmethod def export_nasbench_201_results_to_dict(information): results_dict = {} dataset_names = information.get_dataset_names() results_dict['arch'] = information.arch_str results_dict['datasets'] = dataset_names for ida, dataset in enumerate(dataset_names): dataset_results = {} dataset_results['dataset'] = dataset metric = information.get_compute_costs(dataset) flop, param, latency, training_time = metric['flops'], metric['params'], metric['latency'], metric[ 'T-train@total'] dataset_results['flop'] = flop dataset_results['params (MB)'] = param dataset_results['latency (ms)'] = latency * 1000 if latency is not None and latency > 0 else None dataset_results['training_time'] = training_time train_info = information.get_metrics(dataset, 'train') if dataset == 'cifar10-valid': valid_info = information.get_metrics(dataset, 'x-valid') dataset_results['train_loss'] = train_info['loss'] dataset_results['train_accuracy'] = train_info['accuracy'] dataset_results['valid_loss'] = valid_info['loss'] dataset_results['valid_accuracy'] = valid_info['accuracy'] elif dataset == 'cifar10': test__info = information.get_metrics(dataset, 'ori-test') dataset_results['train_loss'] = train_info['loss'] dataset_results['train_accuracy'] = train_info['accuracy'] dataset_results['test_loss'] = test__info['loss'] dataset_results['test_accuracy'] = test__info['accuracy'] else: valid_info = information.get_metrics(dataset, 'x-valid') test__info = information.get_metrics(dataset, 'x-test') dataset_results['train_loss'] = train_info['loss'] dataset_results['train_accuracy'] = train_info['accuracy'] dataset_results['valid_loss'] = valid_info['loss'] dataset_results['valid_accuracy'] = valid_info['accuracy'] dataset_results['test_loss'] = test__info['loss'] dataset_results['test_accuracy'] = test__info['accuracy'] results_dict[dataset] = dataset_results return results_dict def query_architecture(self, arch_weights): arch_weight_idx_to_parent = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2} arch_strs = { 'cell_normal_from_0_to_1': '', 'cell_normal_from_0_to_2': '', 'cell_normal_from_1_to_2': '', 'cell_normal_from_0_to_3': '', 'cell_normal_from_1_to_3': '', 'cell_normal_from_2_to_3': '', } for arch_weight_idx, (edge_key, edge_weights) in enumerate(arch_weights.items()): edge_weights_norm = torch.softmax(edge_weights, dim=-1) selected_op_str = PRIMITIVES[edge_weights_norm.argmax()] arch_strs[edge_key] = '{}~{}'.format(selected_op_str, arch_weight_idx_to_parent[arch_weight_idx]) arch_str = '|{}|+|{}|{}|+|{}|{}|{}|'.format(*arch_strs.values()) if not hasattr(self, 'nasbench_api'): self.nasbench_api = API('/home/siemsj/nasbench_201.pth') index = self.nasbench_api.query_index_by_arch(arch_str) self.nasbench_api.show(index) info = self.nasbench_api.query_by_index(index) return self.export_nasbench_201_results_to_dict(info)
network = get_cell_based_tiny_net(config) # create the network from configuration network = network.to(device) jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args) jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy() try: s = eval_score(jacobs, labels) except Exception as e: print(e) s = np.nan scores.append(s) best_arch = indices[order_fn(scores)] info = api.query_by_index(best_arch) topscores.append(scores[order_fn(scores)]) chosen.append(best_arch) acc.append(info.get_metrics(dset, acc_type)['accuracy']) if not args.dataset == 'cifar10' or args.trainval: val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy']) times.append(time.time()-start) runs.set_description(f"acc: {mean(acc if not args.trainval else val_acc):.2f}%") print(f"Final mean test accuracy: {np.mean(acc)}") if len(val_acc) > 1: print(f"Final mean validation accuracy: {np.mean(val_acc)}") state = {'accs': acc,
class NAS201(ObjectiveFunction): def __init__(self, data_dir, task='cifar10-valid', log_scale=True, negative=True, use_12_epochs_result=False, seed=None): """ data_dir: data directory that contains NAS-Bench-201-v1_0-e61699.pth file task: the target image tasks. Options: cifar10-valid, cifar100, ImageNet16-120 log_scale: whether output the objective in log scale negative: whether output the objective in negative form use_12_epochs_result: whether use the statistics at the end of training of the 12th epoch instead of all the way till the end. seed: set the random seed to access trained model performance: Options: 0, 1, 2 seed=None will select the seed randomly """ self.api = API(os.path.join(data_dir, 'NAS-Bench-201-v1_1-096897.pth')) if isinstance(task, list): task = task[0] self.task = task self.use_12_epochs_result = use_12_epochs_result if task == 'cifar10-valid': best_val_arch_index = 6111 best_val_acc = 91.60666665039064 / 100 best_test_arch_index = 1459 best_test_acc = 91.52333333333333 / 100 elif task == 'cifar100': best_val_arch_index = 9930 best_val_acc = 73.49333323567708 / 100 best_test_arch_index = 9930 best_test_acc = 73.51333326009114 / 100 elif task == 'ImageNet16-120': best_val_arch_index = 10676 best_val_acc = 46.766666727701825 / 100 best_test_arch_index = 857 best_test_acc = 47.311111097547744 / 100 else: raise NotImplementedError("task" + str(task) + " is not implemented in the dataset.") if log_scale: best_val_acc = np.log(best_val_acc) best_val_err = 1. - best_val_acc best_test_err = 1. - best_test_acc if log_scale: best_val_err = np.log(best_val_err) best_test_err = np.log(best_val_err) if negative: best_val_err = -best_val_err best_test_err = -best_test_err self.best_val_err = best_val_err self.best_test_err = best_test_err self.best_val_acc = best_val_acc self.best_test_acc = best_test_acc super(NAS201, self).__init__(dim=None, optimum_location=best_test_arch_index, optimal_val=best_test_err, bounds=None) self.log_scale = log_scale self.seed = seed self.X = [] self.y_valid_acc = [] self.y_test_acc = [] self.costs = [] self.negative = negative # self.optimal_val = # lowest mean validation error # self.y_star_test = # lowest mean test error def _retrieve(self, G, budget, which='eval'): # set random seed for evaluation if which == 'test': seed = 3 else: seed_list = [777, 888, 999] if self.seed is None: seed = random.choice(seed_list) elif self.seed >= 3: seed = self.seed else: seed = seed_list[self.seed] # find architecture index arch_str = G.name # print(arch_str) try: arch_index = self.api.query_index_by_arch(arch_str) acc_results = self.api.query_by_index(arch_index, self.task, use_12epochs_result=self.use_12_epochs_result,) if seed is not None and 3 <= seed < 777: # some architectures only contain 1 seed result acc_results = self.api.get_more_info(arch_index, self.task, None, use_12epochs_result=self.use_12_epochs_result, is_random=False) val_acc = acc_results['valid-accuracy'] / 100 test_acc = acc_results['test-accuracy'] / 100 else: try: acc_results = self.api.get_more_info(arch_index, self.task, None, use_12epochs_result=self.use_12_epochs_result, is_random=seed) val_acc = acc_results['valid-accuracy'] / 100 test_acc = acc_results['test-accuracy'] / 100 # val_acc = acc_results[seed].get_eval('x-valid')['accuracy'] / 100 # if self.task == 'cifar10-valid': # test_acc = acc_results[seed].get_eval('ori-test')['accuracy'] / 100 # else: # test_acc = acc_results[seed].get_eval('x-test')['accuracy'] / 100 except: # some architectures only contain 1 seed result acc_results = self.api.get_more_info(arch_index, self.task, None, use_12epochs_result=self.use_12_epochs_result, is_random=False) val_acc = acc_results['valid-accuracy'] / 100 test_acc = acc_results['test-accuracy'] / 100 auxiliary_info = self.api.query_meta_info_by_index(arch_index, use_12epochs_result=self.use_12_epochs_result) cost_info = auxiliary_info.get_compute_costs(self.task) # auxiliary cost results such as number of flops and number of parameters cost_results = {'flops': cost_info['flops'], 'params': cost_info['params'], 'latency': cost_info['latency']} except FileNotFoundError: val_acc = 0.01 test_acc = 0.01 print('missing arch info') cost_results = {'flops': None, 'params': None, 'latency': None} # store val and test performance + auxiliary cost information self.X.append(arch_str) self.y_valid_acc.append(val_acc) self.y_test_acc.append(test_acc) self.costs.append(cost_results) if which == 'eval': err = 1 - val_acc elif which == 'test': err = 1 - test_acc else: raise ValueError("Unknown query parameter: which = " + str(which)) if self.log_scale: y = np.log(err) else: y = err if self.negative: y = -y return y def eval(self, G, budget=199, n_repeat=1): # input is a list of graphs [G1,G2, ....] if n_repeat == 1: return self._retrieve(G, budget, 'eval'), [np.nan] return np.mean(np.array([self._retrieve(G, budget, 'eval') for _ in range(n_repeat)])), [np.nan] def test(self, G, budget=199, n_repeat=1): return np.mean(np.array([self._retrieve(G, budget, 'test') for _ in range(n_repeat)])) def get_results(self, ignore_invalid_configs=False): regret_validation = [] regret_test = [] costs = [] model_graph_specs = [] inc_valid = 0 inc_test = 0 for i in range(len(self.X)): if inc_valid < self.y_valid_acc[i]: inc_valid = self.y_valid_acc[i] inc_test = self.y_test_acc[i] regret_validation.append(float(self.best_val_acc - inc_valid)) regret_test.append(float(self.best_test_acc - inc_test)) model_graph_specs.append(self.X[i]) costs.append(self.costs[i]) res = dict() res['regret_validation'] = regret_validation res['regret_test'] = regret_test res['costs'] = costs res['model_graph_specs'] = model_graph_specs return res @staticmethod def get_configuration_space(): # for unpruned graph cs = ConfigSpace.ConfigurationSpace() ops_choices = ['nor_conv_3x3', 'nor_conv_1x1', 'avg_pool_3x3', 'skip_connect', 'none'] for i in range(6): cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("edge_%d" % i, ops_choices)) return cs