コード例 #1
0
def show_multi_trial(search_space):
    api = create(None, search_space, fast_mode=True, verbose=False)

    def show(dataset):
        print('show {:} on {:} done.'.format(dataset, search_space))
        xdataset, max_time = dataset.split('-T')
        alg2data = fetch_data(search_space=search_space, dataset=dataset)
        for idx, (alg, data) in enumerate(alg2data.items()):

            valid_accs, test_accs = [], []
            for _, x in data.items():
                v_acc, t_acc = query_performance(api, x, xdataset,
                                                 float(max_time))
                valid_accs.append(v_acc)
                test_accs.append(t_acc)
            valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs),
                                                   np.std(valid_accs))
            test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs),
                                                  np.std(test_accs))
            print('{:} plot alg : {:10s}  | validation = {:} | test = {:}'.
                  format(time_string(), alg, valid_str, test_str))

    if search_space == 'tss':
        datasets = [
            'cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000'
        ]
    elif search_space == 'sss':
        datasets = [
            'cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000'
        ]
    else:
        raise ValueError('Unknown search space: {:}'.format(search_space))
    for dataset in datasets:
        show(dataset)
    print('{:} complete show multi-trial results.\n'.format(time_string()))
コード例 #2
0
 def make_nats(path_full: str,
               path_save: str = None) -> MiniNATSBenchTabularBenchmark:
     api = create(replace_standard_paths(path_full),
                  'tss',
                  fast_mode=True,
                  verbose=True)
     mini = MiniNATSBenchTabularBenchmark.make_from_full_api(api)
     if isinstance(path_save, str):
         mini.save(path_save)
     return mini
コード例 #3
0
ファイル: natsbench.py プロジェクト: ain-soph/trojanzoo
    def __init__(self,
                 name: str = 'nats_bench',
                 model: type[_NATSbench] = _NATSbench,
                 model_index: int = 0,
                 model_seed: int = 999,
                 hp: int = 200,
                 dataset: ImageSet | None = None,
                 dataset_name: str | None = None,
                 nats_path: str | None = None,
                 search_space: str = 'tss',
                 **kwargs):
        try:
            # pip install nats_bench
            from nats_bench import create  # type: ignore
            from xautodl.models import get_cell_based_tiny_net  # type: ignore
        except ImportError:
            raise ImportError(
                'You need to install nats_bench and auto-dl library')

        if isinstance(dataset, ImageSet):
            kwargs['dataset'] = dataset
            if dataset_name is None:
                dataset_name = dataset.name
                if dataset_name == 'imagenet16':
                    dataset_name = f'imagenet16-{dataset.num_classes:d}'
        assert dataset_name is not None
        dataset_name = dataset_name.replace('imagenet16', 'ImageNet16')
        self.dataset_name = dataset_name

        self.model_index = model_index
        self.model_seed = model_seed
        self.hp = hp
        self.search_space = search_space
        self.nats_path = nats_path

        self.api = create(nats_path,
                          search_space,
                          fast_mode=True,
                          verbose=False)
        config: dict[str,
                     Any] = self.api.get_net_config(model_index, dataset_name)
        self.get_cell_based_tiny_net: Callable[
            ..., nn.Module] = get_cell_based_tiny_net
        network = self.get_cell_based_tiny_net(config)
        super().__init__(name=name, model=model, network=network, **kwargs)
        self.param_list['nats_bench'] = [
            'model_index', 'model_seed', 'hp', 'search_space', 'nats_path'
        ]
        self._model: _NATSbench
コード例 #4
0
    def __init__(self, filepath=None) -> None:
        if filepath is None:
            filepath = NATSBENCH_TFRECORD
        if not os.path.isdir(DATA_ROOT):
            os.mkdir(DATA_ROOT)

        if not os.path.isdir(filepath):
            print("Downloading NATSBench Data.")
            download_file_from_google_drive(file_id, filepath + ".tar")
            print("Downloaded, extracting.")
            untar_file(filepath + ".tar", DATA_ROOT)

        self.api = create(filepath + "/" + NATSBENCH_NAME,
                          "tss",
                          fast_mode=True,
                          verbose=False)
コード例 #5
0
    def __init__(
            self,
            name: str = 'natsbench',
            model: type[_NATSbench] = _NATSbench,
            model_index: int = None,
            model_seed: int = None,
            dataset: ImageSet = None,
            dataset_name: str = None,
            nats_path: str = '/data/rbp5354/nats/NATS-tss-v1_0-3ffb9-full',
            autodl_path: str = '/home/rbp5354/workspace/XAutoDL/lib',
            search_space: str = 'tss',
            **kwargs):
        try:
            import sys
            sys.path.append(autodl_path)
            from nats_bench import create  # type: ignore
            from models import get_cell_based_tiny_net  # type: ignore
        except ImportError as e:
            print('You need to install nats_bench and auto-dl library')
            raise e

        if dataset is not None:
            assert isinstance(dataset, ImageSet)
            kwargs['dataset'] = dataset
            if dataset_name is None:
                dataset_name = dataset.name
        assert dataset_name is not None
        self.dataset_name = dataset_name

        self.search_space = search_space
        self.model_index = model_index
        self.model_seed = model_seed

        self.api = create(nats_path,
                          search_space,
                          fast_mode=True,
                          verbose=False)
        config: dict[str,
                     Any] = self.api.get_net_config(model_index, dataset_name)
        network: nn.Module = get_cell_based_tiny_net(config)
        super().__init__(name=name, model=model, network=network, **kwargs)
        self.param_list['natsbench'] = [
            'model_index', 'model_seed', 'search_space'
        ]
コード例 #6
0
        description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="output/vis-nas-bench",
        help="Folder to save checkpoints and log.",
    )
    # use for train the model
    args = parser.parse_args()

    to_save_dir = Path(args.save_dir)

    datasets = ["cifar10", "cifar100", "ImageNet16-120"]
    # Figure 3 (a-c)
    api_tss = create(None, "tss", verbose=True)
    for xdata in datasets:
        visualize_tss_info(api_tss, xdata, to_save_dir)
    # Figure 3 (d-f)
    api_sss = create(None, "size", verbose=True)
    for xdata in datasets:
        visualize_sss_info(api_sss, xdata, to_save_dir)

    # Figure 2
    visualize_relative_info(None, to_save_dir, "tss")
    visualize_relative_info(None, to_save_dir, "sss")

    # Figure 4
    visualize_rank_info(None, to_save_dir, "tss")
    visualize_rank_info(None, to_save_dir, "sss")
コード例 #7
0
ファイル: nasbench.py プロジェクト: Mirofil/SOTL_NAS
from nats_bench import create
import os
os.environ["TORCH_HOME"] = '/notebooks/storage/.torch'
# Create the API instance for the topology search space in NATS
api = create(None, 'tss', fast_mode=True, verbose=True)

info = api.get_more_info(1234, 'cifar10')

# Query the flops, params, latency. info is a dict.
info = api.get_cost_info(12, 'cifar10')

# Simulate the training of the 1224-th candidate:
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
    1224, dataset='cifar10', hp='12')
コード例 #8
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='NAS-Bench-X',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--save_dir',
                        type=str,
                        default='output/vis-nas-bench',
                        help='Folder to save checkpoints and log.')
    # use for train the model
    args = parser.parse_args()

    to_save_dir = Path(args.save_dir)

    datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
    api201 = create(None, 'tss', verbose=True)
    for xdata in datasets:
        visualize_tss_info(api201, xdata, to_save_dir)

    api301 = create(None, 'size', verbose=True)
    for xdata in datasets:
        visualize_sss_info(api301, xdata, to_save_dir)

    visualize_info(None, to_save_dir, 'tss')
    visualize_info(None, to_save_dir, 'sss')
    visualize_rank_info(None, to_save_dir, 'tss')
    visualize_rank_info(None, to_save_dir, 'sss')

    visualize_all_rank_info(None, to_save_dir, 'tss')
    visualize_all_rank_info(None, to_save_dir, 'sss')
コード例 #9
0
def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
    save_dir.mkdir(parents=True, exist_ok=True)
    api = create(meta_file, search_space, verbose=False)
    datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"]
    print(time_string() + " " + "=" * 50)
    for data in datasets:
        hps = api.avaliable_hps
        for hp in hps:
            nums = api.statistics(data, hp=hp)
            total = sum([k * v for k, v in nums.items()])
            print(
                "Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(
                    hp, data, total, nums
                )
            )
    print(time_string() + " " + "=" * 50)

    norms, accuracies = evaluate(api, weight_dir, xdata)

    indexes = list(range(len(norms)))
    norm_indexes = sorted(indexes, key=lambda i: norms[i])
    accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
    labels = []
    for index in norm_indexes:
        labels.append(accy_indexes.index(index))

    dpi, width, height = 200, 1400, 800
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 18, 12
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(min(indexes), max(indexes))
    plt.ylim(min(indexes), max(indexes))
    # plt.ylabel('y').set_rotation(30)
    plt.yticks(
        np.arange(min(indexes), max(indexes), max(indexes) // 3),
        fontsize=LegendFontsize,
        rotation="vertical",
    )
    plt.xticks(
        np.arange(min(indexes), max(indexes), max(indexes) // 5),
        fontsize=LegendFontsize,
    )
    ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
    ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
    ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy")
    ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="Weight watcher")
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=0, fontsize=LegendFontsize)
    ax.set_xlabel(
        "architecture ranking sorted by the test accuracy ", fontsize=LabelSize
    )
    ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize)
    save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
    save_path = (save_dir / "{:}-{:}-test-ww.png".format(search_space, xdata)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
    print("{:} save into {:}".format(time_string(), save_path))

    print("{:} finish this test.".format(time_string()))
コード例 #10
0
ファイル: nats_pgd.py プロジェクト: LiuBoyang93/trojanzoo
    trojanvision.trainer.add_argument(parser)
    trojanvision.marks.add_argument(parser)
    trojanvision.attacks.add_argument(parser)
    args = parser.parse_args()

    env = trojanvision.environ.create(**args.__dict__)
    dataset = trojanvision.datasets.create(**args.__dict__)
    model = trojanvision.models.create(dataset=dataset, **args.__dict__)
    trainer = trojanvision.trainer.create(dataset=dataset, model=model, **args.__dict__)
    mark = trojanvision.marks.create(dataset=dataset, **args.__dict__)
    attack = trojanvision.attacks.create(dataset=dataset, model=model, mark=mark, **args.__dict__)

    if env['verbose']:
        summary(env=env, dataset=dataset, model=model, mark=mark, trainer=trainer, attack=attack)

    api = create('/data/rbp5354/nats/NATS-tss-v1_0-3ffb9-full', 'tss', fast_mode=True, verbose=False)

    counter = 0
    succ_rate_list: list[float] = []
    avg_iter_list: list[float] = []
    for idx in range(5000):
        info = api.get_more_info(idx, 'cifar10', hp="200")
        test_acc: float = info['test-accuracy']
        valid_acc, latency, _, _ = api.simulate_train_eval(
            idx, dataset='cifar10', hp='200')
        if test_acc > 92:
            print(f'{counter+1:<5d} {idx=:<5d} {test_acc=:<10.3f} {valid_acc=:<10.3f}')
            args.model_index = idx
            model = trojanvision.models.create(dataset=dataset, **args.__dict__)
            loss, real_acc = model._validate(indent=8)
            if real_acc <= 92:
コード例 #11
0
    if not sss_or_tss:
        arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
        matrix = api.str2matrix(arch_str)
        print('Compute the adjacency matrix of {:}'.format(arch_str))
        print(matrix)
    info = api.simulate_train_eval(123, 'cifar10')
    print('simulate_train_eval : {:}\n\n'.format(info))


if __name__ == '__main__':

    # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
    for fast_mode in [True, False]:
        for verbose in [True, False]:
            api_nats_tss = create(None,
                                  'tss',
                                  fast_mode=fast_mode,
                                  verbose=True)
            print('{:} create with fast_mode={:} and verbose={:}'.format(
                time_string(), fast_mode, verbose))
            test_api(api_nats_tss, False)
            del api_nats_tss
            gc.collect()

    for fast_mode in [True, False]:
        for verbose in [True, False]:
            print('{:} create with fast_mode={:} and verbose={:}'.format(
                time_string(), fast_mode, verbose))
            api_nats_sss = create(None,
                                  'size',
                                  fast_mode=fast_mode,
                                  verbose=True)
コード例 #12
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='NATS-Bench',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--save_dir',
                        type=str,
                        default='output/vis-nas-bench',
                        help='Folder to save checkpoints and log.')
    # use for train the model
    args = parser.parse_args()

    to_save_dir = Path(args.save_dir)

    datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
    # Figure 3 (a-c)
    api_tss = create(None, 'tss', verbose=True)
    for xdata in datasets:
        visualize_tss_info(api_tss, xdata, to_save_dir)
    # Figure 3 (d-f)
    api_sss = create(None, 'size', verbose=True)
    for xdata in datasets:
        visualize_sss_info(api_sss, xdata, to_save_dir)

    # Figure 2
    visualize_relative_info(None, to_save_dir, 'tss')
    visualize_relative_info(None, to_save_dir, 'sss')

    # Figure 4
    visualize_rank_info(None, to_save_dir, 'tss')
    visualize_rank_info(None, to_save_dir, 'sss')
コード例 #13
0
  # Obtain both cost and performance information
  info = api.get_more_info(1234, 'cifar10')
  print('{:}\n'.format(info))
  print('{:} finish testing the api : {:}'.format(time_string(), api))

  if not is_301:
    arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
    matrix = api.str2matrix(arch_str)
    print('Compute the adjacency matrix of {:}'.format(arch_str))
    print(matrix)
  info = api.simulate_train_eval(123, 'cifar10')
  print('simulate_train_eval : {:}\n\n'.format(info))


if __name__ == '__main__':

  # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
  for fast_mode in [True, False]:
    for verbose in [True, False]:
      api201 = create(None, 'tss', fast_mode=fast_mode, verbose=True)
      print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
      test_api(api201, False)

  for fast_mode in [True, False]:
    for verbose in [True, False]:
      print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
      api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
      print('{:} --->>> {:}'.format(time_string(), api301))
      test_api(api301, True)
コード例 #14
0
                                    dataset='cifar10-valid',
                                    iepoch=11,
                                    hp='200',
                                    is_random=False)
    info = api.query_by_arch(
        '|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|',
        '200')
    print(info)
    structure = CellStructure.str2structure(
        '|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|'
    )
    info = api.query_by_arch(structure, '200')
    print(info)


if __name__ == '__main__':

    api201 = create(
        os.path.join(os.environ['TORCH_HOME'],
                     'NAS-Bench-201-v1_0-e61699.pth'), 'topology', True)
    test_issue_81_82(api201)
    print('Test {:} done'.format(api201))

    api201 = create(None, 'topology', True)  # use the default file path
    test_issue_81_82(api201)
    test_api(api201, False)
    print('Test {:} done'.format(api201))

    api301 = create(None, 'size', True)
    test_api(api301, True)
コード例 #15
0
        ]
    elif search_space == "sss":
        datasets = [
            "cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"
        ]
    else:
        raise ValueError("Unknown search space: {:}".format(search_space))
    for dataset in datasets:
        show(dataset)
    print("{:} complete show multi-trial results.\n".format(time_string()))


if __name__ == "__main__":

    show_multi_trial("tss")
    show_multi_trial("sss")

    api_tss = create(None, "tss", fast_mode=False, verbose=False)
    resnet = "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
    resnet_index = api_tss.query_index_by_arch(resnet)
    print(show_valid_test(api_tss, resnet_index))

    for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
        find_best_valid(api_tss, dataset)

    largest = "64:64:64:64:64"
    largest_index = api_sss.query_index_by_arch(largest)
    print(show_valid_test(api_sss, largest_index))
    for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
        find_best_valid(api_sss, dataset)
コード例 #16
0
            'cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000'
        ]
    elif search_space == 'sss':
        datasets = [
            'cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000'
        ]
    else:
        raise ValueError('Unknown search space: {:}'.format(search_space))
    for dataset in datasets:
        show(dataset)
    print('{:} complete show multi-trial results.\n'.format(time_string()))


if __name__ == '__main__':

    show_multi_trial('tss')
    show_multi_trial('sss')

    api_tss = create(None, 'tss', fast_mode=False, verbose=False)
    resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'
    resnet_index = api_tss.query_index_by_arch(resnet)
    print(show_valid_test(api_tss, resnet_index))

    for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
        find_best_valid(api_tss, dataset)

    largest = '64:64:64:64:64'
    largest_index = api_sss.query_index_by_arch(largest)
    print(show_valid_test(api_sss, largest_index))
    for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
        find_best_valid(api_sss, dataset)
コード例 #17
0
from nats_bench import create


def show_time(api):
  print('Show the time for {:} with 12-epoch-training'.format(api))
  all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
  for index in tqdm.tqdm(range(len(api))):
    info = api.get_more_info(index, 'ImageNet16-120', hp='12')
    imagenet_time = info['train-all-time']
    info = api.get_more_info(index, 'cifar10-valid', hp='12')
    cifar10_time = info['train-all-time']
    info = api.get_more_info(index, 'cifar100', hp='12')
    cifar100_time = info['train-all-time']
    # accumulate the time
    all_cifar10_time += cifar10_time
    all_cifar100_time += cifar100_time
    all_imagenet_time += imagenet_time
  print('The total training time for CIFAR-10        (held-out train set) is {:} seconds'.format(all_cifar10_time))
  print('The total training time for CIFAR-100       (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_cifar100_time, all_cifar100_time / all_cifar10_time))
  print('The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_imagenet_time, all_imagenet_time / all_cifar10_time))


if __name__ == '__main__':

  api_nats_tss = create(None, 'tss', fast_mode=True, verbose=False)
  show_time(api_nats_tss)

  api_nats_sss = create(None, 'sss', fast_mode=True, verbose=False)
  show_time(api_nats_sss)

コード例 #18
0
    if not sss_or_tss:
        arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|"
        matrix = api.str2matrix(arch_str)
        print("Compute the adjacency matrix of {:}".format(arch_str))
        print(matrix)
    info = api.simulate_train_eval(123, "cifar10")
    print("simulate_train_eval : {:}\n\n".format(info))


if __name__ == "__main__":

    # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
    for fast_mode in [True, False]:
        for verbose in [True, False]:
            api_nats_tss = create(None,
                                  "tss",
                                  fast_mode=fast_mode,
                                  verbose=True)
            print("{:} create with fast_mode={:} and verbose={:}".format(
                time_string(), fast_mode, verbose))
            test_api(api_nats_tss, False)
            del api_nats_tss
            gc.collect()

    for fast_mode in [True, False]:
        for verbose in [True, False]:
            print("{:} create with fast_mode={:} and verbose={:}".format(
                time_string(), fast_mode, verbose))
            api_nats_sss = create(None,
                                  "size",
                                  fast_mode=fast_mode,
                                  verbose=True)
コード例 #19
0
ファイル: show-dataset.py プロジェクト: wnov/AutoDL-Projects
from datasets import get_datasets
from nats_bench import create


def show_imagenet_16_120(dataset_dir=None):
  if dataset_dir is None:
    torch_home_dir = os.environ['TORCH_HOME'] if 'TORCH_HOME' in os.environ else os.path.join(os.environ['HOME'], '.torch')
    dataset_dir = os.path.join(torch_home_dir, 'cifar.python', 'ImageNet16')
  train_data, valid_data, xshape, class_num = get_datasets('ImageNet16-120', dataset_dir, -1)
  split_info  = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None)
  print('=' * 10 + ' ImageNet-16-120 ' + '=' * 10)
  print('Training Data: {:}'.format(train_data))
  print('Evaluation Data: {:}'.format(valid_data))
  print('Hold-out training: {:} images.'.format(len(split_info.train)))
  print('Hold-out valid   : {:} images.'.format(len(split_info.valid)))


if __name__ == '__main__':
  # show_imagenet_16_120()
  api_nats_tss = create(None, 'tss', fast_mode=True, verbose=True)

  valid_acc_12e = []
  test_acc_12e = []
  test_acc_200e = []
  for index in range(10000):
    info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='12')
    valid_acc_12e.append(info['valid-accuracy'])  # the validation accuracy after training the model by 12 epochs
    test_acc_12e.append(info['test-accuracy'])    # the test accuracy after training the model by 12 epochs
    info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='200')
    test_acc_200e.append(info['test-accuracy'])   # the test accuracy after training the model by 200 epochs (which I reported in the paper)
コード例 #20
0
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
  vis_save_dir = vis_save_dir.resolve()
  print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk))
  vis_save_dir.mkdir(parents=True, exist_ok=True)
  cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space)
  datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
  if not cache_file_path.exists():
    api = create(None, search_space, fast_mode=False, verbose=False)
    all_infos = OrderedDict()
    for index in range(len(api)):
      all_info = OrderedDict()
      for dataset in datasets:
        info_less = api.get_more_info(index, dataset, hp='12', is_random=False)
        info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False)
        all_info[dataset] = dict(less=info_less['test-accuracy'],
                                 more=info_more['test-accuracy'])
      all_infos[index] = all_info
    torch.save(all_infos, cache_file_path)
    print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path))
  else:
    api = create(None, search_space, fast_mode=True, verbose=False)
    all_infos = torch.load(cache_file_path)


  dpi, width, height = 250, 5000, 1300
  figsize = width / float(dpi), height / float(dpi)
  LabelSize, LegendFontsize = 16, 16

  fig, axs = plt.subplots(1, 3, figsize=figsize)
  datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
  
  def sub_plot_fn(ax, dataset, indicator):
    performances = []
    # pickup top 10% architectures
    for _index in range(len(api)):
      performances.append((all_infos[_index][dataset][indicator], _index))
    performances = sorted(performances, reverse=True)
    performances = performances[: int(len(api) * topk * 0.01)]
    selected_indexes = [x[1] for x in performances]
    print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes)))
    standard_scores = []
    random_scores = []
    for idx in selected_indexes:
      standard_scores.append(
        api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy'])
      random_scores.append(
        api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy'])
    indexes = list(range(len(selected_indexes)))
    standard_indexes = sorted(indexes, key=lambda i: standard_scores[i])
    random_indexes = sorted(indexes, key=lambda i: random_scores[i])
    random_labels = []
    for idx in standard_indexes:
      random_labels.append(random_indexes.index(idx))
    for tick in ax.get_xticklabels():
      tick.set_fontsize(LabelSize - 3)
    for tick in ax.get_yticklabels():
      tick.set_rotation(25)
      tick.set_fontsize(LabelSize - 3)
    ax.set_xlim(0, len(indexes))
    ax.set_ylim(0, len(indexes))
    ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3))
    ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5))
    ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
    ax.scatter(indexes, indexes      , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
    ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials')
    ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial')

    coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
    ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize)
    if dataset == 'cifar10':
      ax.set_ylabel('architecture ranking', fontsize=LabelSize)
    ax.legend(loc=4, fontsize=LegendFontsize)
    return coef

  for dataset, ax in zip(datasets, axs):
    rank_coef = sub_plot_fn(ax, dataset, indicator)
    print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef))

  save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve()
  fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
  save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve()
  fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
  print('Save into {:}'.format(save_path))
コード例 #21
0
def main(xargs):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1
    )
    if xargs.overwite_epochs is None:
        extra_info = {"class_num": class_num, "xshape": xshape}
    else:
        extra_info = {
            "class_num": class_num,
            "xshape": xshape,
            "epochs": xargs.overwite_epochs,
        }
    config = load_config(xargs.config_path, extra_info, logger)
    search_loader, train_loader, valid_loader = get_nas_search_loaders(
        train_data,
        valid_data,
        xargs.dataset,
        "configs/nas-benchmark/",
        (config.batch_size, config.test_batch_size),
        xargs.workers,
    )
    logger.log(
        "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
            xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
        )
    )
    logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))

    search_space = get_search_spaces(xargs.search_space, "nats-bench")

    model_config = dict2config(
        dict(
            name="generic",
            super_type="search-shape",
            candidate_Cs=search_space["candidates"],
            max_num_Cs=search_space["numbers"],
            num_classes=class_num,
            genotype=args.genotype,
            affine=bool(xargs.affine),
            track_running_stats=bool(xargs.track_running_stats),
        ),
        None,
    )
    logger.log("search space : {:}".format(search_space))
    logger.log("model config : {:}".format(model_config))
    search_model = get_cell_based_tiny_net(model_config)
    search_model.set_algo(xargs.algo)
    logger.log("{:}".format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.weights, config
    )
    a_optimizer = torch.optim.Adam(
        search_model.alphas,
        lr=xargs.arch_learning_rate,
        betas=(0.5, 0.999),
        weight_decay=xargs.arch_weight_decay,
        eps=xargs.arch_eps,
    )
    logger.log("w-optimizer : {:}".format(w_optimizer))
    logger.log("a-optimizer : {:}".format(a_optimizer))
    logger.log("w-scheduler : {:}".format(w_scheduler))
    logger.log("criterion   : {:}".format(criterion))
    params = count_parameters_in_MB(search_model)
    logger.log("The parameters of the search model = {:.2f} MB".format(params))
    logger.log("search-space : {:}".format(search_space))
    if bool(xargs.use_api):
        api = create(None, "size", fast_mode=True, verbose=False)
    else:
        api = None
    logger.log("{:} create API = {:} done".format(time_string(), api))

    last_info, model_base_path, model_best_path = (
        logger.path("info"),
        logger.path("model"),
        logger.path("best"),
    )
    network, criterion = search_model.cuda(), criterion.cuda()  # use a single GPU

    last_info, model_base_path, model_best_path = (
        logger.path("info"),
        logger.path("model"),
        logger.path("best"),
    )

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start".format(last_info)
        )
        last_info = torch.load(last_info)
        start_epoch = last_info["epoch"]
        checkpoint = torch.load(last_info["last_checkpoint"])
        genotypes = checkpoint["genotypes"]
        valid_accuracies = checkpoint["valid_accuracies"]
        search_model.load_state_dict(checkpoint["search_model"])
        w_scheduler.load_state_dict(checkpoint["w_scheduler"])
        w_optimizer.load_state_dict(checkpoint["w_optimizer"])
        a_optimizer.load_state_dict(checkpoint["a_optimizer"])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
                last_info, start_epoch
            )
        )
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.random}

    # start training
    start_time, search_time, epoch_time, total_epoch = (
        time.time(),
        AverageMeter(),
        AverageMeter(),
        config.epochs + config.warmup,
    )
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = "Time Left: {:}".format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
        )
        epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)

        if (
            xargs.warmup_ratio is None
            or xargs.warmup_ratio <= float(epoch) / total_epoch
        ):
            enable_controller = True
            network.set_warmup_ratio(None)
        else:
            enable_controller = False
            network.set_warmup_ratio(
                1.0 - float(epoch) / total_epoch / xargs.warmup_ratio
            )

        logger.log(
            "\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}".format(
                epoch_str,
                need_time,
                min(w_scheduler.get_lr()),
                network.warmup_ratio,
                enable_controller,
            )
        )

        if xargs.algo == "mask_gumbel" or xargs.algo == "tas":
            network.set_tau(
                xargs.tau_max
                - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
            )
            logger.log("[RESET tau as : {:}]".format(network.tau))
        (
            search_w_loss,
            search_w_top1,
            search_w_top5,
            search_a_loss,
            search_a_top1,
            search_a_top5,
        ) = search_func(
            search_loader,
            network,
            criterion,
            w_scheduler,
            w_optimizer,
            a_optimizer,
            enable_controller,
            xargs.algo,
            epoch_str,
            xargs.print_freq,
            logger,
        )
        search_time.update(time.time() - start_time)
        logger.log(
            "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
                epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
            )
        )
        logger.log(
            "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
                epoch_str, search_a_loss, search_a_top1, search_a_top5
            )
        )

        genotype = network.genotype
        logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion, logger
        )
        logger.log(
            "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format(
                epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype
            )
        )
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log(
            "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
        )
        # save checkpoint
        save_path = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(xargs),
                "search_model": search_model.state_dict(),
                "w_optimizer": w_optimizer.state_dict(),
                "a_optimizer": a_optimizer.state_dict(),
                "w_scheduler": w_scheduler.state_dict(),
                "genotypes": genotypes,
                "valid_accuracies": valid_accuracies,
            },
            model_base_path,
            logger,
        )
        last_info = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(args),
                "last_checkpoint": save_path,
            },
            logger.path("info"),
            logger,
        )
        with torch.no_grad():
            logger.log("{:}".format(search_model.show_alphas()))
        if api is not None:
            logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "90")))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype = network.genotype
    search_time.update(time.time() - start_time)

    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion, logger
    )
    logger.log(
        "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(
            genotype, valid_a_top1
        )
    )

    logger.log("\n" + "-" * 100)
    # check the performance from the architecture dataset
    logger.log(
        "[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
            xargs.algo, total_epoch, search_time.sum, genotype
        )
    )
    if api is not None:
        logger.log("{:}".format(api.query_by_arch(genotype, "90")))
    logger.close()
コード例 #22
0
    datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
    for dataset, ax in zip(datasets, axs):
        sub_plot_fn(ax, dataset)
        print('sub-plot {:} on {:} done.'.format(dataset, search_space))
    save_path = (vis_save_dir /
                 '{:}-ws-curve.png'.format(search_space)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='NAS-Bench-X',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--save_dir',
                        type=str,
                        default='output/vis-nas-bench/nas-algos',
                        help='Folder to save checkpoints and log.')
    parser.add_argument('--search_space',
                        type=str,
                        default='tss',
                        choices=['tss', 'sss'],
                        help='Choose the search space.')
    args = parser.parse_args()

    save_dir = Path(args.save_dir)

    api = create(None, args.search_space, verbose=False)
    visualize_curve(api, save_dir, args.search_space)
コード例 #23
0
        dataset_dir = os.path.join(torch_home_dir, "cifar.python",
                                   "ImageNet16")
    train_data, valid_data, xshape, class_num = get_datasets(
        "ImageNet16-120", dataset_dir, -1)
    split_info = load_config("configs/nas-benchmark/ImageNet16-120-split.txt",
                             None, None)
    print("=" * 10 + " ImageNet-16-120 " + "=" * 10)
    print("Training Data: {:}".format(train_data))
    print("Evaluation Data: {:}".format(valid_data))
    print("Hold-out training: {:} images.".format(len(split_info.train)))
    print("Hold-out valid   : {:} images.".format(len(split_info.valid)))


if __name__ == "__main__":
    # show_imagenet_16_120()
    api_nats_tss = create(None, "tss", fast_mode=True, verbose=True)

    valid_acc_12e = []
    test_acc_12e = []
    test_acc_200e = []
    for index in range(10000):
        info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="12")
        valid_acc_12e.append(
            info["valid-accuracy"]
        )  # the validation accuracy after training the model by 12 epochs
        test_acc_12e.append(
            info["test-accuracy"]
        )  # the test accuracy after training the model by 12 epochs
        info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="200")
        test_acc_200e.append(
            info["test-accuracy"]
コード例 #24
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    if xargs.overwite_epochs is None:
        extra_info = {'class_num': class_num, 'xshape': xshape}
    else:
        extra_info = {
            'class_num': class_num,
            'xshape': xshape,
            'epochs': xargs.overwite_epochs
        }
    config = load_config(xargs.config_path, extra_info, logger)
    search_loader, train_loader, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')

    model_config = dict2config(
        dict(name='generic',
             C=xargs.channel,
             N=xargs.num_cells,
             max_nodes=xargs.max_nodes,
             num_classes=class_num,
             space=search_space,
             affine=bool(xargs.affine),
             track_running_stats=bool(xargs.track_running_stats)), None)
    logger.log('search space : {:}'.format(search_space))
    logger.log('model config : {:}'.format(model_config))
    search_model = get_cell_based_tiny_net(model_config)
    search_model.set_algo(xargs.algo)
    logger.log('{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.weights, config)
    a_optimizer = torch.optim.Adam(search_model.alphas,
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay,
                                   eps=xargs.arch_eps)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    params = count_parameters_in_MB(search_model)
    logger.log('The parameters of the search model = {:.2f} MB'.format(params))
    logger.log('search-space : {:}'.format(search_space))
    if bool(xargs.use_api):
        api = create(None, 'topology', fast_mode=True, verbose=False)
    else:
        api = None
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = search_model.cuda(), criterion.cuda(
    )  # use a single GPU

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        baseline = checkpoint['baseline']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: network.return_topK(1, True)[0]
        }
        baseline = None

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        network.set_drop_path(
            float(epoch + 1) / total_epoch, xargs.drop_path_rate)
        if xargs.algo == 'gdas':
            network.set_tau(xargs.tau_max -
                            (xargs.tau_max - xargs.tau_min) * epoch /
                            (total_epoch - 1))
            logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(
                network.tau, network.drop_path))
        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
        if xargs.algo == 'enas':
            ctl_loss, ctl_acc, baseline, ctl_reward \
                                       = train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger)
            logger.log(
                '[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'
                .format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.eval_candidate_num,
                                                xargs.algo)
        if xargs.algo == 'setn' or xargs.algo == 'enas':
            network.set_cal_mode('dynamic', genotype)
        elif xargs.algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif xargs.algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif xargs.algo == 'random':
            network.set_cal_mode('urs', None)
        else:
            raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
        logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(
            epoch_str, genotype, temp_accuracy))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion, xargs.algo, logger)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'baseline': baseline,
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.eval_candidate_num,
                                            xargs.algo)
    if xargs.algo == 'setn' or xargs.algo == 'enas':
        network.set_cal_mode('dynamic', genotype)
    elif xargs.algo == 'gdas':
        network.set_cal_mode('gdas', None)
    elif xargs.algo.startswith('darts'):
        network.set_cal_mode('joint', None)
    elif xargs.algo == 'random':
        network.set_cal_mode('urs', None)
    else:
        raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
    search_time.update(time.time() - start_time)

    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion, xargs.algo, logger)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
        xargs.algo, total_epoch, search_time.sum, genotype))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotype, '200')))
    logger.close()
コード例 #25
0
    all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
    for index in tqdm.tqdm(range(len(api))):
        info = api.get_more_info(index, "ImageNet16-120", hp=epoch)
        imagenet_time = info["train-all-time"]
        info = api.get_more_info(index, "cifar10-valid", hp=epoch)
        cifar10_time = info["train-all-time"]
        info = api.get_more_info(index, "cifar100", hp=epoch)
        cifar100_time = info["train-all-time"]
        # accumulate the time
        all_cifar10_time += cifar10_time
        all_cifar100_time += cifar100_time
        all_imagenet_time += imagenet_time
    print(
        "The total training time for CIFAR-10        (held-out train set) is {:} seconds"
        .format(all_cifar10_time))
    print(
        "The total training time for CIFAR-100       (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10"
        .format(all_cifar100_time, all_cifar100_time / all_cifar10_time))
    print(
        "The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10"
        .format(all_imagenet_time, all_imagenet_time / all_cifar10_time))


if __name__ == "__main__":

    api_nats_tss = create(None, "tss", fast_mode=True, verbose=False)
    show_time(api_nats_tss, 12)

    api_nats_sss = create(None, "sss", fast_mode=True, verbose=False)
    show_time(api_nats_sss, 12)
コード例 #26
0
                        default='./output/search',
                        help='Folder to save checkpoints and log.')
    parser.add_argument(
        '--arch_nas_dataset',
        type=str,
        help='The path to load the architecture dataset (tiny-nas-benchmark).')
    parser.add_argument('--print_freq',
                        type=int,
                        help='print frequency (default: 200)')
    parser.add_argument('--rand_seed',
                        type=int,
                        default=-1,
                        help='manual seed')
    args = parser.parse_args()

    api = create(None, args.search_space, fast_mode=True, verbose=False)

    args.save_dir = os.path.join(
        '{:}-{:}'.format(args.save_dir, args.search_space), args.dataset,
        'REINFORCE-{:}'.format(args.learning_rate))
    print('save-dir : {:}'.format(args.save_dir))

    if args.rand_seed < 0:
        save_dir, all_info = None, collections.OrderedDict()
        for i in range(args.loops_if_rand):
            print('{:} : {:03d}/{:03d}'.format(time_string(), i,
                                               args.loops_if_rand))
            args.rand_seed = random.randint(1, 100000)
            save_dir, all_archs, all_total_times = main(args, api)
            all_info[i] = {
                'all_archs': all_archs,
コード例 #27
0
        "--save_dir",
        type=str,
        default="output/vis-nas-bench/nas-algos",
        help="Folder to save checkpoints and log.",
    )
    parser.add_argument(
        "--search_space",
        type=str,
        choices=["tss", "sss"],
        help="Choose the search space.",
    )
    args = parser.parse_args()

    save_dir = Path(args.save_dir)

    api = create(None, "tss", fast_mode=True, verbose=False)
    indexes = list(range(1, 10000, 300))
    scores_1 = []
    scores_2 = []
    for index in indexes:
        valid_acc, test_acc, _ = get_valid_test_acc(api, index, "cifar10")
        scores_1.append(valid_acc)
        scores_2.append(test_acc)
    correlation = compute_kendalltau(scores_1, scores_2)
    print(
        "The kendall tau correlation of {:} samples : {:}".format(
            len(indexes), correlation
        )
    )
    correlation = compute_spearmanr(scores_1, scores_2)
    print(