示例#1
0
def visualize_rank_over_time(meta_file, vis_save_dir):
    print('\n' + '-' * 150)
    vis_save_dir.mkdir(parents=True, exist_ok=True)
    print('{:} start to visualize rank-over-time into {:}'.format(
        time_string(), vis_save_dir))
    cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        print('{:} load nas_bench done'.format(time_string()))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(
            list), defaultdict(list), defaultdict(list), defaultdict(list)
        #for iepoch in range(200): for index in range( len(nas_bench) ):
        for index in tqdm(range(len(nas_bench))):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            for iepoch in range(200):
                res = info.get_metrics('cifar10', 'train', iepoch)
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid', iepoch)
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                otest_acc = res['accuracy']
                train_accs[iepoch].append(train_acc)
                valid_accs[iepoch].append(valid_acc)
                test_accs[iepoch].append(test_acc)
                otest_accs[iepoch].append(otest_acc)
                if iepoch == 0:
                    res = info.get_comput_costs('cifar10')
                    flop, param = res['flops'], res['params']
                    flops.append(flop)
                    params.append(param)
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
    print('{:} collect data done.'.format(time_string()))
    #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
    selected_epochs = list(range(200))
    x_xtests = test_accs[199]
    indexes = list(range(len(x_xtests)))
    ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
    for sepoch in selected_epochs:
        x_valids = valid_accs[sepoch]
        valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
        valid_ord_lbls = []
        for idx in ord_idxs:
            valid_ord_lbls.append(valid_ord_idxs.index(idx))
        # labeled data
        dpi, width, height = 300, 2600, 2600
        figsize = width / float(dpi), height / float(dpi)
        LabelSize, LegendFontsize = 18, 18

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
        plt.xlim(min(indexes), max(indexes))
        plt.ylim(min(indexes), max(indexes))
        plt.yticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize,
                   rotation='vertical')
        plt.xticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize)
        ax.scatter(indexes,
                   valid_ord_lbls,
                   marker='^',
                   s=0.5,
                   c='tab:green',
                   alpha=0.8)
        ax.scatter(indexes,
                   indexes,
                   marker='o',
                   s=0.5,
                   c='tab:blue',
                   alpha=0.8)
        ax.scatter([-1], [-1],
                   marker='^',
                   s=100,
                   c='tab:green',
                   label='CIFAR-10 validation')
        ax.scatter([-1], [-1],
                   marker='o',
                   s=100,
                   c='tab:blue',
                   label='CIFAR-10 test')
        plt.grid(zorder=0)
        ax.set_axisbelow(True)
        plt.legend(loc='upper left', fontsize=LegendFontsize)
        ax.set_xlabel('architecture ranking in the final test accuracy',
                      fontsize=LabelSize)
        ax.set_ylabel('architecture ranking in the validation set',
                      fontsize=LabelSize)
        save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
        save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
        print('{:} save into {:}'.format(time_string(), save_path))
        plt.close('all')
示例#2
0
def visualize_info(meta_file, dataset, vis_save_dir):
    print('{:} start to visualize {:} information'.format(
        time_string(), dataset))
    cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
        for index in range(len(nas_bench)):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            resx = info.get_comput_costs(dataset)
            flop, param = resx['flops'], resx['params']
            if dataset == 'cifar10':
                res = info.get_metrics('cifar10', 'train')
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                otest_acc = res['accuracy']
            else:
                res = info.get_metrics(dataset, 'train')
                train_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-test')
                test_acc = res['accuracy']
                res = info.get_metrics(dataset, 'ori-test')
                otest_acc = res['accuracy']
            if index == 11472:  # resnet
                resnet = {
                    'params': param,
                    'flops': flop,
                    'index': 11472,
                    'train_acc': train_acc,
                    'valid_acc': valid_acc,
                    'test_acc': test_acc,
                    'otest_acc': otest_acc
                }
            flops.append(flop)
            params.append(param)
            train_accs.append(train_acc)
            valid_accs.append(valid_acc)
            test_accs.append(test_acc)
            otest_accs.append(otest_acc)
        #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        info['resnet'] = resnet
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
        resnet = info['resnet']
    print('{:} collect data done.'.format(time_string()))

    indexes = list(range(len(params)))
    dpi, width, height = 300, 2600, 2600
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 22, 22
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['valid_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=0.4)
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(20, 100)
        plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(25, 76)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['train_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(0, max(indexes))
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 5),
               fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['index']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ID', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')