Esempio n. 1
0
class Nasbench201(Dataset):
    """Nasbench201 Dataset."""
    def __init__(self):
        """Construct the Nasbench201 class."""
        super(Nasbench201, self).__init__()
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        self.nasbench201_api = API('self.args.data_path')

    def query(self, arch_str, dataset):
        """Query an item from the dataset according to the given arch_str and dataset .

        :arch_str: arch_str to define the topology of the cell
        :type path: str
        :dataset: dataset type
        :type dataset: str
        :return: an item of the dataset, which contains the network info and its results like accuracy, flops and etc
        :rtype: dict
        """
        if dataset not in VALID_DATASET:
            raise ValueError(
                "Only cifar10, cifar100, and Imagenet dataset is supported.")
        ops_list = self.nasbench201_api.str2lists(arch_str)
        for op in ops_list:
            if op not in VALID_OPS:
                raise ValueError(
                    "{} is not in the nasbench201 space.".format(op))
        index = self.nasbench201_api.query_index_by_arch(arch_str)

        results = self.nasbench201_api.query_by_index(index, dataset)
        return results
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)

    if os.path.isdir(xargs.save_dir):
        if click.confirm(
                '\nSave directory already exists in {}. Erase?'.format(
                    xargs.save_dir, default=False)):
            os.system('rm -r ' + xargs.save_dir)
            assert not os.path.exists(xargs.save_dir)
            os.mkdir(xargs.save_dir)

    logger = prepare_logger(args)
    writer = SummaryWriter(xargs.save_dir)
    perturb_alpha = None
    if xargs.perturb:
        perturb_alpha = random_alpha

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    # config_path = 'configs/nas-benchmark/algos/DARTS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None:
        model_config = dict2config(
            {
                'name': xargs.model,
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'affine': bool(xargs.affine),
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': bool(xargs.affine),
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    # logger.log('search-model :\n{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config, xargs.weight_learning_rate)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    # logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    # start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    start_time, search_time, epoch_time = time.time(), AverageMeter(
    ), AverageMeter()
    total_epoch = config.epochs + config.warmup
    assert 0 < xargs.early_stop_epoch <= total_epoch - 1
    for epoch in range(start_epoch, total_epoch):
        if epoch >= xargs.early_stop_epoch:
            logger.log(f"Early stop @ {epoch} epoch.")
            break
        if xargs.perturb:
            epsilon_alpha = 0.03 + (xargs.epsilon_alpha -
                                    0.03) * epoch / total_epoch
            logger.log(f'epoch {epoch} epsilon_alpha {epsilon_alpha}')
        else:
            epsilon_alpha = None

        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func(
            search_loader, network, criterion, w_scheduler, w_optimizer,
            a_optimizer, epoch_str, xargs.print_freq, logger,
            xargs.gradient_clip, perturb_alpha, epsilon_alpha)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)

        writer.add_scalar('search/weight_loss', search_w_loss, epoch)
        writer.add_scalar('search/weight_top1_acc', search_w_top1, epoch)
        writer.add_scalar('search/weight_top5_acc', search_w_top5, epoch)

        writer.add_scalar('search/arch_loss', search_a_loss, epoch)
        writer.add_scalar('search/arch_top1_acc', search_a_top1, epoch)
        writer.add_scalar('search/arch_top5_acc', search_a_top5, epoch)

        writer.add_scalar('evaluate/loss', valid_a_loss, epoch)
        writer.add_scalar('evaluate/top1_acc', valid_a_top1, epoch)
        writer.add_scalar('evaluate/top5_acc', valid_a_top5, epoch)
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        writer.add_scalar('entropy', search_model.entropy, epoch)
        per_edge_dict = get_per_egde_value_dict(search_model.arch_parameters)
        for edge_name, edge_val in per_edge_dict.items():
            writer.add_scalars(f"cell/{edge_name}", edge_val, epoch)
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)

        if xargs.snapshoot > 0 and epoch % xargs.snapshoot == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'args': deepcopy(args),
                    'search_model': search_model.state_dict(),
                },
                os.path.join(str(logger.model_dir),
                             f"checkpoint_epoch{epoch}.pth"), logger)

        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
            index = api.query_index_by_arch(genotypes[epoch])
            info = api.query_meta_info_by_index(
                index)  # This is an instance of `ArchResults`
            res_metrics = info.get_metrics(
                f'{xargs.dataset}',
                'ori-test')  # This is a dict with metric names as keys
            # cost_metrics = info.get_comput_costs('cifar10')
            writer.add_scalar(f'{xargs.dataset}_ground_acc_ori-test',
                              res_metrics['accuracy'], epoch)
            writer.add_scalar(f'{xargs.dataset}_search_acc', valid_a_top1,
                              epoch)
            if xargs.dataset.lower() != 'cifar10':
                writer.add_scalar(
                    f'{xargs.dataset}_ground_acc_x-test',
                    info.get_metrics(f'{xargs.dataset}', 'x-test')['accuracy'],
                    epoch)
            if find_best:
                valid_accuracies['best_gt'] = res_metrics['accuracy']
            writer.add_scalar(f"{xargs.dataset}_cur_best_gt_acc_ori-test",
                              valid_accuracies['best_gt'], epoch)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 100)
    logger.log('{:} : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
        args.model, xargs.early_stop_epoch, search_time.sum,
        genotypes[xargs.early_stop_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(
            api.query_by_arch(genotypes[xargs.early_stop_epoch - 1])))
    logger.close()
Esempio n. 3
0
#model_config = {'C': 16, 'N': 5, 'num_classes': num_classes, 'max_nodes': 4, 'search_space': NAS_BENCH_201, 'affine': False}
model = TinyNetwork(C = args.init_channels, N = args.num_cells, max_nodes = args.max_nodes,
                    num_classes = num_classes, search_space = NAS_BENCH_201, affine = False,
                    track_running_stats = args.track_running_stats)
model = model.to(device)
#logging.info(model)

optimizer, _, criterion = get_optim_scheduler(parameters=model.get_weights(), config=config)
criterion = criterion.cuda()
logging.info(f'optimizer: {optimizer}\nCriterion: {criterion}')

# logging the initialized architecture
best_arch_per_epoch = []

arch_str = model.genotype().tostr()
arch_index = api.query_index_by_arch(model.genotype())
if args.dataset == 'cifar10':
  test_acc = get_arch_score(api, arch_index, 'cifar10', 200, acc_type)
  valid_acc = get_arch_score(api, arch_index, 'cifar10-valid', 200, val_acc_type)
  writer.add_scalar("test_acc", test_acc, 0)
  writer.add_scalar("valid_acc", valid_acc, 0)
else:
  test_acc = get_arch_score(api, arch_index, args.dataset, 200, acc_type)
  valid_acc = get_arch_score(api, arch_index, args.dataset, 200, val_acc_type)
  writer.add_scalar("test_acc", test_acc, 0)
  writer.add_scalar("valid_acc", valid_acc, 0)
tmp = (arch_str, test_acc, valid_acc)
best_arch_per_epoch.append(tmp)

'''
optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum = args.momentum, weight_decay = args.weight_decay)
Esempio n. 4
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    #config_path = 'configs/nas-benchmark/algos/GDAS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(search_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None and not args.constrain:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 0,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    elif xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 32,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    #logger.log('search-model :\n{:}'.format(search_model))
    logger.log('model-config : {:}'.format(model_config))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()
    #network, criterion = search_model.cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    sampled_weights = []
    for epoch in range(start_epoch, total_epoch + config.t_epochs):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(
                epoch_time.val * (total_epoch - epoch + config.t_epochs),
                True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        search_model.set_tau(xargs.tau_max -
                             (xargs.tau_max - xargs.tau_min) * epoch /
                             (total_epoch - 1))
        logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(
            epoch_str, need_time, search_model.get_tau(),
            min(w_scheduler.get_lr())))
        if epoch < total_epoch:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
                      = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.bilevel)
        else:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \
                       = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))

        if (epoch + 1) % 50 == 0 and not config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
        elif (epoch + 1) == total_epoch and config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
            arch_iter = iter(weights)
        # validate with single arch
        single_weight = search_model.sample_weights(1)[0]
        single_valid_acc = AverageMeter()
        network.eval()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=single_weight)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                single_valid_acc.update(val_acc1.item(), n_val)
        logger.log('[{:}] valid : accuracy = {:.2f}'.format(
            epoch_str, single_valid_acc.avg))

        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        if epoch < total_epoch:
            genotypes[epoch] = search_model.genotype()
            logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
                epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None and epoch < total_epoch:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    network.eval()
    # Evaluate the architectures sampled throughout the search
    for i in range(len(sampled_weights) - 1):
        logger.log('Sample eval : epoch {}'.format((i + 1) * 50 - 1))
        for w in sampled_weights[i]:
            sample_valid_acc = AverageMeter()
            for i in range(10):
                try:
                    val_input, val_target = next(valid_iter)
                except Exception as e:
                    valid_iter = iter(valid_loader)
                    val_input, val_target = next(valid_iter)
                n_val = val_input.size(0)
                with torch.no_grad():
                    val_target = val_target.cuda(non_blocking=True)
                    _, logits, _ = network(val_input, weights=w)
                    val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                         val_target.data,
                                                         topk=(1, 5))
                    sample_valid_acc.update(val_acc1.item(), n_val)
            w_gene = search_model.genotype(w)
            if api is not None:
                ind = api.query_index_by_arch(w_gene)
                info = api.query_meta_info_by_index(ind)
                metrics = info.get_metrics('cifar10', 'ori-test')
                acc = metrics['accuracy']
            else:
                acc = 0.0
            logger.log(
                'sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
                    sample_valid_acc.avg, acc))
    # Evaluate the final sampling separately to find the top 10 architectures
    logger.log('Final sample eval')
    final_archs = []
    for w in sampled_weights[-1]:
        sample_valid_acc = AverageMeter()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                sample_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log('sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
            sample_valid_acc.avg, acc))
        final_archs.append((w, sample_valid_acc.avg))
    top_10 = sorted(final_archs, key=lambda x: x[1], reverse=True)[:10]
    # Evaluate the top 10 architectures on the entire validation set
    logger.log('Evaluating top archs')
    for w, prev_acc in top_10:
        full_valid_acc = AverageMeter()
        for val_input, val_target in valid_loader:
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                full_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        logger.log('genotype {}'.format(w_gene))
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log(
            'full valid : val_acc = {:.2f} test_acc = {:.2f} pval_acc = {:.2f}'
            .format(full_valid_acc.avg, acc, prev_acc))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotypes[total_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1])))
    logger.close()
Esempio n. 5
0
class MacroGraph(NodeOpGraph):
    def __init__(self, config, primitives, ops_dict, *args, **kwargs):
        self.config = config
        self.primitives = primitives
        self.ops_dict = ops_dict
        self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        super(MacroGraph, self).__init__(*args, **kwargs)

    def _build_graph(self):
        num_cells_per_stack = self.config['num_cells_per_stack']
        C = self.config['init_channels']
        layer_channels = [C] * num_cells_per_stack + [C * 2] + [C * 2] * num_cells_per_stack + [C * 4] + [
            C * 4] * num_cells_per_stack
        layer_reductions = [False] * num_cells_per_stack + [True] + [False] * num_cells_per_stack + [True] + [
            False] * num_cells_per_stack

        stem = NASBENCH_201_Stem(C=C)
        self.add_node(0, type='input')
        self.add_node(1, op=stem, type='stem')

        C_prev = C
        self.cells = nn.ModuleList()
        for cell_num, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
            if reduction:
                cell = ResNetBasicblock(C_prev, C_curr, 2, True)
                self.add_node(cell_num + 2, op=cell, primitives=self.primitives, transform=lambda x: x[0])
            else:
                cell = Cell(primitives=self.primitives, stride=1, C_prev=C_prev, C=C_curr,
                            ops_dict=self.ops_dict, cell_type='normal')
                self.add_node(cell_num + 2, op=cell, primitives=self.primitives)

            C_prev = C_curr

        lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
        pooling = nn.AdaptiveAvgPool2d(1)
        classifier = nn.Linear(C_prev, self.config['num_classes'])

        self.add_node(cell_num + 3, op=lastact, transform=lambda x: x[0], type='postprocessing_nb201')
        self.add_node(cell_num + 4, op=pooling, transform=lambda x: x[0], type='pooling')
        self.add_node(cell_num + 5, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1),
                      type='output')

        # Edges
        for i in range(1, cell_num + 6):
            self.add_edge(i - 1, i, type='input', desc='previous')

    def sample(self, same_cell_struct=True, n_ops_per_edge=1,
               n_input_edges=None, dist=None, seed=1):
        """
        same_cell_struct:
            True; if the sampled cell topology is the same or not
        n_ops_per_edge:
            1; number of sampled operations per edge in cell
        n_input_edges:
            None; list equal with length with number of intermediate
        nodes. Determines the number of predecesor nodes for each of them
        dist:
            None; distribution to sample operations in edges from
        seed:
            1; random seed
        """
        # create a new graph that we will discretize
        new_graph = MacroGraph(self.config, self.primitives, self.ops_dict)
        np.random.seed(seed)
        seeds = {'normal': seed + 1, 'reduction': seed + 2}

        for node in new_graph:
            cell = new_graph.get_node_op(node)
            if not isinstance(cell, Cell):
                continue

            if same_cell_struct:
                np.random.seed(seeds[new_graph.get_node_type(node)])

            for edge in cell.edges:
                op_choices = cell.get_edge_op_choices(*edge)
                sampled_op = np.random.choice(op_choices, n_ops_per_edge,
                                              False, p=dist)
                cell[edge[0]][edge[1]]['op_choices'] = [*sampled_op]

            if n_input_edges is not None:
                for inter_node, k in zip(cell.inter_nodes(), n_input_edges):
                    # in case the start node index is not 0
                    node_idx = list(cell.nodes).index(inter_node)
                    prev_node_choices = list(cell.nodes)[:node_idx]
                    assert k <= len(prev_node_choices), 'cannot sample more'
                    ' than number of predecesor nodes'

                    sampled_input_edges = np.random.choice(prev_node_choices,
                                                           k, False)
                    for i in set(prev_node_choices) - set(sampled_input_edges):
                        cell.remove_edge(i, inter_node)

        return new_graph

    @classmethod
    def from_config(cls, config=None, filename=None):
        with open(filename, 'r') as f:
            graph_dict = yaml.safe_load(f)

        if config is None:
            raise ('No configuration provided')

        graph = cls(config, [])

        graph_type = graph_dict['type']
        edges = [(*eval(e), attr) for e, attr in graph_dict['edges'].items()]
        graph.clear()
        graph.add_edges_from(edges)

        C = config['init_channels']
        C_curr = config['stem_multiplier'] * C

        stem = Stem(C_curr=C_curr)
        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C

        for node, attr in graph_dict['nodes'].items():
            node_type = attr['type']
            if node_type == 'input':
                graph.add_node(node, type='input')
            elif node_type == 'stem':
                graph.add_node(node, op=stem, type='stem')
            elif node_type in ['normal', 'reduction']:
                assert attr['op']['type'] == 'Cell'
                graph.add_node(node,
                               op=Cell.from_config(attr['op'], primitives=attr['op']['primitives'],
                                                   C_prev_prev=C_prev_prev, C_prev=C_prev,
                                                   C=C_curr,
                                                   reduction_prev=graph_dict['nodes'][node - 1]['type'] == 'reduction',
                                                   cell_type=node_type),
                               type=node_type)
                C_prev_prev, C_prev = C_prev, config['channel_multiplier'] * C_curr
            elif node_type == 'pooling':
                pooling = nn.AdaptiveAvgPool2d(1)
                graph.add_node(node, op=pooling, transform=lambda x: x[0],
                               type='pooling')
            elif node_type == 'output':
                classifier = nn.Linear(C_prev, config['num_classes'])
                graph.add_node(node, op=classifier, transform=lambda x:
                x[0].view(x[0].size(0), -1), type='output')

        return graph

    @staticmethod
    def export_nasbench_201_results_to_dict(information):
        results_dict = {}
        dataset_names = information.get_dataset_names()
        results_dict['arch'] = information.arch_str
        results_dict['datasets'] = dataset_names

        for ida, dataset in enumerate(dataset_names):
            dataset_results = {}
            dataset_results['dataset'] = dataset

            metric = information.get_compute_costs(dataset)
            flop, param, latency, training_time = metric['flops'], metric['params'], metric['latency'], metric[
                'T-train@total']
            dataset_results['flop'] = flop
            dataset_results['params (MB)'] = param
            dataset_results['latency (ms)'] = latency * 1000 if latency is not None and latency > 0 else None
            dataset_results['training_time'] = training_time

            train_info = information.get_metrics(dataset, 'train')
            if dataset == 'cifar10-valid':
                valid_info = information.get_metrics(dataset, 'x-valid')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['valid_loss'] = valid_info['loss']
                dataset_results['valid_accuracy'] = valid_info['accuracy']

            elif dataset == 'cifar10':
                test__info = information.get_metrics(dataset, 'ori-test')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['test_loss'] = test__info['loss']
                dataset_results['test_accuracy'] = test__info['accuracy']
            else:
                valid_info = information.get_metrics(dataset, 'x-valid')
                test__info = information.get_metrics(dataset, 'x-test')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['valid_loss'] = valid_info['loss']
                dataset_results['valid_accuracy'] = valid_info['accuracy']

                dataset_results['test_loss'] = test__info['loss']
                dataset_results['test_accuracy'] = test__info['accuracy']
            results_dict[dataset] = dataset_results
        return results_dict

    def query_architecture(self, arch_weights):
        arch_weight_idx_to_parent = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2}
        arch_strs = {
            'cell_normal_from_0_to_1': '',
            'cell_normal_from_0_to_2': '',
            'cell_normal_from_1_to_2': '',
            'cell_normal_from_0_to_3': '',
            'cell_normal_from_1_to_3': '',
            'cell_normal_from_2_to_3': '',
        }
        for arch_weight_idx, (edge_key, edge_weights) in enumerate(arch_weights.items()):
            edge_weights_norm = torch.softmax(edge_weights, dim=-1)
            selected_op_str = PRIMITIVES[edge_weights_norm.argmax()]
            arch_strs[edge_key] = '{}~{}'.format(selected_op_str, arch_weight_idx_to_parent[arch_weight_idx])

        arch_str = '|{}|+|{}|{}|+|{}|{}|{}|'.format(*arch_strs.values())
        if not hasattr(self, 'nasbench_api'):
            self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        index = self.nasbench_api.query_index_by_arch(arch_str)
        self.nasbench_api.show(index)
        info = self.nasbench_api.query_by_index(index)
        return self.export_nasbench_201_results_to_dict(info)
Esempio n. 6
0
def make(ww="without"):

    from nas_201_api import NASBench201API as API
    api = API('/Users/madoibito80/NAS-Bench-201-v1_0-e61699.pth')

    macc = {}
    nparams = {}

    for mode in modes:
        for way in ways:
            if mode == "P" and way == "best":
                continue

            for trial in trials:
                fname = str(trial) + "_" + mode
                arcs = opener("./snapshot/" + ww + "/" + way + "/" + fname +
                              ".txt",
                              freeze=50,
                              mode=mode)
                accs = []
                for i in range(len(arcs)):
                    if i % 100 == 0:
                        print(mode, way, trial, i)
                    index = api.query_index_by_arch(arcs[i])
                    flg = True
                    try:
                        info = api.query_meta_info_by_index(index)
                    except:
                        print("error: ", mode, way, trial, i)
                        flg = False

                    if flg:
                        res = info.get_metrics('cifar10', 'ori-test', None,
                                               False)
                        # cifar10 : training the model on the CIFAR-10 training + validation set.
                        # ,criteria , , False=3average of NAS-Bench)
                        acc = res['accuracy']
                        accs.append(float(acc))
                    else:
                        accs.append(accs[-1])

                if trial == 0:
                    macc[mode + way] = np.zeros((len(trials), len(accs)))
                    nparams[mode + way] = np.zeros((len(trials), 1))
                macc[mode + way][trial] = np.array(accs)

                # get final performance for table
                try:
                    print(dir(info))
                except:
                    print("no dir")
                metric = info.get_comput_costs('cifar10')
                flop, param, latency = metric['flops'], metric[
                    'params'], metric['latency']
                nparams[mode + way][trial] = float(param)

    f = open("./" + ww + ".pickle", "wb")
    pickle.dump(macc, f)
    pickle.dump(nparams, f)
    f.close()

    print(macc)
    print(nparams)
    return macc
Esempio n. 7
0
class Nasbench201:
    def __init__(self, dataset, apiloc):
        self.dataset = dataset
        self.api = API(apiloc, verbose=False)
        self.epochs = '12'

    def get_network(self, uid):
        #config = self.api.get_net_config(uid, self.dataset)
        config = self.api.get_net_config(uid, 'cifar10-valid')
        config['num_classes'] = 1
        network = get_cell_based_tiny_net(config)
        return network

    def __iter__(self):
        for uid in range(len(self)):
            network = self.get_network(uid)
            yield uid, network

    def __getitem__(self, index):
        return index

    def __len__(self):
        return 15625

    def num_activations(self):
        network = self.get_network(0)
        return network.classifier.in_features

    #def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
    #    archinfo = self.api.query_meta_info_by_index(uid)
    #    if (self.dataset == 'cifar10' or traincifar10) and trainval:
    #        #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy']
    #        return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy']
    #    elif traincifar10:
    #        return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy']
    #    else:
    #        return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
    def get_12epoch_accuracy(self,
                             uid,
                             acc_type,
                             trainval,
                             traincifar10=False):
        #archinfo = self.api.query_meta_info_by_index(uid)
        #if (self.dataset == 'cifar10' and trainval) or traincifar10:
        info = self.api.get_more_info(uid,
                                      'cifar10-valid',
                                      iepoch=None,
                                      hp=self.epochs,
                                      is_random=True)
        #else:
        #    info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
        return info['valid-accuracy']

    def get_final_accuracy(self, uid, acc_type, trainval):
        #archinfo = self.api.query_meta_info_by_index(uid)
        if self.dataset == 'cifar10' and trainval:
            info = self.api.query_meta_info_by_index(
                uid, hp='200').get_metrics('cifar10-valid', 'x-valid')
            #info = self.api.query_by_index(uid, 'cifar10-valid', hp='200')
            #info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True)
        else:
            info = self.api.query_meta_info_by_index(
                uid, hp='200').get_metrics(self.dataset, acc_type)
            #info = self.api.query_by_index(uid, self.dataset, hp='200')
            #info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True)
        return info['accuracy']
        #return info['valid-accuracy']
        #if self.dataset == 'cifar10' and trainval:
        #    return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy']
        #else:
        #    #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
        #    return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy']
        ##dataset = self.dataset
        ##if self.dataset == 'cifar10' and trainval:
        ##    dataset = 'cifar10-valid'
        ##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True)
        ##return archinfo['valid-accuracy']

    def get_accuracy(self, uid, acc_type, trainval=True):
        archinfo = self.api.query_meta_info_by_index(uid)
        if self.dataset == 'cifar10' and trainval:
            return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy']
        else:
            return archinfo.get_metrics(self.dataset, acc_type)['accuracy']

    def get_accuracy_for_all_datasets(self, uid):
        archinfo = self.api.query_meta_info_by_index(uid, hp='200')

        c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy']
        c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']

        c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy']
        c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy']

        imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy']
        imagenet_val = archinfo.get_metrics('ImageNet16-120',
                                            'x-valid')['accuracy']

        return c10, c10_val, c100, c100_val, imagenet, imagenet_val

    #def train_and_eval(self, arch, dataname, acc_type, trainval=True):
    #    unique_hash = self.__getitem__(arch)
    #    time = self.get_training_time(unique_hash)
    #    acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval)
    #    acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
    #    return acc12, acc, time
    def train_and_eval(self,
                       arch,
                       dataname,
                       acc_type,
                       trainval=True,
                       traincifar10=False):
        unique_hash = self.__getitem__(arch)
        time = self.get_training_time(unique_hash)
        acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval,
                                          traincifar10)
        acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
        return acc12, acc, time

    def random_arch(self):
        return random.randint(0, len(self) - 1)

    def get_training_time(self, unique_hash):
        #info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True)

        #info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
        info = self.api.get_more_info(unique_hash,
                                      'cifar10-valid',
                                      iepoch=None,
                                      hp='12',
                                      is_random=True)
        return info['train-all-time'] + info['valid-per-time']
        #if self.dataset == 'cifar10' and trainval:
        #    info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
        #else:
        #    info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True)

        ##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
        #return info['train-all-time'] + info['valid-per-time']
    def mutate_arch(self, arch):
        op_names = get_search_spaces('cell', 'nas-bench-201')
        #config = self.api.get_net_config(arch, self.dataset)
        config = self.api.get_net_config(arch, 'cifar10-valid')
        parent_arch = Structure(self.api.str2lists(config['arch_str']))
        child_arch = deepcopy(parent_arch)
        node_id = random.randint(0, len(child_arch.nodes) - 1)
        node_info = list(child_arch.nodes[node_id])
        snode_id = random.randint(0, len(node_info) - 1)
        xop = random.choice(op_names)
        while xop == node_info[snode_id][0]:
            xop = random.choice(op_names)
        node_info[snode_id] = (xop, node_info[snode_id][1])
        child_arch.nodes[node_id] = tuple(node_info)
        arch_index = self.api.query_index_by_arch(child_arch)
        return arch_index
Esempio n. 8
0
class NAS201(ObjectiveFunction):

    def __init__(self, data_dir, task='cifar10-valid', log_scale=True, negative=True,
                 use_12_epochs_result=False,
                 seed=None):
        """
        data_dir: data directory that contains NAS-Bench-201-v1_0-e61699.pth file
        task: the target image tasks. Options: cifar10-valid, cifar100, ImageNet16-120
        log_scale: whether output the objective in log scale
        negative: whether output the objective in negative form
        use_12_epochs_result: whether use the statistics at the end of training of the 12th epoch instead of all the
                              way till the end.
        seed: set the random seed to access trained model performance: Options: 0, 1, 2
              seed=None will select the seed randomly
        """

        self.api = API(os.path.join(data_dir, 'NAS-Bench-201-v1_1-096897.pth'))
        if isinstance(task, list):
            task = task[0]
        self.task = task
        self.use_12_epochs_result = use_12_epochs_result

        if task == 'cifar10-valid':
            best_val_arch_index = 6111
            best_val_acc = 91.60666665039064 / 100
            best_test_arch_index = 1459
            best_test_acc = 91.52333333333333 / 100
        elif task == 'cifar100':
            best_val_arch_index = 9930
            best_val_acc = 73.49333323567708 / 100
            best_test_arch_index = 9930
            best_test_acc = 73.51333326009114 / 100
        elif task == 'ImageNet16-120':
            best_val_arch_index = 10676
            best_val_acc = 46.766666727701825 / 100
            best_test_arch_index = 857
            best_test_acc = 47.311111097547744 / 100
        else:
            raise NotImplementedError("task" + str(task) + " is not implemented in the dataset.")

        if log_scale:
            best_val_acc = np.log(best_val_acc)

        best_val_err = 1. - best_val_acc
        best_test_err = 1. - best_test_acc
        if log_scale:
            best_val_err = np.log(best_val_err)
            best_test_err = np.log(best_val_err)
        if negative:
            best_val_err = -best_val_err
            best_test_err = -best_test_err

        self.best_val_err = best_val_err
        self.best_test_err = best_test_err
        self.best_val_acc = best_val_acc
        self.best_test_acc = best_test_acc

        super(NAS201, self).__init__(dim=None, optimum_location=best_test_arch_index, optimal_val=best_test_err,
                                     bounds=None)

        self.log_scale = log_scale
        self.seed = seed
        self.X = []
        self.y_valid_acc = []
        self.y_test_acc = []
        self.costs = []
        self.negative = negative
        # self.optimal_val =   # lowest mean validation error
        # self.y_star_test =   # lowest mean test error

    def _retrieve(self, G, budget, which='eval'):
        #  set random seed for evaluation
        if which == 'test':
            seed = 3
        else:
            seed_list = [777, 888, 999]
            if self.seed is None:
                seed = random.choice(seed_list)
            elif self.seed >= 3:
                seed = self.seed
            else:
                seed = seed_list[self.seed]

        # find architecture index
        arch_str = G.name
        # print(arch_str)

        try:
            arch_index = self.api.query_index_by_arch(arch_str)
            acc_results = self.api.query_by_index(arch_index, self.task, use_12epochs_result=self.use_12_epochs_result,)
            if seed is not None and 3 <= seed < 777:
                # some architectures only contain 1 seed result
                acc_results = self.api.get_more_info(arch_index, self.task, None,
                                                     use_12epochs_result=self.use_12_epochs_result,
                                                     is_random=False)
                val_acc = acc_results['valid-accuracy'] / 100
                test_acc = acc_results['test-accuracy'] / 100
            else:
                try:
                    acc_results = self.api.get_more_info(arch_index, self.task, None,
                                                         use_12epochs_result=self.use_12_epochs_result,
                                                         is_random=seed)
                    val_acc = acc_results['valid-accuracy'] / 100
                    test_acc = acc_results['test-accuracy'] / 100
                    # val_acc = acc_results[seed].get_eval('x-valid')['accuracy'] / 100
                    # if self.task == 'cifar10-valid':
                    #     test_acc = acc_results[seed].get_eval('ori-test')['accuracy'] / 100
                    # else:
                    #     test_acc = acc_results[seed].get_eval('x-test')['accuracy'] / 100
                except:
                    # some architectures only contain 1 seed result
                    acc_results = self.api.get_more_info(arch_index, self.task, None,
                                                         use_12epochs_result=self.use_12_epochs_result,
                                                         is_random=False)
                    val_acc = acc_results['valid-accuracy'] / 100
                    test_acc = acc_results['test-accuracy'] / 100

            auxiliary_info = self.api.query_meta_info_by_index(arch_index,
                                                               use_12epochs_result=self.use_12_epochs_result)
            cost_info = auxiliary_info.get_compute_costs(self.task)

            # auxiliary cost results such as number of flops and number of parameters
            cost_results = {'flops': cost_info['flops'], 'params': cost_info['params'],
                            'latency': cost_info['latency']}

        except FileNotFoundError:
            val_acc = 0.01
            test_acc = 0.01
            print('missing arch info')
            cost_results = {'flops': None, 'params': None,
                            'latency': None}

        # store val and test performance + auxiliary cost information
        self.X.append(arch_str)
        self.y_valid_acc.append(val_acc)
        self.y_test_acc.append(test_acc)
        self.costs.append(cost_results)

        if which == 'eval':
            err = 1 - val_acc
        elif which == 'test':
            err = 1 - test_acc
        else:
            raise ValueError("Unknown query parameter: which = " + str(which))

        if self.log_scale:
            y = np.log(err)
        else:
            y = err
        if self.negative:
            y = -y
        return y

    def eval(self, G, budget=199, n_repeat=1):
        # input is a list of graphs [G1,G2, ....]
        if n_repeat == 1:
            return self._retrieve(G, budget, 'eval'),  [np.nan]
        return np.mean(np.array([self._retrieve(G, budget, 'eval') for _ in range(n_repeat)])), [np.nan]

    def test(self, G, budget=199, n_repeat=1):
        return np.mean(np.array([self._retrieve(G, budget, 'test') for _ in range(n_repeat)]))

    def get_results(self, ignore_invalid_configs=False):

        regret_validation = []
        regret_test = []
        costs = []
        model_graph_specs = []

        inc_valid = 0
        inc_test = 0

        for i in range(len(self.X)):

            if inc_valid < self.y_valid_acc[i]:
                inc_valid = self.y_valid_acc[i]
                inc_test = self.y_test_acc[i]

            regret_validation.append(float(self.best_val_acc - inc_valid))
            regret_test.append(float(self.best_test_acc - inc_test))
            model_graph_specs.append(self.X[i])
            costs.append(self.costs[i])

        res = dict()
        res['regret_validation'] = regret_validation
        res['regret_test'] = regret_test
        res['costs'] = costs
        res['model_graph_specs'] = model_graph_specs

        return res

    @staticmethod
    def get_configuration_space():
        # for unpruned graph
        cs = ConfigSpace.ConfigurationSpace()

        ops_choices = ['nor_conv_3x3', 'nor_conv_1x1', 'avg_pool_3x3', 'skip_connect', 'none']
        for i in range(6):
            cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("edge_%d" % i, ops_choices))
        return cs