Пример #1
0
class Nasbench201(Dataset):
    """Nasbench201 Dataset."""
    def __init__(self):
        """Construct the Nasbench201 class."""
        super(Nasbench201, self).__init__()
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        self.nasbench201_api = API('self.args.data_path')

    def query(self, arch_str, dataset):
        """Query an item from the dataset according to the given arch_str and dataset .

        :arch_str: arch_str to define the topology of the cell
        :type path: str
        :dataset: dataset type
        :type dataset: str
        :return: an item of the dataset, which contains the network info and its results like accuracy, flops and etc
        :rtype: dict
        """
        if dataset not in VALID_DATASET:
            raise ValueError(
                "Only cifar10, cifar100, and Imagenet dataset is supported.")
        ops_list = self.nasbench201_api.str2lists(arch_str)
        for op in ops_list:
            if op not in VALID_OPS:
                raise ValueError(
                    "{} is not in the nasbench201 space.".format(op))
        index = self.nasbench201_api.query_index_by_arch(arch_str)

        results = self.nasbench201_api.query_by_index(index, dataset)
        return results
Пример #2
0
def main(meta_file: str, weight_dir, save_dir, xdata, use_12epochs_result, valid_or_test):
  api = API(meta_file)
  datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
  print(time_string() + ' ' + '='*50)
  for data in datasets:
    nums = api.statistics(data, True)
    total = sum([k*v for k, v in nums.items()])
    print('Using 012 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
  print(time_string() + ' ' + '='*50)
  for data in datasets:
    nums = api.statistics(data, False)
    total = sum([k*v for k, v in nums.items()])
    print('Using 200 epochs, trained on {:20s} : {:} trials in total ({:}).'.format(data, total, nums))
  print(time_string() + ' ' + '='*50)

  #evaluate(api, weight_dir, 'cifar10-valid', False, True)
  evaluate(api, weight_dir, xdata, use_12epochs_result, valid_or_test)
  
  print('{:} finish this test.'.format(time_string()))
Пример #3
0
def check_cor_for_bandit(meta_file,
                         test_epoch,
                         use_less_or_not,
                         is_rand=True,
                         need_print=False):
    if isinstance(meta_file, API):
        api = meta_file
    else:
        api = API(str(meta_file))
    cifar10_currs = []
    cifar10_valid = []
    cifar10_test = []
    cifar100_valid = []
    cifar100_test = []
    imagenet_test = []
    imagenet_valid = []
    for idx, arch in enumerate(api):
        results = api.get_more_info(idx, 'cifar10-valid', test_epoch - 1,
                                    use_less_or_not, is_rand)
        cifar10_currs.append(results['valid-accuracy'])
        # --->>>>>
        results = api.get_more_info(idx, 'cifar10-valid', None, False, is_rand)
        cifar10_valid.append(results['valid-accuracy'])
        results = api.get_more_info(idx, 'cifar10', None, False, is_rand)
        cifar10_test.append(results['test-accuracy'])
        results = api.get_more_info(idx, 'cifar100', None, False, is_rand)
        cifar100_test.append(results['test-accuracy'])
        cifar100_valid.append(results['valid-accuracy'])
        results = api.get_more_info(idx, 'ImageNet16-120', None, False,
                                    is_rand)
        imagenet_test.append(results['test-accuracy'])
        imagenet_valid.append(results['valid-accuracy'])

    def get_cor(A, B):
        return float(np.corrcoef(A, B)[0, 1])

    cors = []
    for basestr, xlist in zip(
        ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [
            cifar10_valid, cifar10_test, cifar100_valid, cifar100_test,
            imagenet_valid, imagenet_test
        ]):
        correlation = get_cor(cifar10_currs, xlist)
        if need_print:
            print(
                'With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'
                .format(test_epoch, '012' if use_less_or_not else '200',
                        basestr, correlation))
        cors.append(correlation)
        #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
        #print('-'*200)
    #print('*'*230)
    return cors
Пример #4
0
    def query_architecture(self, arch_weights):
        arch_weight_idx_to_parent = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2}
        arch_strs = {
            'cell_normal_from_0_to_1': '',
            'cell_normal_from_0_to_2': '',
            'cell_normal_from_1_to_2': '',
            'cell_normal_from_0_to_3': '',
            'cell_normal_from_1_to_3': '',
            'cell_normal_from_2_to_3': '',
        }
        for arch_weight_idx, (edge_key, edge_weights) in enumerate(arch_weights.items()):
            edge_weights_norm = torch.softmax(edge_weights, dim=-1)
            selected_op_str = PRIMITIVES[edge_weights_norm.argmax()]
            arch_strs[edge_key] = '{}~{}'.format(selected_op_str, arch_weight_idx_to_parent[arch_weight_idx])

        arch_str = '|{}|+|{}|{}|+|{}|{}|{}|'.format(*arch_strs.values())
        if not hasattr(self, 'nasbench_api'):
            self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        index = self.nasbench_api.query_index_by_arch(arch_str)
        self.nasbench_api.show(index)
        info = self.nasbench_api.query_by_index(index)
        return self.export_nasbench_201_results_to_dict(info)
Пример #5
0
def main():
    logger.info("Logger is set - training start")


    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # TODO
    # api = None
    api = API('/home/hongyuan/benchmark/NAS-Bench-201-v1_0-e61699.pth')

    if config.distributed:
        config.gpu = config.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(config.gpu)
        # distributed init
        torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url,
                world_size=config.world_size, rank=config.local_rank)

        config.world_size = torch.distributed.get_world_size()

        config.total_batch_size = config.world_size * config.batch_size
    else:
        config.total_batch_size = config.batch_size


    loaders, samplers = get_search_datasets(config)
    train_loader, valid_loader = loaders
    train_sampler, valid_sampler = samplers

    net_crit = nn.CrossEntropyLoss().cuda()
    controller = CDARTSController(config, net_crit, n_nodes=4, stem_multiplier=config.stem_multiplier)

    resume_state = None
    if config.resume:
        resume_state = torch.load(config.resume_path, map_location='cpu')

    if config.resume:
        controller.load_state_dict(resume_state['controller'])

    controller = controller.cuda()
    if config.sync_bn:
        if config.use_apex:
            controller = apex.parallel.convert_syncbn_model(controller)
        else:
            controller = torch.nn.SyncBatchNorm.convert_sync_batchnorm(controller)

    if config.use_apex:
        controller = DDP(controller, delay_allreduce=True)
    else:
        controller = DDP(controller, device_ids=[config.gpu])

    # warm up model_search
    if config.ensemble_param:
        w_optim = torch.optim.SGD([ {"params": controller.module.feature_extractor.parameters()},
                                    {"params": controller.module.super_layers.parameters()},
                                    {"params": controller.module.fc_super.parameters()},
                                    {"params": controller.module.distill_aux_head1.parameters()},
                                    {"params": controller.module.distill_aux_head2.parameters()},
                                    {"params": controller.module.ensemble_param}],
                                    lr=config.w_lr, momentum=config.w_momentum, weight_decay=config.w_weight_decay)
    else:
        w_optim = torch.optim.SGD([ {"params": controller.module.feature_extractor.parameters()},
                                    {"params": controller.module.super_layers.parameters()},
                                    {"params": controller.module.fc_super.parameters()},
                                    {"params": controller.module.distill_aux_head1.parameters()},
                                    {"params": controller.module.distill_aux_head2.parameters()}],
                                    lr=config.w_lr, momentum=config.w_momentum, weight_decay=config.w_weight_decay)


    # search training loop
    sta_search_iter = 0
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.search_iter * config.search_iter_epochs, eta_min=config.w_lr_min)
    lr_scheduler_retrain = nn.ModuleList()
    alpha_optim = nn.ModuleList()
    optimizer = nn.ModuleList()
    sub_epoch = 0

    for search_iter in range(sta_search_iter, config.search_iter):
        if search_iter < config.pretrain_epochs:
            if config.local_rank == 0:
                logger.info("####### Super model warmup #######")
            train_sampler.set_epoch(search_iter)
            retrain_warmup(train_loader, controller, w_optim, search_iter, writer, logger, True, config.pretrain_epochs, config)
            #lr_scheduler.step()
        else:
            # build new controller
            genotype = controller.module.genotype()
            controller.module.build_nas_model(genotype)

            controller_b = copy.deepcopy(controller.module)
            del controller
            controller = controller_b.cuda()

            # sync params from super layer pool
            controller.copy_params_from_super_layer()
        
            if config.sync_bn:
                if config.use_apex:
                    controller = apex.parallel.convert_syncbn_model(controller)
                else:
                    controller = torch.nn.SyncBatchNorm.convert_sync_batchnorm(controller)

            if config.use_apex:
                controller = DDP(controller, delay_allreduce=True)
            else:
                controller = DDP(controller, device_ids=[config.gpu])

            # weights optimizer
            if config.ensemble_param:
                w_optim = torch.optim.SGD([ {"params": controller.module.feature_extractor.parameters()},
                                            {"params": controller.module.super_layers.parameters()},
                                            {"params": controller.module.fc_super.parameters()},
                                            {"params": controller.module.distill_aux_head1.parameters()},
                                            {"params": controller.module.distill_aux_head2.parameters()},
                                            {"params": controller.module.ensemble_param}],
                                            lr=config.w_lr, momentum=config.w_momentum, weight_decay=config.w_weight_decay)
            else:
                w_optim = torch.optim.SGD([ {"params": controller.module.feature_extractor.parameters()},
                                            {"params": controller.module.super_layers.parameters()},
                                            {"params": controller.module.fc_super.parameters()},
                                            {"params": controller.module.distill_aux_head1.parameters()},
                                            {"params": controller.module.distill_aux_head2.parameters()}],
                                            lr=config.w_lr, momentum=config.w_momentum, weight_decay=config.w_weight_decay)
            # arch_params optimizer
            alpha_optim = torch.optim.Adam(controller.module.arch_parameters(), config.alpha_lr, betas=(0.5, 0.999),
                                        weight_decay=config.alpha_weight_decay)

                                            
            if config.ensemble_param:
                optimizer = torch.optim.SGD([{"params": controller.module.feature_extractor.parameters()},
                                            {"params": controller.module.nas_layers.parameters()},
                                            {"params": controller.module.ensemble_param},
                                            {"params": controller.module.distill_aux_head1.parameters()},
                                            {"params": controller.module.distill_aux_head2.parameters()},
                                            {"params": controller.module.fc_nas.parameters()}],
                                            lr=config.nasnet_lr, momentum=config.w_momentum, weight_decay=config.w_weight_decay)
            else:
                optimizer = torch.optim.SGD([{"params": controller.module.feature_extractor.parameters()},
                                            {"params": controller.module.nas_layers.parameters()},
                                            {"params": controller.module.distill_aux_head1.parameters()},
                                            {"params": controller.module.distill_aux_head2.parameters()},
                                            {"params": controller.module.fc_nas.parameters()}],
                                            lr=config.nasnet_lr, momentum=config.w_momentum, weight_decay=config.w_weight_decay)

            lr_scheduler_retrain = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, config.search_iter_epochs, eta_min=config.w_lr_min)
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                w_optim, config.search_iter * config.search_iter_epochs, eta_min=config.w_lr_min)
    
            # warmup model main
            if config.local_rank == 0:
                logger.info("####### Sub model warmup #######")
            for warmup_epoch in range(config.nasnet_warmup):
                valid_sampler.set_epoch(warmup_epoch)
                retrain_warmup(valid_loader, controller, optimizer, warmup_epoch, writer, logger, False, config.nasnet_warmup, config)
            

            lr_search = lr_scheduler.get_lr()[0]
            lr_main = lr_scheduler_retrain.get_lr()[0]

            search_epoch = search_iter

            # reset iterators
            train_sampler.set_epoch(search_epoch)
            valid_sampler.set_epoch(search_epoch)

            # training
            search(train_loader, valid_loader, controller, optimizer, w_optim, alpha_optim, search_epoch, writer, logger, config)
 
            # sync params to super layer pool
            controller.module.copy_params_from_nas_layer()
            
            # nasbench201
            if config.local_rank == 0:
                logger.info('{}'.format(controller.module._arch_parameters))
                result = api.query_by_arch(controller.module.genotype())
                logger.info('{:}'.format(result))
                cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
                    cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = utils.distill(result)

                writer.add_scalars('nasbench201/cifar10', {'train':cifar10_train,'test':cifar10_test}, search_epoch)
                writer.add_scalars('nasbench201/cifar100', {'train':cifar100_train,'valid':cifar100_valid, 'test':cifar100_test}, search_epoch)
                writer.add_scalars('nasbench201/imagenet16', {'train':imagenet16_train,'valid':imagenet16_valid, 'test':imagenet16_test}, search_epoch)

                
            #lr_scheduler.step()
            #lr_scheduler_retrain.step()
        torch.cuda.empty_cache()
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)

    if os.path.isdir(xargs.save_dir):
        if click.confirm(
                '\nSave directory already exists in {}. Erase?'.format(
                    xargs.save_dir, default=False)):
            os.system('rm -r ' + xargs.save_dir)
            assert not os.path.exists(xargs.save_dir)
            os.mkdir(xargs.save_dir)

    logger = prepare_logger(args)
    writer = SummaryWriter(xargs.save_dir)
    perturb_alpha = None
    if xargs.perturb:
        perturb_alpha = random_alpha

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    # config_path = 'configs/nas-benchmark/algos/DARTS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None:
        model_config = dict2config(
            {
                'name': xargs.model,
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'affine': bool(xargs.affine),
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': bool(xargs.affine),
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    # logger.log('search-model :\n{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config, xargs.weight_learning_rate)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    # logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    # start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    start_time, search_time, epoch_time = time.time(), AverageMeter(
    ), AverageMeter()
    total_epoch = config.epochs + config.warmup
    assert 0 < xargs.early_stop_epoch <= total_epoch - 1
    for epoch in range(start_epoch, total_epoch):
        if epoch >= xargs.early_stop_epoch:
            logger.log(f"Early stop @ {epoch} epoch.")
            break
        if xargs.perturb:
            epsilon_alpha = 0.03 + (xargs.epsilon_alpha -
                                    0.03) * epoch / total_epoch
            logger.log(f'epoch {epoch} epsilon_alpha {epsilon_alpha}')
        else:
            epsilon_alpha = None

        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func(
            search_loader, network, criterion, w_scheduler, w_optimizer,
            a_optimizer, epoch_str, xargs.print_freq, logger,
            xargs.gradient_clip, perturb_alpha, epsilon_alpha)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)

        writer.add_scalar('search/weight_loss', search_w_loss, epoch)
        writer.add_scalar('search/weight_top1_acc', search_w_top1, epoch)
        writer.add_scalar('search/weight_top5_acc', search_w_top5, epoch)

        writer.add_scalar('search/arch_loss', search_a_loss, epoch)
        writer.add_scalar('search/arch_top1_acc', search_a_top1, epoch)
        writer.add_scalar('search/arch_top5_acc', search_a_top5, epoch)

        writer.add_scalar('evaluate/loss', valid_a_loss, epoch)
        writer.add_scalar('evaluate/top1_acc', valid_a_top1, epoch)
        writer.add_scalar('evaluate/top5_acc', valid_a_top5, epoch)
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        writer.add_scalar('entropy', search_model.entropy, epoch)
        per_edge_dict = get_per_egde_value_dict(search_model.arch_parameters)
        for edge_name, edge_val in per_edge_dict.items():
            writer.add_scalars(f"cell/{edge_name}", edge_val, epoch)
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)

        if xargs.snapshoot > 0 and epoch % xargs.snapshoot == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'args': deepcopy(args),
                    'search_model': search_model.state_dict(),
                },
                os.path.join(str(logger.model_dir),
                             f"checkpoint_epoch{epoch}.pth"), logger)

        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
            index = api.query_index_by_arch(genotypes[epoch])
            info = api.query_meta_info_by_index(
                index)  # This is an instance of `ArchResults`
            res_metrics = info.get_metrics(
                f'{xargs.dataset}',
                'ori-test')  # This is a dict with metric names as keys
            # cost_metrics = info.get_comput_costs('cifar10')
            writer.add_scalar(f'{xargs.dataset}_ground_acc_ori-test',
                              res_metrics['accuracy'], epoch)
            writer.add_scalar(f'{xargs.dataset}_search_acc', valid_a_top1,
                              epoch)
            if xargs.dataset.lower() != 'cifar10':
                writer.add_scalar(
                    f'{xargs.dataset}_ground_acc_x-test',
                    info.get_metrics(f'{xargs.dataset}', 'x-test')['accuracy'],
                    epoch)
            if find_best:
                valid_accuracies['best_gt'] = res_metrics['accuracy']
            writer.add_scalar(f"{xargs.dataset}_cur_best_gt_acc_ori-test",
                              valid_accuracies['best_gt'], epoch)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 100)
    logger.log('{:} : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
        args.model, xargs.early_stop_epoch, search_time.sum,
        genotypes[xargs.early_stop_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(
            api.query_by_arch(genotypes[xargs.early_stop_epoch - 1])))
    logger.close()
Пример #7
0
                        type=int,
                        nargs='?',
                        help='number of iterations for optimization method')
    # log
    parser.add_argument('--save_dir',
                        type=str,
                        default='./output/search',
                        help='Folder to save checkpoints and log.')
    parser.add_argument('--rand_seed',
                        type=int,
                        default=-1,
                        help='manual seed')
    args = parser.parse_args()

    if args.search_space == 'tss':
        api = NASBench201API(verbose=False)
    elif args.search_space == 'sss':
        api = NASBench301API(verbose=False)
    else:
        raise ValueError('Invalid search space : {:}'.format(
            args.search_space))

    args.save_dir = os.path.join(
        '{:}-{:}'.format(args.save_dir, args.search_space), args.dataset,
        'BOHB')
    print('save-dir : {:}'.format(args.save_dir))

    if args.rand_seed < 0:
        save_dir, all_info = None, collections.OrderedDict()
        for i in range(args.loops_if_rand):
            print('{:} : {:03d}/{:03d}'.format(time_string(), i,
Пример #8
0
def visualize_rank_over_time(meta_file, vis_save_dir):
    print('\n' + '-' * 150)
    vis_save_dir.mkdir(parents=True, exist_ok=True)
    print('{:} start to visualize rank-over-time into {:}'.format(
        time_string(), vis_save_dir))
    cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        print('{:} load nas_bench done'.format(time_string()))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(
            list), defaultdict(list), defaultdict(list), defaultdict(list)
        #for iepoch in range(200): for index in range( len(nas_bench) ):
        for index in tqdm(range(len(nas_bench))):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            for iepoch in range(200):
                res = info.get_metrics('cifar10', 'train', iepoch)
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid', iepoch)
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                otest_acc = res['accuracy']
                train_accs[iepoch].append(train_acc)
                valid_accs[iepoch].append(valid_acc)
                test_accs[iepoch].append(test_acc)
                otest_accs[iepoch].append(otest_acc)
                if iepoch == 0:
                    res = info.get_comput_costs('cifar10')
                    flop, param = res['flops'], res['params']
                    flops.append(flop)
                    params.append(param)
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
    print('{:} collect data done.'.format(time_string()))
    #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
    selected_epochs = list(range(200))
    x_xtests = test_accs[199]
    indexes = list(range(len(x_xtests)))
    ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
    for sepoch in selected_epochs:
        x_valids = valid_accs[sepoch]
        valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
        valid_ord_lbls = []
        for idx in ord_idxs:
            valid_ord_lbls.append(valid_ord_idxs.index(idx))
        # labeled data
        dpi, width, height = 300, 2600, 2600
        figsize = width / float(dpi), height / float(dpi)
        LabelSize, LegendFontsize = 18, 18

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
        plt.xlim(min(indexes), max(indexes))
        plt.ylim(min(indexes), max(indexes))
        plt.yticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize,
                   rotation='vertical')
        plt.xticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize)
        ax.scatter(indexes,
                   valid_ord_lbls,
                   marker='^',
                   s=0.5,
                   c='tab:green',
                   alpha=0.8)
        ax.scatter(indexes,
                   indexes,
                   marker='o',
                   s=0.5,
                   c='tab:blue',
                   alpha=0.8)
        ax.scatter([-1], [-1],
                   marker='^',
                   s=100,
                   c='tab:green',
                   label='CIFAR-10 validation')
        ax.scatter([-1], [-1],
                   marker='o',
                   s=100,
                   c='tab:blue',
                   label='CIFAR-10 test')
        plt.grid(zorder=0)
        ax.set_axisbelow(True)
        plt.legend(loc='upper left', fontsize=LegendFontsize)
        ax.set_xlabel('architecture ranking in the final test accuracy',
                      fontsize=LabelSize)
        ax.set_ylabel('architecture ranking in the validation set',
                      fontsize=LabelSize)
        save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
        save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
        print('{:} save into {:}'.format(time_string(), save_path))
        plt.close('all')
Пример #9
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    if xargs.overwite_epochs is None:
        extra_info = {'class_num': class_num, 'xshape': xshape}
    else:
        extra_info = {
            'class_num': class_num,
            'xshape': xshape,
            'epochs': xargs.overwite_epochs
        }
    config = load_config(xargs.config_path, extra_info, logger)
    search_loader, train_loader, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')

    model_config = dict2config(
        dict(name='generic',
             C=xargs.channel,
             N=xargs.num_cells,
             max_nodes=xargs.max_nodes,
             num_classes=class_num,
             space=search_space,
             affine=bool(xargs.affine),
             track_running_stats=bool(xargs.track_running_stats)), None)
    logger.log('search space : {:}'.format(search_space))
    logger.log('model config : {:}'.format(model_config))
    search_model = get_cell_based_tiny_net(model_config)
    search_model.set_algo(xargs.algo)
    logger.log('{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.weights, config)
    a_optimizer = torch.optim.Adam(search_model.alphas,
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay,
                                   eps=xargs.arch_eps)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    params = count_parameters_in_MB(search_model)
    logger.log('The parameters of the search model = {:.2f} MB'.format(params))
    logger.log('search-space : {:}'.format(search_space))
    if bool(xargs.use_api):
        api = API(verbose=False)
    else:
        api = None
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = search_model.cuda(), criterion.cuda(
    )  # use a single GPU

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        baseline = checkpoint['baseline']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: network.return_topK(1, True)[0]
        }
        baseline = None

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        network.set_drop_path(
            float(epoch + 1) / total_epoch, xargs.drop_path_rate)
        if xargs.algo == 'gdas':
            network.set_tau(xargs.tau_max -
                            (xargs.tau_max - xargs.tau_min) * epoch /
                            (total_epoch - 1))
            logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(
                network.tau, network.drop_path))
        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
        if xargs.algo == 'enas':
            ctl_loss, ctl_acc, baseline, ctl_reward \
                                       = train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger)
            logger.log(
                '[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'
                .format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.eval_candidate_num,
                                                xargs.algo)
        if xargs.algo == 'setn' or xargs.algo == 'enas':
            network.set_cal_mode('dynamic', genotype)
        elif xargs.algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif xargs.algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif xargs.algo == 'random':
            network.set_cal_mode('urs', None)
        else:
            raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
        logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(
            epoch_str, genotype, temp_accuracy))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion, xargs.algo, logger)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'baseline': baseline,
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.eval_candidate_num,
                                            xargs.algo)
    if xargs.algo == 'setn' or xargs.algo == 'enas':
        network.set_cal_mode('dynamic', genotype)
    elif xargs.algo == 'gdas':
        network.set_cal_mode('gdas', None)
    elif xargs.algo.startswith('darts'):
        network.set_cal_mode('joint', None)
    elif xargs.algo == 'random':
        network.set_cal_mode('urs', None)
    else:
        raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
    search_time.update(time.time() - start_time)

    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion, xargs.algo, logger)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
        xargs.algo, total_epoch, search_time.sum, genotype))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotype, '200')))
    logger.close()
Пример #10
0
from pathlib import Path
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from log_utils    import AverageMeter, time_string, convert_secs2time
from config_utils import load_config, dict2config
from datasets     import get_datasets
from models       import CellStructure, get_cell_based_tiny_net, get_search_spaces
from nats_bench   import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures   import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils        import get_md5_file
from nas_201_api  import NASBench201API


api = NASBench201API('{:}/.torch/NAS-Bench-201-v1_0-e61699.pth'.format(os.environ['HOME']))

NATS_TSS_BASE_NAME = 'NATS-tss-v1_0'  # 2020.08.28


def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, Any],
                        results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount:
  xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'],
                         results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
  net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None)
  if 'train_times' in results: # new version
    xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times'])
    xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times'])
  else:
    network = get_cell_based_tiny_net(net_config)
    network.load_state_dict(xresult.get_net_param())
Пример #11
0
def main(xargs):
    PID = os.getpid()
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    prepare_seed(xargs.rand_seed)

    if xargs.timestamp == 'none':
        xargs.timestamp = "{:}".format(
            time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))

    train_data, valid_data, xshape, class_num = get_datasets(xargs, -1)

    ##### config & logging #####
    config = edict()
    config.class_num = class_num
    config.xshape = xshape
    config.batch_size = xargs.batch_size
    xargs.save_dir = xargs.save_dir + \
        "/repeat%d-prunNum%d-prec%d-%s-batch%d"%(
                xargs.repeat, xargs.prune_number, xargs.precision, xargs.init, config["batch_size"]) + \
        "/{:}/seed{:}".format(xargs.timestamp, xargs.rand_seed)
    config.save_dir = xargs.save_dir
    logger = prepare_logger(xargs)
    ###############

    if xargs.dataset in [
            'MiniImageNet', 'MetaMiniImageNet', 'TieredImageNet',
            'MetaTieredImageNet'
    ]:
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=xargs.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
    elif xargs.dataset != 'imagenet-1k':
        search_loader, train_loader, valid_loader = get_nas_search_loaders(
            train_data, valid_data, xargs.dataset, 'configs/',
            config.batch_size, xargs.workers)
    else:
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=xargs.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)
    logger.log(
        '||||||| {:10s} ||||||| Train-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(train_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.search_space_name == 'nas-bench-201':
        model_config = edict({
            'name':
            'DARTS-V1',
            'C':
            3,
            'N':
            1,
            'depth':
            -1,
            'use_stem':
            True,
            'max_nodes':
            xargs.max_nodes,
            'num_classes':
            class_num,
            'space':
            search_space,
            'affine':
            True,
            'track_running_stats':
            bool(xargs.track_running_stats),
        })
        model_config_thin = edict({
            'name':
            'DARTS-V1',
            'C':
            1,
            'N':
            1,
            'depth':
            1,
            'use_stem':
            False,
            'max_nodes':
            xargs.max_nodes,
            'num_classes':
            class_num,
            'space':
            search_space,
            'affine':
            True,
            'track_running_stats':
            bool(xargs.track_running_stats),
        })
    elif xargs.search_space_name in ['darts', 'darts_fewshot']:
        model_config = edict({
            'name':
            'DARTS-V1',
            'C':
            1,
            'N':
            1,
            'depth':
            2,
            'use_stem':
            True,
            'stem_multiplier':
            1,
            'num_classes':
            class_num,
            'space':
            search_space,
            'affine':
            True,
            'track_running_stats':
            bool(xargs.track_running_stats),
            'super_type':
            xargs.super_type,
            'steps':
            xargs.max_nodes,
            'multiplier':
            xargs.max_nodes,
        })
        model_config_thin = edict({
            'name':
            'DARTS-V1',
            'C':
            1,
            'N':
            1,
            'depth':
            2,
            'use_stem':
            False,
            'stem_multiplier':
            1,
            'max_nodes':
            xargs.max_nodes,
            'num_classes':
            class_num,
            'space':
            search_space,
            'affine':
            True,
            'track_running_stats':
            bool(xargs.track_running_stats),
            'super_type':
            xargs.super_type,
            'steps':
            xargs.max_nodes,
            'multiplier':
            xargs.max_nodes,
        })
    network = get_cell_based_tiny_net(model_config)
    logger.log('model-config : {:}'.format(model_config))
    arch_parameters = [
        alpha.detach().clone() for alpha in network.get_alphas()
    ]
    for alpha in arch_parameters:
        alpha[:, :] = 0

    # TODO Linear_Region_Collector
    lrc_model = Linear_Region_Collector(xargs,
                                        input_size=(1000, 1, 3, 3),
                                        sample_batch=3,
                                        dataset=xargs.dataset,
                                        data_path=xargs.data_path,
                                        seed=xargs.rand_seed)

    # ### all params trainable (except train_bn) #########################
    flop, param = get_model_infos(network, xshape)
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None or xargs.search_space_name in [
            'darts', 'darts_fewshot'
    ]:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    network = network.cuda()

    genotypes = {}
    genotypes['arch'] = {
        -1: network.genotype()
    }

    arch_parameters_history = []
    arch_parameters_history_npy = []
    start_time = time.time()
    epoch = -1

    for alpha in arch_parameters:
        alpha[:, 0] = -INF
    arch_parameters_history.append(
        [alpha.detach().clone() for alpha in arch_parameters])
    arch_parameters_history_npy.append(
        [alpha.detach().clone().cpu().numpy() for alpha in arch_parameters])
    np.save(os.path.join(xargs.save_dir, "arch_parameters_history.npy"),
            arch_parameters_history_npy)
    while not is_single_path(network):
        epoch += 1
        torch.cuda.empty_cache()
        print("<< ============== JOB (PID = %d) %s ============== >>" %
              (PID, '/'.join(xargs.save_dir.split("/")[-6:])))

        arch_parameters, op_pruned = prune_func_rank(
            xargs,
            arch_parameters,
            model_config,
            model_config_thin,
            train_loader,
            lrc_model,
            search_space,
            precision=xargs.precision,
            prune_number=xargs.prune_number)
        # rebuild supernet
        network = get_cell_based_tiny_net(model_config)
        network = network.cuda()
        network.set_alphas(arch_parameters)

        arch_parameters_history.append(
            [alpha.detach().clone() for alpha in arch_parameters])
        arch_parameters_history_npy.append([
            alpha.detach().clone().cpu().numpy() for alpha in arch_parameters
        ])
        np.save(os.path.join(xargs.save_dir, "arch_parameters_history.npy"),
                arch_parameters_history_npy)
        genotypes['arch'][epoch] = network.genotype()

        logger.log('operators remaining (1s) and prunned (0s)\n{:}'.format(
            '\n'.join([
                str((alpha > -INF).int()) for alpha in network.get_alphas()
            ])))

    if xargs.search_space_name in ['darts', 'darts_fewshot']:
        print("===>>> Prune Edge Groups...")
        if xargs.max_nodes == 4:
            edge_groups = [(0, 2), (2, 5), (5, 9), (9, 14)]
        elif xargs.max_nodes == 3:
            edge_groups = [(0, 2), (2, 5), (5, 9)]
        arch_parameters = prune_func_rank_group(
            xargs,
            arch_parameters,
            model_config,
            model_config_thin,
            train_loader,
            lrc_model,
            search_space,
            edge_groups=edge_groups,
            num_per_group=2,
            precision=xargs.precision,
        )
        network = get_cell_based_tiny_net(model_config)
        network = network.cuda()
        network.set_alphas(arch_parameters)
        arch_parameters_history.append(
            [alpha.detach().clone() for alpha in arch_parameters])
        arch_parameters_history_npy.append([
            alpha.detach().clone().cpu().numpy() for alpha in arch_parameters
        ])
        np.save(os.path.join(xargs.save_dir, "arch_parameters_history.npy"),
                arch_parameters_history_npy)

    logger.log('<<<--->>> End: {:}'.format(network.genotype()))
    logger.log('operators remaining (1s) and prunned (0s)\n{:}'.format(
        '\n'.join(
            [str((alpha > -INF).int()) for alpha in network.get_alphas()])))

    end_time = time.time()
    logger.log('\n' + '-' * 100)
    logger.log("Time spent: %d s" % (end_time - start_time))
    # check the performance from the architecture dataset
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotypes['arch'][epoch])))

    logger.close()
Пример #12
0
def main(xargs, myargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(xargs)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset,
        'AutoDL-Projects/configs/nas-benchmark/',
        (config.batch_size, config.test_batch_size), xargs.num_worker)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if not hasattr(xargs, 'model_config') or xargs.model_config is None:
        model_config = dict2config(
            dict(name='SETN',
                 C=xargs.channel,
                 N=xargs.num_cells,
                 max_nodes=xargs.max_nodes,
                 num_classes=class_num,
                 space=search_space,
                 affine=False,
                 track_running_stats=bool(xargs.track_running_stats)), None)
    else:
        model_config = load_config(
            xargs.model_config,
            dict(num_classes=class_num,
                 space=search_space,
                 affine=False,
                 track_running_stats=bool(xargs.track_running_stats)), None)
    logger.log('search space : {:}'.format(search_space))
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space : {:}'.format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        init_genotype, _ = get_best_arch(valid_loader, network,
                                         xargs.select_num)
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: init_genotype
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.select_num)
        network.module.set_cal_mode('dynamic', genotype)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        #search_model.set_cal_mode('urs')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('joint')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('select')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.select_num)
    search_time.update(time.time() - start_time)
    network.module.set_cal_mode('dynamic', genotype)
    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotype))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotype, '200')))
    logger.close()
Пример #13
0
        'avoid loading NASBench2 api an instead load a pickle file with tuple (index, arch_str)'
    )
    args = parser.parse_args()
    args.device = torch.device(
        "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    return args


if __name__ == '__main__':
    args = parse_arguments()

    if args.noacc:
        api = pickle.load(open(args.api_loc, 'rb'))
    else:
        from nas_201_api import NASBench201API as API
        api = API(args.api_loc)

    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    train_loader, val_loader = get_cifar_dataloaders(args.batch_size,
                                                     args.batch_size,
                                                     args.dataset,
                                                     args.num_data_workers)

    cached_res = []
    pre = 'cf' if 'cifar' in args.dataset else 'im'
    pfn = f'nb2_{pre}{get_num_classes(args)}_seed{args.seed}_dl{args.dataload}_dlinfo{args.dataload_info}_initw{args.init_w_type}_initb{args.init_b_type}.p'
    op = os.path.join(args.outdir, pfn)
Пример #14
0
def eval_score(jacob, labels=None):
    corrs = np.corrcoef(jacob)
    v, _  = np.linalg.eig(corrs)
    k = 1e-5
    return -np.sum(np.log(v + k) + 1./(v + k))



if(args.use_GPU==True):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)
THE_START = time.time()
api = API(args.api_loc)
print("API loaded")
os.makedirs(args.save_loc, exist_ok=True)

train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, cutout=0)

if args.dataset == 'cifar10':
    acc_type = 'ori-test'
    val_acc_type = 'x-valid'

else:
    acc_type = 'x-test'
    val_acc_type = 'x-valid'

if args.trainval:
    cifar_split = load_config('config_utils/cifar-split.txt', None, None)
Пример #15
0
def main():
    api = API(None)
    info = api.get_more_info(100, 'cifar100', 199, False, True)
Пример #16
0
def main():
    torch.set_num_threads(3)
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    
    if not 'debug' in args.save:
        api = API('pth file path')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    if args.method == 'snas':
        # Create the decrease step for the gumbel softmax temperature
        args.epochs = 100
        tau_step = (args.tau_min - args.tau_max) / args.epochs
        tau_epoch = args.tau_max
        model = TinyNetwork(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes,
                            criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='gumbel')
    elif args.method == 'dirichlet':
        model = TinyNetwork(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes,
                            criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='dirichlet')
    elif args.method == 'darts':
        model = TinyNetwork(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes,
                            criterion=criterion, search_space=NAS_BENCH_201, k=args.k, species='softmax')
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        model.get_weights(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    if args.dataset == 'cifar10':
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
        train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    elif args.dataset == 'cifar100':
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
        train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
    elif args.dataset == 'svhn':
        train_transform, valid_transform = utils._data_transforms_svhn(args)
        train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
    elif args.dataset == 'imagenet16-120':
        import torchvision.transforms as transforms
        from nasbench201.DownsampledImageNet import ImageNet16
        mean = [x / 255 for x in [122.68, 116.66, 104.01]]
        std = [x / 255 for x in [63.22,  61.26, 65.09]]
        lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
        train_transform = transforms.Compose(lists)
        train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120)
        assert len(train_data) == 151700

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True)

    architect = Architect(model, args)
    
    # configure progressive parameter
    epoch = 0
    ks = [4, 2]
    num_keeps = [5, 3]
    train_epochs = [2, 2] if 'debug' in args.save else [50, 50]
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(sum(train_epochs)), eta_min=args.learning_rate_min)
    
    for i, current_epochs in enumerate(train_epochs):
        for e in range(current_epochs):
            lr = scheduler.get_lr()[0]
            logging.info('epoch %d lr %e', epoch, lr)
            genotype = model.genotype()
            logging.info('genotype = %s', genotype)
            model.show_arch_parameters()

            # training
            train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, e)
            logging.info('train_acc %f', train_acc)

            # validation
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

            if not 'debug' in args.save:
                # nasbench201
                result = api.query_by_arch(model.genotype())
                logging.info('{:}'.format(result))
                cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \
                    cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result)
                logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test)
                logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test)
                logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test)

                # tensorboard
                writer.add_scalars('accuracy', {'train':train_acc,'valid':valid_acc}, epoch)
                writer.add_scalars('loss', {'train':train_obj,'valid':valid_obj}, epoch)
                writer.add_scalars('nasbench201/cifar10', {'train':cifar10_train,'test':cifar10_test}, epoch)
                writer.add_scalars('nasbench201/cifar100', {'train':cifar100_train,'valid':cifar100_valid, 'test':cifar100_test}, epoch)
                writer.add_scalars('nasbench201/imagenet16', {'train':imagenet16_train,'valid':imagenet16_valid, 'test':imagenet16_test}, epoch)

                utils.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'alpha': model.arch_parameters()
                }, False, args.save)
                
            epoch += 1
            scheduler.step()
            if args.method == 'snas':
                # Decrease the temperature for the gumbel softmax linearly
                tau_epoch += tau_step
                logging.info('tau %f', tau_epoch)
                model.set_tau(tau_epoch)

        if not i == len(train_epochs) - 1:
            model.pruning(num_keeps[i+1])
            # architect.pruning([model._mask])
            model.wider(ks[i+1])
            optimizer = configure_optimizer(optimizer, torch.optim.SGD(
                model.get_weights(),
                args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay))
            scheduler = configure_scheduler(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, float(sum(train_epochs)), eta_min=args.learning_rate_min))
            logging.info('pruning finish, %d ops left per edge', num_keeps[i+1])
            logging.info('network wider finish, current pc parameter %d', ks[i+1])

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)
    model.show_arch_parameters()
    writer.close()
Пример #17
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, test_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    logger.log('use config from : {:}'.format(xargs.config_path))
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    _, train_loader, valid_loader = get_nas_search_loaders(
        train_data, test_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
    valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
    if hasattr(valid_loader.dataset, 'transforms'):
        valid_loader.dataset.transforms = deepcopy(
            train_loader.dataset.transforms)
    # data loader
    logger.log(
        '||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(train_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)

    if xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'ENAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)

    shared_cnn = get_cell_based_tiny_net(model_config)
    controller = shared_cnn.create_controller()

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        shared_cnn.parameters(), config)
    a_optimizer = torch.optim.Adam(controller.parameters(),
                                   lr=config.controller_lr,
                                   betas=config.controller_betas,
                                   eps=config.controller_eps)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    #flop, param  = get_model_infos(shared_cnn, xshape)
    #logger.log('{:}'.format(shared_cnn))
    #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space : {:}'.format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))
    shared_cnn, controller, criterion = torch.nn.DataParallel(
        shared_cnn).cuda(), controller.cuda(), criterion.cuda()

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        baseline = checkpoint['baseline']
        valid_accuracies = checkpoint['valid_accuracies']
        shared_cnn.load_state_dict(checkpoint['shared_cnn'])
        controller.load_state_dict(checkpoint['controller'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes, baseline = 0, {
            'best': -1
        }, {}, None

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log(
            '\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(
                epoch_str, need_time, min(w_scheduler.get_lr()), baseline))

        cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(
            train_loader, shared_cnn, controller, criterion, w_scheduler,
            w_optimizer, epoch_str, xargs.print_freq, logger)
        logger.log(
            '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
        ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \
                                     = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \
                                                            dict2config({'baseline': baseline,
                                                                         'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate,
                                                                         'ctl_entropy_w': xargs.controller_entropy_weight,
                                                                         'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \
                                                            epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'
            .format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward,
                    baseline, search_time.sum))
        best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
        shared_cnn.module.update_arch(best_arch)
        _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)

        genotypes[epoch] = best_arch
        # check the best accuracy
        valid_accuracies[epoch] = best_valid_acc
        if best_valid_acc > valid_accuracies['best']:
            valid_accuracies['best'] = best_valid_acc
            genotypes['best'] = best_arch
            find_best = True
        else:
            find_best = False

        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'baseline': baseline,
                'shared_cnn': shared_cnn.state_dict(),
                'controller': controller.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, best_valid_acc))
            copy_checkpoint(model_base_path, model_best_path, logger)
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 100)
    logger.log('During searching, the best architecture is {:}'.format(
        genotypes['best']))
    logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
    logger.log('Randomly select {:} architectures and select the best.'.format(
        xargs.controller_num_samples))
    start_time = time.time()
    final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader,
                                  xargs.controller_num_samples)
    search_time.update(time.time() - start_time)
    shared_cnn.module.update_arch(final_arch)
    final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn,
                                                    criterion)
    logger.log('The Selected Final Architecture : {:}'.format(final_arch))
    logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(
        final_loss, final_top1, final_top5))
    logger.log(
        'ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, final_arch))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(final_arch)))
    logger.close()
Пример #18
0
class MacroGraph(NodeOpGraph):
    def __init__(self, config, primitives, ops_dict, *args, **kwargs):
        self.config = config
        self.primitives = primitives
        self.ops_dict = ops_dict
        self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        super(MacroGraph, self).__init__(*args, **kwargs)

    def _build_graph(self):
        num_cells_per_stack = self.config['num_cells_per_stack']
        C = self.config['init_channels']
        layer_channels = [C] * num_cells_per_stack + [C * 2] + [C * 2] * num_cells_per_stack + [C * 4] + [
            C * 4] * num_cells_per_stack
        layer_reductions = [False] * num_cells_per_stack + [True] + [False] * num_cells_per_stack + [True] + [
            False] * num_cells_per_stack

        stem = NASBENCH_201_Stem(C=C)
        self.add_node(0, type='input')
        self.add_node(1, op=stem, type='stem')

        C_prev = C
        self.cells = nn.ModuleList()
        for cell_num, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
            if reduction:
                cell = ResNetBasicblock(C_prev, C_curr, 2, True)
                self.add_node(cell_num + 2, op=cell, primitives=self.primitives, transform=lambda x: x[0])
            else:
                cell = Cell(primitives=self.primitives, stride=1, C_prev=C_prev, C=C_curr,
                            ops_dict=self.ops_dict, cell_type='normal')
                self.add_node(cell_num + 2, op=cell, primitives=self.primitives)

            C_prev = C_curr

        lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
        pooling = nn.AdaptiveAvgPool2d(1)
        classifier = nn.Linear(C_prev, self.config['num_classes'])

        self.add_node(cell_num + 3, op=lastact, transform=lambda x: x[0], type='postprocessing_nb201')
        self.add_node(cell_num + 4, op=pooling, transform=lambda x: x[0], type='pooling')
        self.add_node(cell_num + 5, op=classifier, transform=lambda x: x[0].view(x[0].size(0), -1),
                      type='output')

        # Edges
        for i in range(1, cell_num + 6):
            self.add_edge(i - 1, i, type='input', desc='previous')

    def sample(self, same_cell_struct=True, n_ops_per_edge=1,
               n_input_edges=None, dist=None, seed=1):
        """
        same_cell_struct:
            True; if the sampled cell topology is the same or not
        n_ops_per_edge:
            1; number of sampled operations per edge in cell
        n_input_edges:
            None; list equal with length with number of intermediate
        nodes. Determines the number of predecesor nodes for each of them
        dist:
            None; distribution to sample operations in edges from
        seed:
            1; random seed
        """
        # create a new graph that we will discretize
        new_graph = MacroGraph(self.config, self.primitives, self.ops_dict)
        np.random.seed(seed)
        seeds = {'normal': seed + 1, 'reduction': seed + 2}

        for node in new_graph:
            cell = new_graph.get_node_op(node)
            if not isinstance(cell, Cell):
                continue

            if same_cell_struct:
                np.random.seed(seeds[new_graph.get_node_type(node)])

            for edge in cell.edges:
                op_choices = cell.get_edge_op_choices(*edge)
                sampled_op = np.random.choice(op_choices, n_ops_per_edge,
                                              False, p=dist)
                cell[edge[0]][edge[1]]['op_choices'] = [*sampled_op]

            if n_input_edges is not None:
                for inter_node, k in zip(cell.inter_nodes(), n_input_edges):
                    # in case the start node index is not 0
                    node_idx = list(cell.nodes).index(inter_node)
                    prev_node_choices = list(cell.nodes)[:node_idx]
                    assert k <= len(prev_node_choices), 'cannot sample more'
                    ' than number of predecesor nodes'

                    sampled_input_edges = np.random.choice(prev_node_choices,
                                                           k, False)
                    for i in set(prev_node_choices) - set(sampled_input_edges):
                        cell.remove_edge(i, inter_node)

        return new_graph

    @classmethod
    def from_config(cls, config=None, filename=None):
        with open(filename, 'r') as f:
            graph_dict = yaml.safe_load(f)

        if config is None:
            raise ('No configuration provided')

        graph = cls(config, [])

        graph_type = graph_dict['type']
        edges = [(*eval(e), attr) for e, attr in graph_dict['edges'].items()]
        graph.clear()
        graph.add_edges_from(edges)

        C = config['init_channels']
        C_curr = config['stem_multiplier'] * C

        stem = Stem(C_curr=C_curr)
        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C

        for node, attr in graph_dict['nodes'].items():
            node_type = attr['type']
            if node_type == 'input':
                graph.add_node(node, type='input')
            elif node_type == 'stem':
                graph.add_node(node, op=stem, type='stem')
            elif node_type in ['normal', 'reduction']:
                assert attr['op']['type'] == 'Cell'
                graph.add_node(node,
                               op=Cell.from_config(attr['op'], primitives=attr['op']['primitives'],
                                                   C_prev_prev=C_prev_prev, C_prev=C_prev,
                                                   C=C_curr,
                                                   reduction_prev=graph_dict['nodes'][node - 1]['type'] == 'reduction',
                                                   cell_type=node_type),
                               type=node_type)
                C_prev_prev, C_prev = C_prev, config['channel_multiplier'] * C_curr
            elif node_type == 'pooling':
                pooling = nn.AdaptiveAvgPool2d(1)
                graph.add_node(node, op=pooling, transform=lambda x: x[0],
                               type='pooling')
            elif node_type == 'output':
                classifier = nn.Linear(C_prev, config['num_classes'])
                graph.add_node(node, op=classifier, transform=lambda x:
                x[0].view(x[0].size(0), -1), type='output')

        return graph

    @staticmethod
    def export_nasbench_201_results_to_dict(information):
        results_dict = {}
        dataset_names = information.get_dataset_names()
        results_dict['arch'] = information.arch_str
        results_dict['datasets'] = dataset_names

        for ida, dataset in enumerate(dataset_names):
            dataset_results = {}
            dataset_results['dataset'] = dataset

            metric = information.get_compute_costs(dataset)
            flop, param, latency, training_time = metric['flops'], metric['params'], metric['latency'], metric[
                'T-train@total']
            dataset_results['flop'] = flop
            dataset_results['params (MB)'] = param
            dataset_results['latency (ms)'] = latency * 1000 if latency is not None and latency > 0 else None
            dataset_results['training_time'] = training_time

            train_info = information.get_metrics(dataset, 'train')
            if dataset == 'cifar10-valid':
                valid_info = information.get_metrics(dataset, 'x-valid')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['valid_loss'] = valid_info['loss']
                dataset_results['valid_accuracy'] = valid_info['accuracy']

            elif dataset == 'cifar10':
                test__info = information.get_metrics(dataset, 'ori-test')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['test_loss'] = test__info['loss']
                dataset_results['test_accuracy'] = test__info['accuracy']
            else:
                valid_info = information.get_metrics(dataset, 'x-valid')
                test__info = information.get_metrics(dataset, 'x-test')
                dataset_results['train_loss'] = train_info['loss']
                dataset_results['train_accuracy'] = train_info['accuracy']

                dataset_results['valid_loss'] = valid_info['loss']
                dataset_results['valid_accuracy'] = valid_info['accuracy']

                dataset_results['test_loss'] = test__info['loss']
                dataset_results['test_accuracy'] = test__info['accuracy']
            results_dict[dataset] = dataset_results
        return results_dict

    def query_architecture(self, arch_weights):
        arch_weight_idx_to_parent = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2}
        arch_strs = {
            'cell_normal_from_0_to_1': '',
            'cell_normal_from_0_to_2': '',
            'cell_normal_from_1_to_2': '',
            'cell_normal_from_0_to_3': '',
            'cell_normal_from_1_to_3': '',
            'cell_normal_from_2_to_3': '',
        }
        for arch_weight_idx, (edge_key, edge_weights) in enumerate(arch_weights.items()):
            edge_weights_norm = torch.softmax(edge_weights, dim=-1)
            selected_op_str = PRIMITIVES[edge_weights_norm.argmax()]
            arch_strs[edge_key] = '{}~{}'.format(selected_op_str, arch_weight_idx_to_parent[arch_weight_idx])

        arch_str = '|{}|+|{}|{}|+|{}|{}|{}|'.format(*arch_strs.values())
        if not hasattr(self, 'nasbench_api'):
            self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
        index = self.nasbench_api.query_index_by_arch(arch_str)
        self.nasbench_api.show(index)
        info = self.nasbench_api.query_by_index(index)
        return self.export_nasbench_201_results_to_dict(info)
Пример #19
0
def main(xargs):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1
    )
    # config_path = 'configs/nas-benchmark/algos/DARTS.config'
    config = load_config(
        xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
    )
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data,
        valid_data,
        xargs.dataset,
        "configs/nas-benchmark/",
        config.batch_size,
        xargs.workers,
    )
    logger.log(
        "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
            xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
        )
    )
    logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))

    search_space = get_search_spaces("cell", xargs.search_space_name)
    if xargs.model_config is None:
        model_config = dict2config(
            {
                "name": "DARTS-V1",
                "C": xargs.channel,
                "N": xargs.num_cells,
                "max_nodes": xargs.max_nodes,
                "num_classes": class_num,
                "space": search_space,
                "affine": False,
                "track_running_stats": bool(xargs.track_running_stats),
            },
            None,
        )
    else:
        model_config = load_config(
            xargs.model_config,
            {
                "num_classes": class_num,
                "space": search_space,
                "affine": False,
                "track_running_stats": bool(xargs.track_running_stats),
            },
            None,
        )
    search_model = get_cell_based_tiny_net(model_config)
    logger.log("search-model :\n{:}".format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config
    )
    a_optimizer = torch.optim.Adam(
        search_model.get_alphas(),
        lr=xargs.arch_learning_rate,
        betas=(0.5, 0.999),
        weight_decay=xargs.arch_weight_decay,
    )
    logger.log("w-optimizer : {:}".format(w_optimizer))
    logger.log("a-optimizer : {:}".format(a_optimizer))
    logger.log("w-scheduler : {:}".format(w_scheduler))
    logger.log("criterion   : {:}".format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    # logger.log('{:}'.format(search_model))
    logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log("{:} create API = {:} done".format(time_string(), api))

    last_info, model_base_path, model_best_path = (
        logger.path("info"),
        logger.path("model"),
        logger.path("best"),
    )
    network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start".format(last_info)
        )
        last_info = torch.load(last_info)
        start_epoch = last_info["epoch"]
        checkpoint = torch.load(last_info["last_checkpoint"])
        genotypes = checkpoint["genotypes"]
        valid_accuracies = checkpoint["valid_accuracies"]
        search_model.load_state_dict(checkpoint["search_model"])
        w_scheduler.load_state_dict(checkpoint["w_scheduler"])
        w_optimizer.load_state_dict(checkpoint["w_optimizer"])
        a_optimizer.load_state_dict(checkpoint["a_optimizer"])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
                last_info, start_epoch
            )
        )
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = (
            0,
            {"best": -1},
            {-1: search_model.genotype()},
        )

    # start training
    start_time, search_time, epoch_time, total_epoch = (
        time.time(),
        AverageMeter(),
        AverageMeter(),
        config.epochs + config.warmup,
    )
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = "Time Left: {:}".format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
        )
        epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
        logger.log(
            "\n[Search the {:}-th epoch] {:}, LR={:}".format(
                epoch_str, need_time, min(w_scheduler.get_lr())
            )
        )

        search_w_loss, search_w_top1, search_w_top5 = search_func(
            search_loader,
            network,
            criterion,
            w_scheduler,
            w_optimizer,
            a_optimizer,
            epoch_str,
            xargs.print_freq,
            logger,
            xargs.gradient_clip,
        )
        search_time.update(time.time() - start_time)
        logger.log(
            "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
                epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
            )
        )
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion
        )
        logger.log(
            "[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
                epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
            )
        )
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies["best"]:
            valid_accuracies["best"] = valid_a_top1
            genotypes["best"] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log(
            "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
        )
        # save checkpoint
        save_path = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(xargs),
                "search_model": search_model.state_dict(),
                "w_optimizer": w_optimizer.state_dict(),
                "a_optimizer": a_optimizer.state_dict(),
                "w_scheduler": w_scheduler.state_dict(),
                "genotypes": genotypes,
                "valid_accuracies": valid_accuracies,
            },
            model_base_path,
            logger,
        )
        last_info = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(args),
                "last_checkpoint": save_path,
            },
            logger.path("info"),
            logger,
        )
        if find_best:
            logger.log(
                "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
                    epoch_str, valid_a_top1
                )
            )
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            # logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
            logger.log("{:}".format(search_model.show_alphas()))
        if api is not None:
            logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log("\n" + "-" * 100)
    logger.log(
        "DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
            total_epoch, search_time.sum, genotypes[total_epoch - 1]
        )
    )
    if api is not None:
        logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200")))
    logger.close()
Пример #20
0
 def __init__(self, config, primitives, ops_dict, *args, **kwargs):
     self.config = config
     self.primitives = primitives
     self.ops_dict = ops_dict
     self.nasbench_api = API('/home/siemsj/nasbench_201.pth')
     super(MacroGraph, self).__init__(*args, **kwargs)
Пример #21
0
def visualize_info(meta_file, dataset, vis_save_dir):
    print('{:} start to visualize {:} information'.format(
        time_string(), dataset))
    cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
        for index in range(len(nas_bench)):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            resx = info.get_comput_costs(dataset)
            flop, param = resx['flops'], resx['params']
            if dataset == 'cifar10':
                res = info.get_metrics('cifar10', 'train')
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                otest_acc = res['accuracy']
            else:
                res = info.get_metrics(dataset, 'train')
                train_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-test')
                test_acc = res['accuracy']
                res = info.get_metrics(dataset, 'ori-test')
                otest_acc = res['accuracy']
            if index == 11472:  # resnet
                resnet = {
                    'params': param,
                    'flops': flop,
                    'index': 11472,
                    'train_acc': train_acc,
                    'valid_acc': valid_acc,
                    'test_acc': test_acc,
                    'otest_acc': otest_acc
                }
            flops.append(flop)
            params.append(param)
            train_accs.append(train_acc)
            valid_accs.append(valid_acc)
            test_accs.append(test_acc)
            otest_accs.append(otest_acc)
        #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        info['resnet'] = resnet
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
        resnet = info['resnet']
    print('{:} collect data done.'.format(time_string()))

    indexes = list(range(len(params)))
    dpi, width, height = 300, 2600, 2600
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 22, 22
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['valid_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=0.4)
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(20, 100)
        plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(25, 76)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['train_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(0, max(indexes))
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 5),
               fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['index']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ID', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Пример #22
0
def main(xargs):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, test_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    logger.log("use config from : {:}".format(xargs.config_path))
    config = load_config(xargs.config_path, {
        "class_num": class_num,
        "xshape": xshape
    }, logger)
    _, train_loader, valid_loader = get_nas_search_loaders(
        train_data,
        test_data,
        xargs.dataset,
        "configs/nas-benchmark/",
        config.batch_size,
        xargs.workers,
    )
    # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
    valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
    if hasattr(valid_loader.dataset, "transforms"):
        valid_loader.dataset.transforms = deepcopy(
            train_loader.dataset.transforms)
    # data loader
    logger.log(
        "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}"
        .format(xargs.dataset, len(train_loader), len(valid_loader),
                config.batch_size))
    logger.log("||||||| {:10s} ||||||| Config={:}".format(
        xargs.dataset, config))

    search_space = get_search_spaces("cell", xargs.search_space_name)
    model_config = dict2config(
        {
            "name": "ENAS",
            "C": xargs.channel,
            "N": xargs.num_cells,
            "max_nodes": xargs.max_nodes,
            "num_classes": class_num,
            "space": search_space,
            "affine": False,
            "track_running_stats": bool(xargs.track_running_stats),
        },
        None,
    )
    shared_cnn = get_cell_based_tiny_net(model_config)
    controller = shared_cnn.create_controller()

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        shared_cnn.parameters(), config)
    a_optimizer = torch.optim.Adam(
        controller.parameters(),
        lr=config.controller_lr,
        betas=config.controller_betas,
        eps=config.controller_eps,
    )
    logger.log("w-optimizer : {:}".format(w_optimizer))
    logger.log("a-optimizer : {:}".format(a_optimizer))
    logger.log("w-scheduler : {:}".format(w_scheduler))
    logger.log("criterion   : {:}".format(criterion))
    # flop, param  = get_model_infos(shared_cnn, xshape)
    # logger.log('{:}'.format(shared_cnn))
    # logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log("search-space : {:}".format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log("{:} create API = {:} done".format(time_string(), api))
    shared_cnn, controller, criterion = (
        torch.nn.DataParallel(shared_cnn).cuda(),
        controller.cuda(),
        criterion.cuda(),
    )

    last_info, model_base_path, model_best_path = (
        logger.path("info"),
        logger.path("model"),
        logger.path("best"),
    )

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info["epoch"]
        checkpoint = torch.load(last_info["last_checkpoint"])
        genotypes = checkpoint["genotypes"]
        baseline = checkpoint["baseline"]
        valid_accuracies = checkpoint["valid_accuracies"]
        shared_cnn.load_state_dict(checkpoint["shared_cnn"])
        controller.load_state_dict(checkpoint["controller"])
        w_scheduler.load_state_dict(checkpoint["w_scheduler"])
        w_optimizer.load_state_dict(checkpoint["w_optimizer"])
        a_optimizer.load_state_dict(checkpoint["a_optimizer"])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes, baseline = 0, {
            "best": -1
        }, {}, None

    # start training
    start_time, search_time, epoch_time, total_epoch = (
        time.time(),
        AverageMeter(),
        AverageMeter(),
        config.epochs + config.warmup,
    )
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = "Time Left: {:}".format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
        logger.log(
            "\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format(
                epoch_str, need_time, min(w_scheduler.get_lr()), baseline))

        cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(
            train_loader,
            shared_cnn,
            controller,
            criterion,
            w_scheduler,
            w_optimizer,
            epoch_str,
            xargs.print_freq,
            logger,
        )
        logger.log(
            "[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%"
            .format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
        ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline = train_controller(
            valid_loader,
            shared_cnn,
            controller,
            criterion,
            a_optimizer,
            dict2config(
                {
                    "baseline": baseline,
                    "ctl_train_steps": xargs.controller_train_steps,
                    "ctl_num_aggre": xargs.controller_num_aggregate,
                    "ctl_entropy_w": xargs.controller_entropy_weight,
                    "ctl_bl_dec": xargs.controller_bl_dec,
                },
                None,
            ),
            epoch_str,
            xargs.print_freq,
            logger,
        )
        search_time.update(time.time() - start_time)
        logger.log(
            "[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s"
            .format(
                epoch_str,
                ctl_loss,
                ctl_acc,
                ctl_baseline,
                ctl_reward,
                baseline,
                search_time.sum,
            ))
        best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
        shared_cnn.module.update_arch(best_arch)
        _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)

        genotypes[epoch] = best_arch
        # check the best accuracy
        valid_accuracies[epoch] = best_valid_acc
        if best_valid_acc > valid_accuracies["best"]:
            valid_accuracies["best"] = best_valid_acc
            genotypes["best"] = best_arch
            find_best = True
        else:
            find_best = False

        logger.log("<<<--->>> The {:}-th epoch : {:}".format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(xargs),
                "baseline": baseline,
                "shared_cnn": shared_cnn.state_dict(),
                "controller": controller.state_dict(),
                "w_optimizer": w_optimizer.state_dict(),
                "a_optimizer": a_optimizer.state_dict(),
                "w_scheduler": w_scheduler.state_dict(),
                "genotypes": genotypes,
                "valid_accuracies": valid_accuracies,
            },
            model_base_path,
            logger,
        )
        last_info = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(args),
                "last_checkpoint": save_path,
            },
            logger.path("info"),
            logger,
        )
        if find_best:
            logger.log(
                "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%."
                .format(epoch_str, best_valid_acc))
            copy_checkpoint(model_base_path, model_best_path, logger)
        if api is not None:
            logger.log("{:}".format(api.query_by_arch(genotypes[epoch],
                                                      "200")))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log("\n" + "-" * 100)
    logger.log("During searching, the best architecture is {:}".format(
        genotypes["best"]))
    logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"]))
    logger.log("Randomly select {:} architectures and select the best.".format(
        xargs.controller_num_samples))
    start_time = time.time()
    final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader,
                                  xargs.controller_num_samples)
    search_time.update(time.time() - start_time)
    shared_cnn.module.update_arch(final_arch)
    final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn,
                                                    criterion)
    logger.log("The Selected Final Architecture : {:}".format(final_arch))
    logger.log("Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format(
        final_loss, final_top1, final_top5))
    logger.log(
        "ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
            total_epoch, search_time.sum, final_arch))
    if api is not None:
        logger.log("{:}".format(api.query_by_arch(final_arch)))
    logger.close()
Пример #23
0
def main(xargs):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        "class_num": class_num,
        "xshape": xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data,
        valid_data,
        xargs.dataset,
        "configs/nas-benchmark/",
        (config.batch_size, config.test_batch_size),
        xargs.workers,
    )
    logger.log(
        "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}"
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log("||||||| {:10s} ||||||| Config={:}".format(
        xargs.dataset, config))

    search_space = get_search_spaces("cell", xargs.search_space_name)
    if xargs.model_config is None:
        model_config = dict2config(
            dict(
                name="SETN",
                C=xargs.channel,
                N=xargs.num_cells,
                max_nodes=xargs.max_nodes,
                num_classes=class_num,
                space=search_space,
                affine=False,
                track_running_stats=bool(xargs.track_running_stats),
            ),
            None,
        )
    else:
        model_config = load_config(
            xargs.model_config,
            dict(
                num_classes=class_num,
                space=search_space,
                affine=False,
                track_running_stats=bool(xargs.track_running_stats),
            ),
            None,
        )
    logger.log("search space : {:}".format(search_space))
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(
        search_model.get_alphas(),
        lr=xargs.arch_learning_rate,
        betas=(0.5, 0.999),
        weight_decay=xargs.arch_weight_decay,
    )
    logger.log("w-optimizer : {:}".format(w_optimizer))
    logger.log("a-optimizer : {:}".format(a_optimizer))
    logger.log("w-scheduler : {:}".format(w_scheduler))
    logger.log("criterion   : {:}".format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
    logger.log("search-space : {:}".format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log("{:} create API = {:} done".format(time_string(), api))

    last_info, model_base_path, model_best_path = (
        logger.path("info"),
        logger.path("model"),
        logger.path("best"),
    )
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info["epoch"]
        checkpoint = torch.load(last_info["last_checkpoint"])
        genotypes = checkpoint["genotypes"]
        valid_accuracies = checkpoint["valid_accuracies"]
        search_model.load_state_dict(checkpoint["search_model"])
        w_scheduler.load_state_dict(checkpoint["w_scheduler"])
        w_optimizer.load_state_dict(checkpoint["w_optimizer"])
        a_optimizer.load_state_dict(checkpoint["a_optimizer"])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        init_genotype, _ = get_best_arch(valid_loader, network,
                                         xargs.select_num)
        start_epoch, valid_accuracies, genotypes = 0, {
            "best": -1
        }, {
            -1: init_genotype
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = (
        time.time(),
        AverageMeter(),
        AverageMeter(),
        config.epochs + config.warmup,
    )
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = "Time Left: {:}".format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
        logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        (
            search_w_loss,
            search_w_top1,
            search_w_top5,
            search_a_loss,
            search_a_top1,
            search_a_top5,
        ) = search_func(
            search_loader,
            network,
            criterion,
            w_scheduler,
            w_optimizer,
            a_optimizer,
            epoch_str,
            xargs.print_freq,
            logger,
        )
        search_time.update(time.time() - start_time)
        logger.log(
            "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s"
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%"
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.select_num)
        network.module.set_cal_mode("dynamic", genotype)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}"
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        # search_model.set_cal_mode('urs')
        # valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        # logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # search_model.set_cal_mode('joint')
        # valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        # logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # search_model.set_cal_mode('select')
        # valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        # logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log("<<<--->>> The {:}-th epoch : {:}".format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(xargs),
                "search_model": search_model.state_dict(),
                "w_optimizer": w_optimizer.state_dict(),
                "a_optimizer": a_optimizer.state_dict(),
                "w_scheduler": w_scheduler.state_dict(),
                "genotypes": genotypes,
                "valid_accuracies": valid_accuracies,
            },
            model_base_path,
            logger,
        )
        last_info = save_checkpoint(
            {
                "epoch": epoch + 1,
                "args": deepcopy(args),
                "last_checkpoint": save_path,
            },
            logger.path("info"),
            logger,
        )
        with torch.no_grad():
            logger.log("{:}".format(search_model.show_alphas()))
        if api is not None:
            logger.log("{:}".format(api.query_by_arch(genotypes[epoch],
                                                      "200")))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.select_num)
    search_time.update(time.time() - start_time)
    network.module.set_cal_mode("dynamic", genotype)
    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion)
    logger.log(
        "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%."
        .format(genotype, valid_a_top1))

    logger.log("\n" + "-" * 100)
    # check the performance from the architecture dataset
    logger.log(
        "SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
            total_epoch, search_time.sum, genotype))
    if api is not None:
        logger.log("{:}".format(api.query_by_arch(genotype, "200")))
    logger.close()
Пример #24
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
                                          (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config(
        {
            'name': 'RANDOM',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.parameters(), config)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    if xargs.arch_nas_dataset is None: api = None
    else: api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
        search_w_loss, search_w_top1, search_w_top5 = search_func(
            search_loader, network, criterion, w_scheduler, w_optimizer,
            epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        cur_arch, cur_valid_acc = search_find_best(valid_loader, network,
                                                   xargs.select_num)
        logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(
            epoch_str, cur_arch, cur_valid_acc))
        genotypes[epoch] = cur_arch
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            find_best = True
        else:
            find_best = False

        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 200)
    logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
    start_time = time.time()
    best_arch, best_acc = search_find_best(valid_loader, network,
                                           xargs.select_num)
    search_time.update(time.time() - start_time)
    logger.log(
        'RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'
        .format(best_arch, best_acc, search_time.sum))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch)))
    logger.close()
################################################################################################
import sys, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api  import NASBench201API as API

if __name__ == '__main__':
  parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
  parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-201 benchmark file.')
  args = parser.parse_args()

  meta_file = Path(args.api_path)
  assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)

  api = API(str(meta_file))

  # This will show the results of the best architecture based on the validation set of each dataset.
  arch_index, accuracy = api.find_best('cifar10-valid', 'x-valid', None, None, False)
  print('FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::')
  print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index)))
  api.show(arch_index)
  print('')

  arch_index, accuracy = api.find_best('cifar100', 'x-valid', None, None, False)
  print('FOR CIFAR-100, using the hyper-parameters with 200 training epochs :::')
  print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index)))
  api.show(arch_index)
  print('')

  arch_index, accuracy = api.find_best('ImageNet16-120', 'x-valid', None, None, False)
Пример #26
0
from collections import defaultdict, OrderedDict
from typing import Dict, Any, Text, List

lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
    sys.path.insert(0, str(lib_dir))
from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import load_config, dict2config
from datasets import get_datasets
from models import CellStructure, get_cell_based_tiny_net, get_search_spaces
from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file
from nas_201_api import NASBench201API

api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(
    os.environ["HOME"]))

NATS_TSS_BASE_NAME = "NATS-tss-v1_0"  # 2020.08.28


def create_result_count(
    used_seed: int,
    dataset: Text,
    arch_config: Dict[Text, Any],
    results: Dict[Text, Any],
    dataloader_dict: Dict[Text, Any],
) -> ResultsCount:
    xresult = ResultsCount(
        dataset,
        results["net_state_dict"],
        results["train_acc1es"],
Пример #27
0
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API

if __name__ == '__main__':
    parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
    parser.add_argument('--api_path',
                        type=str,
                        default=None,
                        help='The path to the NAS-Bench-201 benchmark file.')
    args = parser.parse_args()

    meta_file = Path(args.api_path)
    assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)

    api = API(str(meta_file))

    # This will show the results of the best architecture based on the validation set of each dataset.
    arch_index, accuracy = api.find_best('cifar10-valid', 'x-valid', None,
                                         None, False)
    print(
        'FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::'
    )
    print('arch-index={:5d}, arch={:}'.format(arch_index,
                                              api.arch(arch_index)))
    api.show(arch_index)
    print('')

    arch_index, accuracy = api.find_best('cifar100', 'x-valid', None, None,
                                         False)
    print(
Пример #28
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    #config_path = 'configs/nas-benchmark/algos/GDAS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(search_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None and not args.constrain:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 0,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    elif xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 32,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    #logger.log('search-model :\n{:}'.format(search_model))
    logger.log('model-config : {:}'.format(model_config))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()
    #network, criterion = search_model.cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    sampled_weights = []
    for epoch in range(start_epoch, total_epoch + config.t_epochs):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(
                epoch_time.val * (total_epoch - epoch + config.t_epochs),
                True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        search_model.set_tau(xargs.tau_max -
                             (xargs.tau_max - xargs.tau_min) * epoch /
                             (total_epoch - 1))
        logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(
            epoch_str, need_time, search_model.get_tau(),
            min(w_scheduler.get_lr())))
        if epoch < total_epoch:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
                      = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.bilevel)
        else:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \
                       = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))

        if (epoch + 1) % 50 == 0 and not config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
        elif (epoch + 1) == total_epoch and config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
            arch_iter = iter(weights)
        # validate with single arch
        single_weight = search_model.sample_weights(1)[0]
        single_valid_acc = AverageMeter()
        network.eval()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=single_weight)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                single_valid_acc.update(val_acc1.item(), n_val)
        logger.log('[{:}] valid : accuracy = {:.2f}'.format(
            epoch_str, single_valid_acc.avg))

        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        if epoch < total_epoch:
            genotypes[epoch] = search_model.genotype()
            logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
                epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None and epoch < total_epoch:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    network.eval()
    # Evaluate the architectures sampled throughout the search
    for i in range(len(sampled_weights) - 1):
        logger.log('Sample eval : epoch {}'.format((i + 1) * 50 - 1))
        for w in sampled_weights[i]:
            sample_valid_acc = AverageMeter()
            for i in range(10):
                try:
                    val_input, val_target = next(valid_iter)
                except Exception as e:
                    valid_iter = iter(valid_loader)
                    val_input, val_target = next(valid_iter)
                n_val = val_input.size(0)
                with torch.no_grad():
                    val_target = val_target.cuda(non_blocking=True)
                    _, logits, _ = network(val_input, weights=w)
                    val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                         val_target.data,
                                                         topk=(1, 5))
                    sample_valid_acc.update(val_acc1.item(), n_val)
            w_gene = search_model.genotype(w)
            if api is not None:
                ind = api.query_index_by_arch(w_gene)
                info = api.query_meta_info_by_index(ind)
                metrics = info.get_metrics('cifar10', 'ori-test')
                acc = metrics['accuracy']
            else:
                acc = 0.0
            logger.log(
                'sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
                    sample_valid_acc.avg, acc))
    # Evaluate the final sampling separately to find the top 10 architectures
    logger.log('Final sample eval')
    final_archs = []
    for w in sampled_weights[-1]:
        sample_valid_acc = AverageMeter()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                sample_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log('sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
            sample_valid_acc.avg, acc))
        final_archs.append((w, sample_valid_acc.avg))
    top_10 = sorted(final_archs, key=lambda x: x[1], reverse=True)[:10]
    # Evaluate the top 10 architectures on the entire validation set
    logger.log('Evaluating top archs')
    for w, prev_acc in top_10:
        full_valid_acc = AverageMeter()
        for val_input, val_target in valid_loader:
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                full_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        logger.log('genotype {}'.format(w_gene))
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log(
            'full valid : val_acc = {:.2f} test_acc = {:.2f} pval_acc = {:.2f}'
            .format(full_valid_acc.avg, acc, prev_acc))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotypes[total_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1])))
    logger.close()
Пример #29
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    #config_path = 'configs/nas-benchmark/algos/GDAS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(search_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    logger.log('search-model :\n{:}'.format(search_model))
    logger.log('model-config : {:}'.format(model_config))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        search_model.set_tau(xargs.tau_max -
                             (xargs.tau_max - xargs.tau_min) * epoch /
                             (total_epoch - 1))
        logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(
            epoch_str, need_time, search_model.get_tau(),
            min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
                  = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotypes[total_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(
            api.query_by_arch(genotypes[total_epoch - 1], '200')))
    logger.close()
Пример #30
0
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

device = torch.device("cuda:{}".format(args.gpu))
cpu_device = torch.device("cpu")

torch.cuda.set_device(args.gpu)
cudnn.deterministic = True
cudnn.enabled = True
cudnn.benchmark = False

assert args.api_path is not None, 'NAS201 data path has not been provided'
api = API(args.api_path, verbose = False)
logging.info(f'length of api: {len(api)}')

# Configuring dataset and dataloader
if args.dataset == 'cifar10':
  acc_type     = 'ori-test'
  val_acc_type = 'x-valid'
else:
  acc_type     = 'x-test'
  val_acc_type = 'x-valid'

datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
assert args.dataset in datasets, 'Incorrect dataset'
if args.cutout:
  train_data, valid_data, xshape, num_classes = get_datasets(name = args.dataset, root = args.data, cutout=args.cutout)
else: