def extract_archs_benchmarks(benchmark_dir: str, search_space: SearchSpace) -> Dict[str, Dict]:
    if search_space == SearchSpace.TOPOLOGY:
      api = NATStopology(benchmark_dir, True, False)
    elif search_space == SearchSpace.SIZE:
      api = NATSsize(benchmark_dir, True, False)
    else:
      raise Exception("invalid search space")

    datasets = {
        'cifar10': 'CIFAR-10',
        'cifar100': 'CIFAR-100',
        'ImageNet16-120': 'ImageNet16-120'
    }
    if search_space == SearchSpace.TOPOLOGY:
        nums_epochs = {
            12: '12',
            200: '200'
        }
    else:
        nums_epochs = {
            1: '01',
            12: '12',
            90: '90'
        }
    archs = {}

    for arch_id in range(len(api)):
        print(f'Architecture {arch_id} topology: {api.arch(arch_id)}')
        arch_key = api.arch(arch_id)
        archs[arch_key] = {}

        # hack to free RAM
        if arch_id != 0 and arch_id % 1000 == 0:
            del api
            if search_space == SearchSpace.TOPOLOGY:
                api = NATStopology(benchmark_dir, True, False)
            elif search_space == SearchSpace.SIZE:
                api = NATSsize(benchmark_dir, True, False)

        for dataset_id, dataset_name in datasets.items():
            archs[arch_key][dataset_id] = {}
            for num_epochs, num_epochs_str in nums_epochs.items():
              arch_performance = api.get_more_info(arch_id, dataset_id, hp=num_epochs_str)
              archs[arch_key][dataset_id][num_epochs] = { #[
                # accuracy and time
                #arch_performance['train-loss'], arch_performance['train-accuracy'], arch_performance['train-all-time'], arch_performance['train-all-time'], arch_performance['test-accuracy'], arch_performance['test-all-time']

                "train_loss": arch_performance['train-loss'],
                "train_accuracy":  arch_performance['train-accuracy'],
                "train_time": arch_performance['train-all-time'],
                "test_loss": arch_performance['train-all-time'],
                "test_accuracy": arch_performance['test-accuracy'],
                "test_time": arch_performance['test-all-time'],
              } # ]

    return archs
Beispiel #2
0
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
    """The main test entry for NATS-Bench."""
    if is_tss:
        api = NATStopology(benchmark_dir, True, verbose)
    else:
        api = NATSsize(benchmark_dir, True, verbose)

    if fake_random:
        test_indexes = [0, 11, 284]
    else:
        test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]

    key2dataset = {
        'cifar10': 'CIFAR-10',
        'cifar100': 'CIFAR-100',
        'ImageNet16-120': 'ImageNet16-120'
    }

    for index in test_indexes:
        print('\n\nEvaluate the {:5d}-th architecture.'.format(index))

        for key, dataset in key2dataset.items():
            # Query the loss / accuracy / time for the `index`-th candidate
            #   architecture on CIFAR-10
            # info is a dict, where you can easily figure out the meaning by key
            info = api.get_more_info(index, key)
            print('  -->> The performance on {:}: {:}'.format(dataset, info))

            # Query the flops, params, latency. info is a dict.
            info = api.get_cost_info(index, key)
            print('  -->> The cost info on {:}: {:}'.format(dataset, info))

            # Simulate the training of the `index`-th candidate:
            validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
                index, dataset=key, hp='12')
            print('  -->> The validation accuracy={:}, latency={:}, '
                  'the current time cost={:} s, accumulated time cost={:} s'.
                  format(validation_accuracy, latency, time_cost,
                         current_total_time_cost))

            # Print the configuration of the `index`-th architecture on CIFAR-10
            config = api.get_net_config(index, key)
            print('  -->> The configuration on {:} is {:}'.format(
                dataset, config))

        # Show the information of the `index`-th architecture
        api.show(index)

    with pytest.raises(ValueError):
        api.get_more_info(100000, 'cifar10')
Beispiel #3
0
def create(file_path_or_dict, search_space, fast_mode=False, verbose=True):
  """Create the instead for NATS API.

  Args:
    file_path_or_dict: None or a file path or a directory path.
    search_space: This is a string indicates the search space in NATS-Bench.
    fast_mode: If True, we will not load all the data at initialization,
      instead, the data for each candidate architecture will be loaded when
      quering it; If False, we will load all the data during initialization.
    verbose: This is a flag to indicate whether log additional information.

  Raises:
    ValueError: If not find the matched serach space description.

  Returns:
    The created NATS-Bench API.
  """
  if search_space in NATS_BENCH_TSS_NAMEs:
    return NATStopology(file_path_or_dict, fast_mode, verbose)
  elif search_space in NATS_BENCH_SSS_NAMEs:
    return NATSsize(file_path_or_dict, fast_mode, verbose)
  else:
    raise ValueError('invalid search space : {:}'.format(search_space))