Esempio n. 1
0
    def query_in_benchmark(
            graph_data: BenchmarkGraphData) -> Tuple[float, List[float]]:
        if not isinstance(graph_data.benchmark, str):
            return graph_data.benchmark(graph_data)

        # built-in benchmarks with default query setting
        if graph_data.benchmark == 'nasbench101':
            from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
            arch = None
            for t in graph_data.mutation.values():
                if isinstance(t, dict):
                    arch = t
            if arch is None:
                raise ValueError(
                    f'Cannot identify architecture from mutation dict: {graph_data.mutation}'
                )
            print(arch)
            return _convert_to_final_and_intermediates(
                query_nb101_trial_stats(arch, 108, include_intermediates=True),
                'valid_acc')
        elif graph_data.benchmark.startswith('nasbench201'):
            from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
            dataset = graph_data.benchmark.split('-')[-1]
            return _convert_to_final_and_intermediates(
                query_nb201_trial_stats(_flatten_architecture(
                    graph_data.mutation),
                                        200,
                                        dataset,
                                        include_intermediates=True),
                'valid_acc',
            )
        elif graph_data.benchmark.startswith('nds'):
            # FIXME: not tested yet
            from nni.nas.benchmarks.nds import query_nds_trial_stats
            dataset = graph_data.benchmark.split('-')[-1]
            return _convert_to_final_and_intermediates(
                query_nds_trial_stats(None,
                                      None,
                                      None,
                                      None,
                                      _flatten_architecture(
                                          graph_data.mutation),
                                      dataset,
                                      include_intermediates=True), 'valid_acc')
        elif graph_data.benchmark.startswith('nlp'):
            # FIXME: not tested yet
            from nni.nas.benchmarks.nlp import query_nlp_trial_stats
            # TODO: I'm not sure of the availble datasets in this benchmark. and the docs are missing.
            return _convert_to_final_and_intermediates(
                query_nlp_trial_stats(_flatten_architecture(
                    graph_data.mutation),
                                      'ptb',
                                      include_intermediates=True), 'valid_acc')
        else:
            raise ValueError(
                f'{graph_data.benchmark} is not a supported benchmark.')
Esempio n. 2
0
def main(args):
    r = args.pop('TRIAL_BUDGET')
    dataset = [
        t for t in query_nb201_trial_stats(
            args, 200, 'cifar100', include_intermediates=True)
    ]
    test_acc = random.choice(dataset)['intermediates'][r -
                                                       1]['ori_test_acc'] / 100
    time.sleep(random.randint(0, 10))
    nni.report_final_result(test_acc)
    logger.debug('Final result is %g', test_acc)
    logger.debug('Send final result done.')
Esempio n. 3
0
    dataset_train, dataset_valid = datasets.get_dataset("cifar10")
    model = NASBench201Network(stem_out_channels=args.stem_out_channels,
                               num_modules_per_stack=args.num_modules_per_stack,
                               bn_affine=args.bn_affine,
                               bn_momentum=args.bn_momentum,
                               bn_track_running_stats=args.bn_track_running_stats)

    optim = torch.optim.SGD(model.parameters(), 0.025)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
    criterion = nn.CrossEntropyLoss()

    if args.arch is not None:
        logger.info('model retraining...')
        with open(args.arch, 'r') as f:
            arch = json.load(f)
        for trial in query_nb201_trial_stats(arch, 200, 'cifar100'):
            pprint.pprint(trial)
        apply_fixed_architecture(model, args.arch)
        dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
        dataloader_valid = DataLoader(dataset_valid, batch_size=args.batch_size, shuffle=True, num_workers=0)
        train(args, model, dataloader_train, dataloader_valid, criterion, optim,
              torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        exit(0)

    trainer = enas.EnasTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               reward_function=reward_accuracy,
                               optimizer=optim,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                               batch_size=args.batch_size,
Esempio n. 4
0
# NAS-Bench-201
# -------------
#
# Use the following architecture as an example:
#
# .. image:: ../../img/nas-bench-201-example.png

arch = {
    '0_1': 'avg_pool_3x3',
    '0_2': 'conv_1x1',
    '1_2': 'skip_connect',
    '0_3': 'conv_1x1',
    '1_3': 'skip_connect',
    '2_3': 'skip_connect'
}
for t in query_nb201_trial_stats(arch, 200, 'cifar100'):
    pprint.pprint(t)

# %%
# Intermediate results are also available.

for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):
    print(t['config'])
    print('Intermediates:', len(t['intermediates']))

# %%
# NDS
# ---
#
# Use the following architecture as an example:
#