Beispiel #1
0
class Nasbench201(Dataset):
    """Nasbench201 Dataset."""
    def __init__(self):
        """Construct the Nasbench201 class."""
        super(Nasbench201, self).__init__()
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        self.nasbench201_api = API('self.args.data_path')

    def query(self, arch_str, dataset):
        """Query an item from the dataset according to the given arch_str and dataset .

        :arch_str: arch_str to define the topology of the cell
        :type path: str
        :dataset: dataset type
        :type dataset: str
        :return: an item of the dataset, which contains the network info and its results like accuracy, flops and etc
        :rtype: dict
        """
        if dataset not in VALID_DATASET:
            raise ValueError(
                "Only cifar10, cifar100, and Imagenet dataset is supported.")
        ops_list = self.nasbench201_api.str2lists(arch_str)
        for op in ops_list:
            if op not in VALID_OPS:
                raise ValueError(
                    "{} is not in the nasbench201 space.".format(op))
        index = self.nasbench201_api.query_index_by_arch(arch_str)

        results = self.nasbench201_api.query_by_index(index, dataset)
        return results
def visualize_rank_over_time(meta_file, vis_save_dir):
    print('\n' + '-' * 150)
    vis_save_dir.mkdir(parents=True, exist_ok=True)
    print('{:} start to visualize rank-over-time into {:}'.format(
        time_string(), vis_save_dir))
    cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        print('{:} load nas_bench done'.format(time_string()))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(
            list), defaultdict(list), defaultdict(list), defaultdict(list)
        #for iepoch in range(200): for index in range( len(nas_bench) ):
        for index in tqdm(range(len(nas_bench))):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            for iepoch in range(200):
                res = info.get_metrics('cifar10', 'train', iepoch)
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid', iepoch)
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                otest_acc = res['accuracy']
                train_accs[iepoch].append(train_acc)
                valid_accs[iepoch].append(valid_acc)
                test_accs[iepoch].append(test_acc)
                otest_accs[iepoch].append(otest_acc)
                if iepoch == 0:
                    res = info.get_comput_costs('cifar10')
                    flop, param = res['flops'], res['params']
                    flops.append(flop)
                    params.append(param)
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
    print('{:} collect data done.'.format(time_string()))
    #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
    selected_epochs = list(range(200))
    x_xtests = test_accs[199]
    indexes = list(range(len(x_xtests)))
    ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
    for sepoch in selected_epochs:
        x_valids = valid_accs[sepoch]
        valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
        valid_ord_lbls = []
        for idx in ord_idxs:
            valid_ord_lbls.append(valid_ord_idxs.index(idx))
        # labeled data
        dpi, width, height = 300, 2600, 2600
        figsize = width / float(dpi), height / float(dpi)
        LabelSize, LegendFontsize = 18, 18

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
        plt.xlim(min(indexes), max(indexes))
        plt.ylim(min(indexes), max(indexes))
        plt.yticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize,
                   rotation='vertical')
        plt.xticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize)
        ax.scatter(indexes,
                   valid_ord_lbls,
                   marker='^',
                   s=0.5,
                   c='tab:green',
                   alpha=0.8)
        ax.scatter(indexes,
                   indexes,
                   marker='o',
                   s=0.5,
                   c='tab:blue',
                   alpha=0.8)
        ax.scatter([-1], [-1],
                   marker='^',
                   s=100,
                   c='tab:green',
                   label='CIFAR-10 validation')
        ax.scatter([-1], [-1],
                   marker='o',
                   s=100,
                   c='tab:blue',
                   label='CIFAR-10 test')
        plt.grid(zorder=0)
        ax.set_axisbelow(True)
        plt.legend(loc='upper left', fontsize=LegendFontsize)
        ax.set_xlabel('architecture ranking in the final test accuracy',
                      fontsize=LabelSize)
        ax.set_ylabel('architecture ranking in the validation set',
                      fontsize=LabelSize)
        save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
        save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
        print('{:} save into {:}'.format(time_string(), save_path))
        plt.close('all')
def visualize_info(meta_file, dataset, vis_save_dir):
    print('{:} start to visualize {:} information'.format(
        time_string(), dataset))
    cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
        for index in range(len(nas_bench)):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            resx = info.get_comput_costs(dataset)
            flop, param = resx['flops'], resx['params']
            if dataset == 'cifar10':
                res = info.get_metrics('cifar10', 'train')
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                otest_acc = res['accuracy']
            else:
                res = info.get_metrics(dataset, 'train')
                train_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-test')
                test_acc = res['accuracy']
                res = info.get_metrics(dataset, 'ori-test')
                otest_acc = res['accuracy']
            if index == 11472:  # resnet
                resnet = {
                    'params': param,
                    'flops': flop,
                    'index': 11472,
                    'train_acc': train_acc,
                    'valid_acc': valid_acc,
                    'test_acc': test_acc,
                    'otest_acc': otest_acc
                }
            flops.append(flop)
            params.append(param)
            train_accs.append(train_acc)
            valid_accs.append(valid_acc)
            test_accs.append(test_acc)
            otest_accs.append(otest_acc)
        #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        info['resnet'] = resnet
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
        resnet = info['resnet']
    print('{:} collect data done.'.format(time_string()))

    indexes = list(range(len(params)))
    dpi, width, height = 300, 2600, 2600
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 22, 22
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['valid_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=0.4)
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(20, 100)
        plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(25, 76)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['train_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(0, max(indexes))
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 5),
               fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['index']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ID', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Beispiel #4
0
class MacroGraph(NodeOpGraph):
    def __init__(self, config, primitives, ops_dict, *args, **kwargs):
        self.config = config
        self.primitives = primitives
        self.ops_dict = ops_dict
        self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        super(MacroGraph, self).__init__(*args, **kwargs)

    def _build_graph(self):
        num_cells_per_stack = self.config['num_cells_per_stack']
        C = self.config['init_channels']
        layer_channels = [C] * num_cells_per_stack + [C * 2] + [C * 2] * num_cells_per_stack + [C * 4] + [
            C * 4] * num_cells_per_stack
        layer_reductions = [False] * num_cells_per_stack + [True] + [False] * num_cells_per_stack + [True] + [
            False] * num_cells_per_stack

        stem = NASBENCH_201_Stem(C=C)
        self.add_node(0, type='input')
        self.add_node(1, op=stem, type='stem')

        C_prev = C
        self.cells = nn.ModuleList()
        for cell_num, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
            if reduction:
                cell = ResNetBasicblock(C_prev, C_curr, 2, True)
                self.add_node(cell_num + 2, op=cell, primitives=self.primitives, transform=lambda x: x[0])
            else:
                cell = Cell(primitives=self.primitives, stride=1, C_prev=C_prev, C=C_curr,
                            ops_dict=self.ops_dict, cell_type='normal')
                self.add_node(cell_num + 2, op=cell, primitives=self.primitives)

            C_prev = C_curr

        lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
        pooling = nn.AdaptiveAvgPool2d(1)
        classifier = nn.Linear(C_prev, self.config['num_classes'])

        self.add_node(cell_num + 3, op=lastact, transform=lambda x: x[0], type='postprocessing_nb201')
        self.add_node(cell_num + 4, op=pooling, transform=lambda x: x[0], type='pooling')
        self.add_node(cell_num + 5, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1),
                      type='output')

        # Edges
        for i in range(1, cell_num + 6):
            self.add_edge(i - 1, i, type='input', desc='previous')

    def sample(self, same_cell_struct=True, n_ops_per_edge=1,
               n_input_edges=None, dist=None, seed=1):
        """
        same_cell_struct:
            True; if the sampled cell topology is the same or not
        n_ops_per_edge:
            1; number of sampled operations per edge in cell
        n_input_edges:
            None; list equal with length with number of intermediate
        nodes. Determines the number of predecesor nodes for each of them
        dist:
            None; distribution to sample operations in edges from
        seed:
            1; random seed
        """
        # create a new graph that we will discretize
        new_graph = MacroGraph(self.config, self.primitives, self.ops_dict)
        np.random.seed(seed)
        seeds = {'normal': seed + 1, 'reduction': seed + 2}

        for node in new_graph:
            cell = new_graph.get_node_op(node)
            if not isinstance(cell, Cell):
                continue

            if same_cell_struct:
                np.random.seed(seeds[new_graph.get_node_type(node)])

            for edge in cell.edges:
                op_choices = cell.get_edge_op_choices(*edge)
                sampled_op = np.random.choice(op_choices, n_ops_per_edge,
                                              False, p=dist)
                cell[edge[0]][edge[1]]['op_choices'] = [*sampled_op]

            if n_input_edges is not None:
                for inter_node, k in zip(cell.inter_nodes(), n_input_edges):
                    # in case the start node index is not 0
                    node_idx = list(cell.nodes).index(inter_node)
                    prev_node_choices = list(cell.nodes)[:node_idx]
                    assert k <= len(prev_node_choices), 'cannot sample more'
                    ' than number of predecesor nodes'

                    sampled_input_edges = np.random.choice(prev_node_choices,
                                                           k, False)
                    for i in set(prev_node_choices) - set(sampled_input_edges):
                        cell.remove_edge(i, inter_node)

        return new_graph

    @classmethod
    def from_config(cls, config=None, filename=None):
        with open(filename, 'r') as f:
            graph_dict = yaml.safe_load(f)

        if config is None:
            raise ('No configuration provided')

        graph = cls(config, [])

        graph_type = graph_dict['type']
        edges = [(*eval(e), attr) for e, attr in graph_dict['edges'].items()]
        graph.clear()
        graph.add_edges_from(edges)

        C = config['init_channels']
        C_curr = config['stem_multiplier'] * C

        stem = Stem(C_curr=C_curr)
        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C

        for node, attr in graph_dict['nodes'].items():
            node_type = attr['type']
            if node_type == 'input':
                graph.add_node(node, type='input')
            elif node_type == 'stem':
                graph.add_node(node, op=stem, type='stem')
            elif node_type in ['normal', 'reduction']:
                assert attr['op']['type'] == 'Cell'
                graph.add_node(node,
                               op=Cell.from_config(attr['op'], primitives=attr['op']['primitives'],
                                                   C_prev_prev=C_prev_prev, C_prev=C_prev,
                                                   C=C_curr,
                                                   reduction_prev=graph_dict['nodes'][node - 1]['type'] == 'reduction',
                                                   cell_type=node_type),
                               type=node_type)
                C_prev_prev, C_prev = C_prev, config['channel_multiplier'] * C_curr
            elif node_type == 'pooling':
                pooling = nn.AdaptiveAvgPool2d(1)
                graph.add_node(node, op=pooling, transform=lambda x: x[0],
                               type='pooling')
            elif node_type == 'output':
                classifier = nn.Linear(C_prev, config['num_classes'])
                graph.add_node(node, op=classifier, transform=lambda x:
                x[0].view(x[0].size(0), -1), type='output')

        return graph

    @staticmethod
    def export_nasbench_201_results_to_dict(information):
        results_dict = {}
        dataset_names = information.get_dataset_names()
        results_dict['arch'] = information.arch_str
        results_dict['datasets'] = dataset_names

        for ida, dataset in enumerate(dataset_names):
            dataset_results = {}
            dataset_results['dataset'] = dataset

            metric = information.get_compute_costs(dataset)
            flop, param, latency, training_time = metric['flops'], metric['params'], metric['latency'], metric[
                'T-train@total']
            dataset_results['flop'] = flop
            dataset_results['params (MB)'] = param
            dataset_results['latency (ms)'] = latency * 1000 if latency is not None and latency > 0 else None
            dataset_results['training_time'] = training_time

            train_info = information.get_metrics(dataset, 'train')
            if dataset == 'cifar10-valid':
                valid_info = information.get_metrics(dataset, 'x-valid')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['valid_loss'] = valid_info['loss']
                dataset_results['valid_accuracy'] = valid_info['accuracy']

            elif dataset == 'cifar10':
                test__info = information.get_metrics(dataset, 'ori-test')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['test_loss'] = test__info['loss']
                dataset_results['test_accuracy'] = test__info['accuracy']
            else:
                valid_info = information.get_metrics(dataset, 'x-valid')
                test__info = information.get_metrics(dataset, 'x-test')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['valid_loss'] = valid_info['loss']
                dataset_results['valid_accuracy'] = valid_info['accuracy']

                dataset_results['test_loss'] = test__info['loss']
                dataset_results['test_accuracy'] = test__info['accuracy']
            results_dict[dataset] = dataset_results
        return results_dict

    def query_architecture(self, arch_weights):
        arch_weight_idx_to_parent = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2}
        arch_strs = {
            'cell_normal_from_0_to_1': '',
            'cell_normal_from_0_to_2': '',
            'cell_normal_from_1_to_2': '',
            'cell_normal_from_0_to_3': '',
            'cell_normal_from_1_to_3': '',
            'cell_normal_from_2_to_3': '',
        }
        for arch_weight_idx, (edge_key, edge_weights) in enumerate(arch_weights.items()):
            edge_weights_norm = torch.softmax(edge_weights, dim=-1)
            selected_op_str = PRIMITIVES[edge_weights_norm.argmax()]
            arch_strs[edge_key] = '{}~{}'.format(selected_op_str, arch_weight_idx_to_parent[arch_weight_idx])

        arch_str = '|{}|+|{}|{}|+|{}|{}|{}|'.format(*arch_strs.values())
        if not hasattr(self, 'nasbench_api'):
            self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        index = self.nasbench_api.query_index_by_arch(arch_str)
        self.nasbench_api.show(index)
        info = self.nasbench_api.query_by_index(index)
        return self.export_nasbench_201_results_to_dict(info)
        network = get_cell_based_tiny_net(config)  # create the network from configuration
        network = network.to(device)

        jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args)
        jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()

        try:
            s = eval_score(jacobs, labels)
        except Exception as e:
            print(e)
            s = np.nan

        scores.append(s)

    best_arch = indices[order_fn(scores)]
    info      = api.query_by_index(best_arch)
    topscores.append(scores[order_fn(scores)])
    chosen.append(best_arch)
    acc.append(info.get_metrics(dset, acc_type)['accuracy'])

    if not args.dataset == 'cifar10' or args.trainval:
        val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])

    times.append(time.time()-start)
    runs.set_description(f"acc: {mean(acc if not args.trainval else val_acc):.2f}%")

print(f"Final mean test accuracy: {np.mean(acc)}")
if len(val_acc) > 1:
    print(f"Final mean validation accuracy: {np.mean(val_acc)}")

state = {'accs': acc,
Beispiel #6
0
class NAS201(ObjectiveFunction):

    def __init__(self, data_dir, task='cifar10-valid', log_scale=True, negative=True,
                 use_12_epochs_result=False,
                 seed=None):
        """
        data_dir: data directory that contains NAS-Bench-201-v1_0-e61699.pth file
        task: the target image tasks. Options: cifar10-valid, cifar100, ImageNet16-120
        log_scale: whether output the objective in log scale
        negative: whether output the objective in negative form
        use_12_epochs_result: whether use the statistics at the end of training of the 12th epoch instead of all the
                              way till the end.
        seed: set the random seed to access trained model performance: Options: 0, 1, 2
              seed=None will select the seed randomly
        """

        self.api = API(os.path.join(data_dir, 'NAS-Bench-201-v1_1-096897.pth'))
        if isinstance(task, list):
            task = task[0]
        self.task = task
        self.use_12_epochs_result = use_12_epochs_result

        if task == 'cifar10-valid':
            best_val_arch_index = 6111
            best_val_acc = 91.60666665039064 / 100
            best_test_arch_index = 1459
            best_test_acc = 91.52333333333333 / 100
        elif task == 'cifar100':
            best_val_arch_index = 9930
            best_val_acc = 73.49333323567708 / 100
            best_test_arch_index = 9930
            best_test_acc = 73.51333326009114 / 100
        elif task == 'ImageNet16-120':
            best_val_arch_index = 10676
            best_val_acc = 46.766666727701825 / 100
            best_test_arch_index = 857
            best_test_acc = 47.311111097547744 / 100
        else:
            raise NotImplementedError("task" + str(task) + " is not implemented in the dataset.")

        if log_scale:
            best_val_acc = np.log(best_val_acc)

        best_val_err = 1. - best_val_acc
        best_test_err = 1. - best_test_acc
        if log_scale:
            best_val_err = np.log(best_val_err)
            best_test_err = np.log(best_val_err)
        if negative:
            best_val_err = -best_val_err
            best_test_err = -best_test_err

        self.best_val_err = best_val_err
        self.best_test_err = best_test_err
        self.best_val_acc = best_val_acc
        self.best_test_acc = best_test_acc

        super(NAS201, self).__init__(dim=None, optimum_location=best_test_arch_index, optimal_val=best_test_err,
                                     bounds=None)

        self.log_scale = log_scale
        self.seed = seed
        self.X = []
        self.y_valid_acc = []
        self.y_test_acc = []
        self.costs = []
        self.negative = negative
        # self.optimal_val =   # lowest mean validation error
        # self.y_star_test =   # lowest mean test error

    def _retrieve(self, G, budget, which='eval'):
        #  set random seed for evaluation
        if which == 'test':
            seed = 3
        else:
            seed_list = [777, 888, 999]
            if self.seed is None:
                seed = random.choice(seed_list)
            elif self.seed >= 3:
                seed = self.seed
            else:
                seed = seed_list[self.seed]

        # find architecture index
        arch_str = G.name
        # print(arch_str)

        try:
            arch_index = self.api.query_index_by_arch(arch_str)
            acc_results = self.api.query_by_index(arch_index, self.task, use_12epochs_result=self.use_12_epochs_result,)
            if seed is not None and 3 <= seed < 777:
                # some architectures only contain 1 seed result
                acc_results = self.api.get_more_info(arch_index, self.task, None,
                                                     use_12epochs_result=self.use_12_epochs_result,
                                                     is_random=False)
                val_acc = acc_results['valid-accuracy'] / 100
                test_acc = acc_results['test-accuracy'] / 100
            else:
                try:
                    acc_results = self.api.get_more_info(arch_index, self.task, None,
                                                         use_12epochs_result=self.use_12_epochs_result,
                                                         is_random=seed)
                    val_acc = acc_results['valid-accuracy'] / 100
                    test_acc = acc_results['test-accuracy'] / 100
                    # val_acc = acc_results[seed].get_eval('x-valid')['accuracy'] / 100
                    # if self.task == 'cifar10-valid':
                    #     test_acc = acc_results[seed].get_eval('ori-test')['accuracy'] / 100
                    # else:
                    #     test_acc = acc_results[seed].get_eval('x-test')['accuracy'] / 100
                except:
                    # some architectures only contain 1 seed result
                    acc_results = self.api.get_more_info(arch_index, self.task, None,
                                                         use_12epochs_result=self.use_12_epochs_result,
                                                         is_random=False)
                    val_acc = acc_results['valid-accuracy'] / 100
                    test_acc = acc_results['test-accuracy'] / 100

            auxiliary_info = self.api.query_meta_info_by_index(arch_index,
                                                               use_12epochs_result=self.use_12_epochs_result)
            cost_info = auxiliary_info.get_compute_costs(self.task)

            # auxiliary cost results such as number of flops and number of parameters
            cost_results = {'flops': cost_info['flops'], 'params': cost_info['params'],
                            'latency': cost_info['latency']}

        except FileNotFoundError:
            val_acc = 0.01
            test_acc = 0.01
            print('missing arch info')
            cost_results = {'flops': None, 'params': None,
                            'latency': None}

        # store val and test performance + auxiliary cost information
        self.X.append(arch_str)
        self.y_valid_acc.append(val_acc)
        self.y_test_acc.append(test_acc)
        self.costs.append(cost_results)

        if which == 'eval':
            err = 1 - val_acc
        elif which == 'test':
            err = 1 - test_acc
        else:
            raise ValueError("Unknown query parameter: which = " + str(which))

        if self.log_scale:
            y = np.log(err)
        else:
            y = err
        if self.negative:
            y = -y
        return y

    def eval(self, G, budget=199, n_repeat=1):
        # input is a list of graphs [G1,G2, ....]
        if n_repeat == 1:
            return self._retrieve(G, budget, 'eval'),  [np.nan]
        return np.mean(np.array([self._retrieve(G, budget, 'eval') for _ in range(n_repeat)])), [np.nan]

    def test(self, G, budget=199, n_repeat=1):
        return np.mean(np.array([self._retrieve(G, budget, 'test') for _ in range(n_repeat)]))

    def get_results(self, ignore_invalid_configs=False):

        regret_validation = []
        regret_test = []
        costs = []
        model_graph_specs = []

        inc_valid = 0
        inc_test = 0

        for i in range(len(self.X)):

            if inc_valid < self.y_valid_acc[i]:
                inc_valid = self.y_valid_acc[i]
                inc_test = self.y_test_acc[i]

            regret_validation.append(float(self.best_val_acc - inc_valid))
            regret_test.append(float(self.best_test_acc - inc_test))
            model_graph_specs.append(self.X[i])
            costs.append(self.costs[i])

        res = dict()
        res['regret_validation'] = regret_validation
        res['regret_test'] = regret_test
        res['costs'] = costs
        res['model_graph_specs'] = model_graph_specs

        return res

    @staticmethod
    def get_configuration_space():
        # for unpruned graph
        cs = ConfigSpace.ConfigurationSpace()

        ops_choices = ['nor_conv_3x3', 'nor_conv_1x1', 'avg_pool_3x3', 'skip_connect', 'none']
        for i in range(6):
            cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("edge_%d" % i, ops_choices))
        return cs