Exemplo n.º 1
0
def check_cor_for_bandit(meta_file,
                         test_epoch,
                         use_less_or_not,
                         is_rand=True,
                         need_print=False):
    if isinstance(meta_file, API):
        api = meta_file
    else:
        api = API(str(meta_file))
    cifar10_currs = []
    cifar10_valid = []
    cifar10_test = []
    cifar100_valid = []
    cifar100_test = []
    imagenet_test = []
    imagenet_valid = []
    for idx, arch in enumerate(api):
        results = api.get_more_info(idx, 'cifar10-valid', test_epoch - 1,
                                    use_less_or_not, is_rand)
        cifar10_currs.append(results['valid-accuracy'])
        # --->>>>>
        results = api.get_more_info(idx, 'cifar10-valid', None, False, is_rand)
        cifar10_valid.append(results['valid-accuracy'])
        results = api.get_more_info(idx, 'cifar10', None, False, is_rand)
        cifar10_test.append(results['test-accuracy'])
        results = api.get_more_info(idx, 'cifar100', None, False, is_rand)
        cifar100_test.append(results['test-accuracy'])
        cifar100_valid.append(results['valid-accuracy'])
        results = api.get_more_info(idx, 'ImageNet16-120', None, False,
                                    is_rand)
        imagenet_test.append(results['test-accuracy'])
        imagenet_valid.append(results['valid-accuracy'])

    def get_cor(A, B):
        return float(np.corrcoef(A, B)[0, 1])

    cors = []
    for basestr, xlist in zip(
        ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [
            cifar10_valid, cifar10_test, cifar100_valid, cifar100_test,
            imagenet_valid, imagenet_test
        ]):
        correlation = get_cor(cifar10_currs, xlist)
        if need_print:
            print(
                'With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'
                .format(test_epoch, '012' if use_less_or_not else '200',
                        basestr, correlation))
        cors.append(correlation)
        #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
        #print('-'*200)
    #print('*'*230)
    return cors
Exemplo n.º 2
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
  logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_search_spaces('cell', xargs.search_space_name)
  model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space,
                              'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
  search_model = get_cell_based_tiny_net(model_config)
  logger.log('search-model :\n{:}'.format(search_model))
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
  a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('a-optimizer : {:}'.format(a_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  flop, param  = get_model_infos(search_model, xshape)
  #logger.log('{:}'.format(search_model))
  logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
  if xargs.arch_nas_dataset is None:
    api = None
  else:
    api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
  network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    valid_accuracies = checkpoint['valid_accuracies']
    search_model.load_state_dict( checkpoint['search_model'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    min_LR    = min(w_scheduler.get_lr())
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))

    search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
    valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
    logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
    # check the best accuracy
    valid_accuracies[epoch] = valid_a_top1
    if valid_a_top1 > valid_accuracies['best']:
      valid_accuracies['best'] = valid_a_top1
      genotypes['best']        = search_model.genotype()
      find_best = True
    else: find_best = False

    genotypes[epoch] = search_model.genotype()
    logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'a_optimizer' : a_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)
    if find_best:
      logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
      copy_checkpoint(model_base_path, model_best_path, logger)
    with torch.no_grad():
      logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*100)
  # check the performance from the architecture dataset
  logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
  logger.close()
Exemplo n.º 3
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  logger.log('use config from : {:}'.format(xargs.config_path))
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  _, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
  # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
  valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
  if hasattr(valid_loader.dataset, 'transforms'):
    valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
  # data loader
  logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_sub_search_spaces('cell', xargs.search_space_name)
  logger.log('search_space={}'.format(search_space))
  model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space,
                              'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
  shared_cnn = get_cell_based_tiny_net(model_config)
  controller = shared_cnn.create_controller()
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config)
  a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('a-optimizer : {:}'.format(a_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  #flop, param  = get_model_infos(shared_cnn, xshape)
  #logger.log('{:}'.format(shared_cnn))
  #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
  logger.log('search-space : {:}'.format(search_space))
  if xargs.arch_nas_dataset is None:
    api = None
  else:
    api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))
  shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    baseline    = checkpoint['baseline']
    valid_accuracies = checkpoint['valid_accuracies']
    shared_cnn.load_state_dict( checkpoint['shared_cnn'] )
    controller.load_state_dict( checkpoint['controller'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline))

    cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
    logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
    ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \
                                 = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \
                                                        dict2config({'baseline': baseline,
                                                                     'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate,
                                                                     'ctl_entropy_w': xargs.controller_entropy_weight, 
                                                                     'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \
                                                        epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum))
    best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
    shared_cnn.module.update_arch(best_arch)
    _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)

    genotypes[epoch] = best_arch
    # check the best accuracy
    valid_accuracies[epoch] = best_valid_acc
    if best_valid_acc > valid_accuracies['best']:
      valid_accuracies['best'] = best_valid_acc
      genotypes['best']        = best_arch
      find_best = True
    else: find_best = False

    logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'baseline'    : baseline,
                'shared_cnn'  : shared_cnn.state_dict(),
                'controller'  : controller.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'a_optimizer' : a_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)

    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*100)
  logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
  logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
  logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
  start_time = time.time()
  final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
  search_time.update(time.time() - start_time)
  shared_cnn.module.update_arch(final_arch)
  final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
  logger.log('The Selected Final Architecture : {:}'.format(final_arch))
  logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
  logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) ))
  logger.close()
Exemplo n.º 4
0
def visualize_rank_over_time(meta_file, vis_save_dir):
    print('\n' + '-' * 150)
    vis_save_dir.mkdir(parents=True, exist_ok=True)
    print('{:} start to visualize rank-over-time into {:}'.format(
        time_string(), vis_save_dir))
    cache_file_path = vis_save_dir / 'rank-over-time-cache-info.pth'
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        print('{:} load nas_bench done'.format(time_string()))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], defaultdict(
            list), defaultdict(list), defaultdict(list), defaultdict(list)
        #for iepoch in range(200): for index in range( len(nas_bench) ):
        for index in tqdm(range(len(nas_bench))):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            for iepoch in range(200):
                res = info.get_metrics('cifar10', 'train', iepoch)
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid', iepoch)
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test', iepoch)
                otest_acc = res['accuracy']
                train_accs[iepoch].append(train_acc)
                valid_accs[iepoch].append(valid_acc)
                test_accs[iepoch].append(test_acc)
                otest_accs[iepoch].append(otest_acc)
                if iepoch == 0:
                    res = info.get_comput_costs('cifar10')
                    flop, param = res['flops'], res['params']
                    flops.append(flop)
                    params.append(param)
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
    print('{:} collect data done.'.format(time_string()))
    #selected_epochs = [0, 100, 150, 180, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199]
    selected_epochs = list(range(200))
    x_xtests = test_accs[199]
    indexes = list(range(len(x_xtests)))
    ord_idxs = sorted(indexes, key=lambda i: x_xtests[i])
    for sepoch in selected_epochs:
        x_valids = valid_accs[sepoch]
        valid_ord_idxs = sorted(indexes, key=lambda i: x_valids[i])
        valid_ord_lbls = []
        for idx in ord_idxs:
            valid_ord_lbls.append(valid_ord_idxs.index(idx))
        # labeled data
        dpi, width, height = 300, 2600, 2600
        figsize = width / float(dpi), height / float(dpi)
        LabelSize, LegendFontsize = 18, 18

        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)
        plt.xlim(min(indexes), max(indexes))
        plt.ylim(min(indexes), max(indexes))
        plt.yticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize,
                   rotation='vertical')
        plt.xticks(np.arange(min(indexes), max(indexes),
                             max(indexes) // 6),
                   fontsize=LegendFontsize)
        ax.scatter(indexes,
                   valid_ord_lbls,
                   marker='^',
                   s=0.5,
                   c='tab:green',
                   alpha=0.8)
        ax.scatter(indexes,
                   indexes,
                   marker='o',
                   s=0.5,
                   c='tab:blue',
                   alpha=0.8)
        ax.scatter([-1], [-1],
                   marker='^',
                   s=100,
                   c='tab:green',
                   label='CIFAR-10 validation')
        ax.scatter([-1], [-1],
                   marker='o',
                   s=100,
                   c='tab:blue',
                   label='CIFAR-10 test')
        plt.grid(zorder=0)
        ax.set_axisbelow(True)
        plt.legend(loc='upper left', fontsize=LegendFontsize)
        ax.set_xlabel('architecture ranking in the final test accuracy',
                      fontsize=LabelSize)
        ax.set_ylabel('architecture ranking in the validation set',
                      fontsize=LabelSize)
        save_path = (vis_save_dir / 'time-{:03d}.pdf'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
        save_path = (vis_save_dir / 'time-{:03d}.png'.format(sepoch)).resolve()
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
        print('{:} save into {:}'.format(time_string(), save_path))
        plt.close('all')
Exemplo n.º 5
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
    if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
        split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
        cifar_split = load_config(split_Fpath, None, None)
        train_split, valid_split = cifar_split.train, cifar_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    elif xargs.dataset.startswith('ImageNet16'):
        split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(
            xargs.dataset)
        imagenet16_split = load_config(split_Fpath, None, None)
        train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    else:
        raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
    #config_path = 'configs/nas-benchmark/algos/SETN.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    # To split data
    train_data_v2 = deepcopy(train_data)
    train_data_v2.transform = valid_data.transform
    valid_data = train_data_v2
    search_data = SearchDataset(xargs.dataset, train_data, train_split,
                                valid_split)
    # data loader
    search_loader = torch.utils.data.DataLoader(search_data,
                                                batch_size=config.batch_size,
                                                shuffle=True,
                                                num_workers=xargs.workers,
                                                pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_data,
        batch_size=config.test_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
        num_workers=xargs.workers,
        pin_memory=True)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config(
        {
            'name': 'SETN',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    logger.log('search space : {:}'.format(search_space))
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space : {:}'.format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.select_num)
        network.module.set_cal_mode('dynamic', genotype)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        #search_model.set_cal_mode('urs')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('joint')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('select')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('arch-parameters :\n{:}'.format(
                nn.functional.softmax(search_model.arch_parameters,
                                      dim=-1).cpu()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.select_num)
    search_time.update(time.time() - start_time)
    network.module.set_cal_mode('dynamic', genotype)
    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotype))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype)))
    logger.close()
Exemplo n.º 6
0
def visualize_info(meta_file, dataset, vis_save_dir):
    print('{:} start to visualize {:} information'.format(
        time_string(), dataset))
    cache_file_path = vis_save_dir / '{:}-cache-info.pth'.format(dataset)
    if not cache_file_path.exists():
        print('Do not find cache file : {:}'.format(cache_file_path))
        nas_bench = API(str(meta_file))
        params, flops, train_accs, valid_accs, test_accs, otest_accs = [], [], [], [], [], []
        for index in range(len(nas_bench)):
            info = nas_bench.query_by_index(index, use_12epochs_result=False)
            resx = info.get_comput_costs(dataset)
            flop, param = resx['flops'], resx['params']
            if dataset == 'cifar10':
                res = info.get_metrics('cifar10', 'train')
                train_acc = res['accuracy']
                res = info.get_metrics('cifar10-valid', 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                test_acc = res['accuracy']
                res = info.get_metrics('cifar10', 'ori-test')
                otest_acc = res['accuracy']
            else:
                res = info.get_metrics(dataset, 'train')
                train_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-valid')
                valid_acc = res['accuracy']
                res = info.get_metrics(dataset, 'x-test')
                test_acc = res['accuracy']
                res = info.get_metrics(dataset, 'ori-test')
                otest_acc = res['accuracy']
            if index == 11472:  # resnet
                resnet = {
                    'params': param,
                    'flops': flop,
                    'index': 11472,
                    'train_acc': train_acc,
                    'valid_acc': valid_acc,
                    'test_acc': test_acc,
                    'otest_acc': otest_acc
                }
            flops.append(flop)
            params.append(param)
            train_accs.append(train_acc)
            valid_accs.append(valid_acc)
            test_accs.append(test_acc)
            otest_accs.append(otest_acc)
        #resnet = {'params': 0.559, 'flops': 78.56, 'index': 11472, 'train_acc': 99.99, 'valid_acc': 90.84, 'test_acc': 93.97}
        info = {
            'params': params,
            'flops': flops,
            'train_accs': train_accs,
            'valid_accs': valid_accs,
            'test_accs': test_accs,
            'otest_accs': otest_accs
        }
        info['resnet'] = resnet
        torch.save(info, cache_file_path)
    else:
        print('Find cache file : {:}'.format(cache_file_path))
        info = torch.load(cache_file_path)
        params, flops, train_accs, valid_accs, test_accs, otest_accs = info[
            'params'], info['flops'], info['train_accs'], info[
                'valid_accs'], info['test_accs'], info['otest_accs']
        resnet = info['resnet']
    print('{:} collect data done.'.format(time_string()))

    indexes = list(range(len(params)))
    dpi, width, height = 300, 2600, 2600
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize = 22, 22
    resnet_scale, resnet_alpha = 120, 0.5

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, valid_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['valid_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=0.4)
    plt.grid(zorder=0)
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the validation accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-valid.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-test.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xticks(np.arange(0, 1.6, 0.3), fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(20, 100)
        plt.yticks(np.arange(20, 101, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(25, 76)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    ax.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['params']], [resnet['train_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('#parameters (MB)', fontsize=LabelSize)
    ax.set_ylabel('the trarining accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-param-vs-train.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    plt.xlim(0, max(indexes))
    plt.xticks(np.arange(min(indexes), max(indexes),
                         max(indexes) // 5),
               fontsize=LegendFontsize)
    if dataset == 'cifar10':
        plt.ylim(50, 100)
        plt.yticks(np.arange(50, 101, 10), fontsize=LegendFontsize)
    elif dataset == 'cifar100':
        plt.ylim(25, 75)
        plt.yticks(np.arange(25, 76, 10), fontsize=LegendFontsize)
    else:
        plt.ylim(0, 50)
        plt.yticks(np.arange(0, 51, 10), fontsize=LegendFontsize)
    ax.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
    ax.scatter([resnet['index']], [resnet['test_acc']],
               marker='*',
               s=resnet_scale,
               c='tab:orange',
               label='resnet',
               alpha=resnet_alpha)
    plt.grid()
    ax.set_axisbelow(True)
    plt.legend(loc=4, fontsize=LegendFontsize)
    ax.set_xlabel('architecture ID', fontsize=LabelSize)
    ax.set_ylabel('the test accuracy (%)', fontsize=LabelSize)
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.pdf'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
    save_path = (vis_save_dir /
                 '{:}-test-over-ID.png'.format(dataset)).resolve()
    fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
    print('{:} save into {:}'.format(time_string(), save_path))
    plt.close('all')
Exemplo n.º 7
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger4(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
                                          (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    # search_space = get_search_spaces('cell', xargs.search_space_name)
    search_space = get_sub_search_spaces('cell', xargs.search_space_name)
    logger.log('search_space={}'.format(search_space))
    model_config = dict2config(
        {
            'name': 'RANDOM',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.parameters(), config)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    if xargs.arch_nas_dataset is None: api = None
    else: api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
        search_w_loss, search_w_top1, search_w_top5 = search_func(
            search_loader, network, criterion, w_scheduler, w_optimizer,
            epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        cur_arch, cur_valid_acc = search_find_best(valid_loader, network,
                                                   xargs.select_num)
        logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(
            epoch_str, cur_arch, cur_valid_acc))
        genotypes[epoch] = cur_arch
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            find_best = True
        else:
            find_best = False

        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            pdb.set_trace()
            copy_checkpoint(model_base_path, model_best_path, logger)
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 200)
    logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
    start_time = time.time()
    best_arch, best_acc = search_find_best(valid_loader, network,
                                           xargs.select_num)
    search_time.update(time.time() - start_time)
    logger.log(
        'RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'
        .format(best_arch, best_acc, search_time.sum))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch)))
    logger.close()
Exemplo n.º 8
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    print(config)
    search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
                                          (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config(
        {
            'name': 'SPOS',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    logger.log('search space : {:}'.format(search_space))
    model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        model.get_weights(), config)
    a_optimizer = torch.optim.Adam(model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(model, xshape)
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space : {:}'.format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(model).cuda(), criterion.cuda()

    checkpoint_path = 'output/search-cell-nas-bench-102/result-{}/checkpoint/seed-{}_epoch-{}.pth'.format(
        xargs.dataset, xargs.rand_seed, xargs.epoch)
    if checkpoint_path is not None:  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint from {}".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['search_model'])

    # start inference
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    all_archs = network.module.get_all_archs()
    random.shuffle(all_archs)

    valid_accuracies = {}
    process_start_time = time.time()
    for i, genotype in enumerate(all_archs):
        network.module.set_cal_mode('dynamic', genotype)
        recalculate_bn(network, search_loader)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(i, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
        valid_accuracies[genotype.tostr()] = valid_a_top1
    process_end_time = time.time()
    logger.log('process time: {}'.format(process_end_time -
                                         process_start_time))

    torch.save(valid_accuracies, '{}/result.dat'.format(xargs.save_dir))
    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'SPOS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotype))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype)))
    logger.close()
Exemplo n.º 9
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
    split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
    cifar_split = load_config(split_Fpath, None, None)
    train_split, valid_split = cifar_split.train, cifar_split.valid
    logger.log('Load split file from {:}'.format(split_Fpath))
  #elif xargs.dataset.startswith('ImageNet16'):
  #  # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
  #  # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
  #  # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
  #  # _ = configure2str(imagenet16_split, 'temp.txt')
  #  split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
  #  imagenet16_split = load_config(split_Fpath, None, None)
  #  train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
  #  logger.log('Load split file from {:}'.format(split_Fpath))
  else:
    raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  logger.log('config : {:}'.format(config))
  # To split data
  train_data_v2 = deepcopy(train_data)
  train_data_v2.transform = valid_data.transform
  valid_data    = train_data_v2
  search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
  # data loader
  search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
  valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
  logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_search_spaces('cell', xargs.search_space_name)
  model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space}, None)
  search_model = get_cell_based_tiny_net(model_config)
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  if xargs.arch_nas_dataset is None: api = None
  else                             : api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
  network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    valid_accuracies = checkpoint['valid_accuracies']
    search_model.load_state_dict( checkpoint['search_model'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))

    # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
    search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
    valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
    logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
    cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
    logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
    genotypes[epoch] = cur_arch
    # check the best accuracy
    valid_accuracies[epoch] = valid_a_top1
    if valid_a_top1 > valid_accuracies['best']:
      valid_accuracies['best'] = valid_a_top1
      find_best = True
    else: find_best = False

    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)
    if find_best:
      logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
      copy_checkpoint(model_base_path, model_best_path, logger)
    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*200)
  logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
  start_time = time.time()
  best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
  search_time.update(time.time() - start_time)
  logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
  logger.close()
Exemplo n.º 10
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled   = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads( xargs.workers )
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
    # class_num = 4
    # xshape = (1,3,88,88)
    # print(xshape)
    if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
        split_Fpath = '/home/city/Projects/NAS-Projects/configs/nas-benchmark/cifar-split.txt'
        cifar_split = load_config(split_Fpath, None, None)
        train_split, valid_split = cifar_split.train, cifar_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    elif xargs.dataset.startswith('ImageNet16'):
        split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
        imagenet16_split = load_config(split_Fpath, None, None)
        train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    else:
        raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
    config_path = '/home/city/Projects/NAS-Projects/configs/nas-benchmark/algos/DARTS.config'
    config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
    print('config')
    print(config)
    # To split data
    train_data_v2 = deepcopy(train_data)
    train_data_v2.transform = valid_data.transform
    valid_data    = train_data_v2
    search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
    # data loader
    search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
    valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
    logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

    # train_transform = transforms.Compose([
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor()
    #     # transforms.Normalize(mean=[128, 128, 128], std=[50, 50, 50])
    # ])
    # val_transform = transforms.Compose([
    #     # transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor()
    #     # transforms.Normalize(mean=[128, 128, 128], std=[50, 50, 50])
    # ])
    #
    # train_data = datasets.ImageFolder(root='/home/city/Projects/build_assessment/data/train',
    #                                   transform=train_transform)
    # valid_data = datasets.ImageFolder(root='/home/city/Projects/build_assessment/data/val',
    #                                   transform=val_transform)
    # print(len(train_data))
    # print('2333333333333333333333333333333')
    # train_split = []
    # valid_split = []
    #
    # for i in range(len(train_data)):
    #     if i%2==0:
    #         train_split.append(i)
    #     else:
    #         valid_split.append(i)
    # search_data = SearchDataset('builds', train_data, train_split, valid_split)
    #
    # search_loader = torch.utils.data.DataLoader(search_data,
    #                                              batch_size=32, shuffle=True,
    #                                              num_workers=4, pin_memory=True
    #                                              )
    # valid_loader = torch.utils.data.DataLoader(valid_data,
    #                                              batch_size=32, shuffle=True,
    #                                              num_workers=2, pin_memory=True
    #                                              )



    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
                                'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                                'space'    : search_space}, None)
    search_model = get_cell_based_tiny_net(model_config)
    logger.log('search-model :\n{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param  = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

    if last_info.exists(): # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
        last_info   = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint  = torch.load(last_info['last_checkpoint'])
        genotypes   = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict( checkpoint['search_model'] )
        w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
        w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
        a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
        logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        min_LR    = min(w_scheduler.get_lr())
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))

        search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
        valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best']        = search_model.genotype()
            op_list, _ = genotypes['best'].tolist(remove_str=None)
            find_best = True
            best_arch_nums = op_list2str(op_list)
            torch.save(search_model,'/home/city/disk/log/builds-darts/darts2_%04d_%s_%s_%.2f.pth' %(epoch,time_string_short(),best_arch_nums, valid_a_top1))
            print('/home/city/disk/log/builds-darts/darts2_%04d_%s_%s_%.2f.pth' %(epoch,time_string_short(),best_arch_nums, valid_a_top1))
        else: find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint({'epoch' : epoch + 1,
                                     'args'  : deepcopy(xargs),
                                     'search_model': search_model.state_dict(),
                                     'w_optimizer' : w_optimizer.state_dict(),
                                     'a_optimizer' : a_optimizer.state_dict(),
                                     'w_scheduler' : w_scheduler.state_dict(),
                                     'genotypes'   : genotypes,
                                     'valid_accuracies' : valid_accuracies},
                                    model_base_path, logger)
        last_info = save_checkpoint({
            'epoch': epoch + 1,
            'args' : deepcopy(args),
            'last_checkpoint': save_path,
        }, logger.path('info'), logger)
        if find_best:
            logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
            logger.log('arch :\n{:}'.format(nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu().argmax(dim=-1)))
        if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-'*100)
    # check the performance from the architecture dataset
    logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
    if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
    logger.close()