Пример #1
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  logger.log('use config from : {:}'.format(xargs.config_path))
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  _, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
  # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
  valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
  if hasattr(valid_loader.dataset, 'transforms'):
    valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
  # data loader
  logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_sub_search_spaces('cell', xargs.search_space_name)
  logger.log('search_space={}'.format(search_space))
  model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space,
                              'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
  shared_cnn = get_cell_based_tiny_net(model_config)
  controller = shared_cnn.create_controller()
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config)
  a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('a-optimizer : {:}'.format(a_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  #flop, param  = get_model_infos(shared_cnn, xshape)
  #logger.log('{:}'.format(shared_cnn))
  #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
  logger.log('search-space : {:}'.format(search_space))
  if xargs.arch_nas_dataset is None:
    api = None
  else:
    api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))
  shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    baseline    = checkpoint['baseline']
    valid_accuracies = checkpoint['valid_accuracies']
    shared_cnn.load_state_dict( checkpoint['shared_cnn'] )
    controller.load_state_dict( checkpoint['controller'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline))

    cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
    logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
    ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \
                                 = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \
                                                        dict2config({'baseline': baseline,
                                                                     'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate,
                                                                     'ctl_entropy_w': xargs.controller_entropy_weight, 
                                                                     'ctl_bl_dec'   : xargs.controller_bl_dec}, None), \
                                                        epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum))
    best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
    shared_cnn.module.update_arch(best_arch)
    _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)

    genotypes[epoch] = best_arch
    # check the best accuracy
    valid_accuracies[epoch] = best_valid_acc
    if best_valid_acc > valid_accuracies['best']:
      valid_accuracies['best'] = best_valid_acc
      genotypes['best']        = best_arch
      find_best = True
    else: find_best = False

    logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'baseline'    : baseline,
                'shared_cnn'  : shared_cnn.state_dict(),
                'controller'  : controller.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'a_optimizer' : a_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)

    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*100)
  logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
  logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
  logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
  start_time = time.time()
  final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
  search_time.update(time.time() - start_time)
  shared_cnn.module.update_arch(final_arch)
  final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
  logger.log('The Selected Final Architecture : {:}'.format(final_arch))
  logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
  logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) ))
  logger.close()
Пример #2
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
  logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_search_spaces('cell', xargs.search_space_name)
  model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space,
                              'affine'   : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
  search_model = get_cell_based_tiny_net(model_config)
  logger.log('search-model :\n{:}'.format(search_model))
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
  a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('a-optimizer : {:}'.format(a_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  flop, param  = get_model_infos(search_model, xshape)
  #logger.log('{:}'.format(search_model))
  logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
  if xargs.arch_nas_dataset is None:
    api = None
  else:
    api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
  network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    valid_accuracies = checkpoint['valid_accuracies']
    search_model.load_state_dict( checkpoint['search_model'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    min_LR    = min(w_scheduler.get_lr())
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))

    search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
    valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
    logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
    # check the best accuracy
    valid_accuracies[epoch] = valid_a_top1
    if valid_a_top1 > valid_accuracies['best']:
      valid_accuracies['best'] = valid_a_top1
      genotypes['best']        = search_model.genotype()
      find_best = True
    else: find_best = False

    genotypes[epoch] = search_model.genotype()
    logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'a_optimizer' : a_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)
    if find_best:
      logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
      copy_checkpoint(model_base_path, model_best_path, logger)
    with torch.no_grad():
      logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*100)
  # check the performance from the architecture dataset
  logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
  logger.close()
Пример #3
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
    if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
        split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
        cifar_split = load_config(split_Fpath, None, None)
        train_split, valid_split = cifar_split.train, cifar_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    elif xargs.dataset.startswith('ImageNet16'):
        split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(
            xargs.dataset)
        imagenet16_split = load_config(split_Fpath, None, None)
        train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    else:
        raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
    #config_path = 'configs/nas-benchmark/algos/SETN.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    # To split data
    train_data_v2 = deepcopy(train_data)
    train_data_v2.transform = valid_data.transform
    valid_data = train_data_v2
    search_data = SearchDataset(xargs.dataset, train_data, train_split,
                                valid_split)
    # data loader
    search_loader = torch.utils.data.DataLoader(search_data,
                                                batch_size=config.batch_size,
                                                shuffle=True,
                                                num_workers=xargs.workers,
                                                pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_data,
        batch_size=config.test_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
        num_workers=xargs.workers,
        pin_memory=True)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config(
        {
            'name': 'SETN',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    logger.log('search space : {:}'.format(search_space))
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space : {:}'.format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, search_a_loss, search_a_top1, search_a_top5))

        genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                                xargs.select_num)
        network.module.set_cal_mode('dynamic', genotype)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5,
                    genotype))
        #search_model.set_cal_mode('urs')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('joint')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        #search_model.set_cal_mode('select')
        #valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1

        genotypes[epoch] = genotype
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        with torch.no_grad():
            logger.log('arch-parameters :\n{:}'.format(
                nn.functional.softmax(search_model.arch_parameters,
                                      dim=-1).cpu()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # the final post procedure : count the time
    start_time = time.time()
    genotype, temp_accuracy = get_best_arch(valid_loader, network,
                                            xargs.select_num)
    search_time.update(time.time() - start_time)
    network.module.set_cal_mode('dynamic', genotype)
    valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
        valid_loader, network, criterion)
    logger.log(
        'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'
        .format(genotype, valid_a_top1))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotype))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype)))
    logger.close()
Пример #4
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger4(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
                                          (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    # search_space = get_search_spaces('cell', xargs.search_space_name)
    search_space = get_sub_search_spaces('cell', xargs.search_space_name)
    logger.log('search_space={}'.format(search_space))
    model_config = dict2config(
        {
            'name': 'RANDOM',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    search_model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.parameters(), config)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    if xargs.arch_nas_dataset is None: api = None
    else: api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
        search_w_loss, search_w_top1, search_w_top5 = search_func(
            search_loader, network, criterion, w_scheduler, w_optimizer,
            epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        cur_arch, cur_valid_acc = search_find_best(valid_loader, network,
                                                   xargs.select_num)
        logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(
            epoch_str, cur_arch, cur_valid_acc))
        genotypes[epoch] = cur_arch
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            find_best = True
        else:
            find_best = False

        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            pdb.set_trace()
            copy_checkpoint(model_base_path, model_best_path, logger)
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 200)
    logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
    start_time = time.time()
    best_arch, best_acc = search_find_best(valid_loader, network,
                                           xargs.select_num)
    search_time.update(time.time() - start_time)
    logger.log(
        'RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'
        .format(best_arch, best_acc, search_time.sum))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch)))
    logger.close()
Пример #5
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    print(config)
    search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
                                          (config.batch_size, config.test_batch_size), xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
        .format(xargs.dataset, len(search_loader), len(valid_loader),
                config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config(
        {
            'name': 'SPOS',
            'C': xargs.channel,
            'N': xargs.num_cells,
            'max_nodes': xargs.max_nodes,
            'num_classes': class_num,
            'space': search_space,
            'affine': False,
            'track_running_stats': bool(xargs.track_running_stats)
        }, None)
    logger.log('search space : {:}'.format(search_space))
    model = get_cell_based_tiny_net(model_config)

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        model.get_weights(), config)
    a_optimizer = torch.optim.Adam(model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(model, xshape)
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space : {:}'.format(search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(model).cuda(), criterion.cuda()

    checkpoint_path = 'output/search-cell-nas-bench-102/result-{}/checkpoint/seed-{}_epoch-{}.pth'.format(
        xargs.dataset, xargs.rand_seed, xargs.epoch)
    if checkpoint_path is not None:  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint from {}".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['search_model'])

    # start inference
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    all_archs = network.module.get_all_archs()
    random.shuffle(all_archs)

    valid_accuracies = {}
    process_start_time = time.time()
    for i, genotype in enumerate(all_archs):
        network.module.set_cal_mode('dynamic', genotype)
        recalculate_bn(network, search_loader)
        valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
            valid_loader, network, criterion)
        logger.log(
            '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'
            .format(i, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
        valid_accuracies[genotype.tostr()] = valid_a_top1
    process_end_time = time.time()
    logger.log('process time: {}'.format(process_end_time -
                                         process_start_time))

    torch.save(valid_accuracies, '{}/result.dat'.format(xargs.save_dir))
    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'SPOS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotype))
    if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype)))
    logger.close()
Пример #6
0
def main(xargs):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
    split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
    cifar_split = load_config(split_Fpath, None, None)
    train_split, valid_split = cifar_split.train, cifar_split.valid
    logger.log('Load split file from {:}'.format(split_Fpath))
  #elif xargs.dataset.startswith('ImageNet16'):
  #  # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes)
  #  # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :])
  #  # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None)
  #  # _ = configure2str(imagenet16_split, 'temp.txt')
  #  split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
  #  imagenet16_split = load_config(split_Fpath, None, None)
  #  train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
  #  logger.log('Load split file from {:}'.format(split_Fpath))
  else:
    raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
  config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  logger.log('config : {:}'.format(config))
  # To split data
  train_data_v2 = deepcopy(train_data)
  train_data_v2.transform = valid_data.transform
  valid_data    = train_data_v2
  search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
  # data loader
  search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
  valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
  logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

  search_space = get_search_spaces('cell', xargs.search_space_name)
  model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
                              'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                              'space'    : search_space}, None)
  search_model = get_cell_based_tiny_net(model_config)
  
  w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
  logger.log('w-optimizer : {:}'.format(w_optimizer))
  logger.log('w-scheduler : {:}'.format(w_scheduler))
  logger.log('criterion   : {:}'.format(criterion))
  if xargs.arch_nas_dataset is None: api = None
  else                             : api = API(xargs.arch_nas_dataset)
  logger.log('{:} create API = {:} done'.format(time_string(), api))

  last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
  network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

  if last_info.exists(): # automatically resume from previous checkpoint
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info   = torch.load(last_info)
    start_epoch = last_info['epoch']
    checkpoint  = torch.load(last_info['last_checkpoint'])
    genotypes   = checkpoint['genotypes']
    valid_accuracies = checkpoint['valid_accuracies']
    search_model.load_state_dict( checkpoint['search_model'] )
    w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
    w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
    logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

  # start training
  start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
  for epoch in range(start_epoch, total_epoch):
    w_scheduler.update(epoch, 0.0)
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
    epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
    logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))

    # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
    search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
    search_time.update(time.time() - start_time)
    logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
    valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
    logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
    cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
    logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
    genotypes[epoch] = cur_arch
    # check the best accuracy
    valid_accuracies[epoch] = valid_a_top1
    if valid_a_top1 > valid_accuracies['best']:
      valid_accuracies['best'] = valid_a_top1
      find_best = True
    else: find_best = False

    # save checkpoint
    save_path = save_checkpoint({'epoch' : epoch + 1,
                'args'  : deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer' : w_optimizer.state_dict(),
                'w_scheduler' : w_scheduler.state_dict(),
                'genotypes'   : genotypes,
                'valid_accuracies' : valid_accuracies},
                model_base_path, logger)
    last_info = save_checkpoint({
          'epoch': epoch + 1,
          'args' : deepcopy(args),
          'last_checkpoint': save_path,
          }, logger.path('info'), logger)
    if find_best:
      logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
      copy_checkpoint(model_base_path, model_best_path, logger)
    if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.log('\n' + '-'*200)
  logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
  start_time = time.time()
  best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
  search_time.update(time.time() - start_time)
  logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
  if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
  logger.close()
Пример #7
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled   = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads( xargs.workers )
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
    # class_num = 4
    # xshape = (1,3,88,88)
    # print(xshape)
    if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
        split_Fpath = '/home/city/Projects/NAS-Projects/configs/nas-benchmark/cifar-split.txt'
        cifar_split = load_config(split_Fpath, None, None)
        train_split, valid_split = cifar_split.train, cifar_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    elif xargs.dataset.startswith('ImageNet16'):
        split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
        imagenet16_split = load_config(split_Fpath, None, None)
        train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
        logger.log('Load split file from {:}'.format(split_Fpath))
    else:
        raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
    config_path = '/home/city/Projects/NAS-Projects/configs/nas-benchmark/algos/DARTS.config'
    config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
    print('config')
    print(config)
    # To split data
    train_data_v2 = deepcopy(train_data)
    train_data_v2.transform = valid_data.transform
    valid_data    = train_data_v2
    search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
    # data loader
    search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
    valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
    logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))

    # train_transform = transforms.Compose([
    #     transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor()
    #     # transforms.Normalize(mean=[128, 128, 128], std=[50, 50, 50])
    # ])
    # val_transform = transforms.Compose([
    #     # transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor()
    #     # transforms.Normalize(mean=[128, 128, 128], std=[50, 50, 50])
    # ])
    #
    # train_data = datasets.ImageFolder(root='/home/city/Projects/build_assessment/data/train',
    #                                   transform=train_transform)
    # valid_data = datasets.ImageFolder(root='/home/city/Projects/build_assessment/data/val',
    #                                   transform=val_transform)
    # print(len(train_data))
    # print('2333333333333333333333333333333')
    # train_split = []
    # valid_split = []
    #
    # for i in range(len(train_data)):
    #     if i%2==0:
    #         train_split.append(i)
    #     else:
    #         valid_split.append(i)
    # search_data = SearchDataset('builds', train_data, train_split, valid_split)
    #
    # search_loader = torch.utils.data.DataLoader(search_data,
    #                                              batch_size=32, shuffle=True,
    #                                              num_workers=4, pin_memory=True
    #                                              )
    # valid_loader = torch.utils.data.DataLoader(valid_data,
    #                                              batch_size=32, shuffle=True,
    #                                              num_workers=2, pin_memory=True
    #                                              )



    search_space = get_search_spaces('cell', xargs.search_space_name)
    model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
                                'max_nodes': xargs.max_nodes, 'num_classes': class_num,
                                'space'    : search_space}, None)
    search_model = get_cell_based_tiny_net(model_config)
    logger.log('search-model :\n{:}'.format(search_model))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param  = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()

    if last_info.exists(): # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
        last_info   = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint  = torch.load(last_info['last_checkpoint'])
        genotypes   = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict( checkpoint['search_model'] )
        w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
        w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
        a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
        logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        min_LR    = min(w_scheduler.get_lr())
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))

        search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
        valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion)
        logger.log('[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best']        = search_model.genotype()
            op_list, _ = genotypes['best'].tolist(remove_str=None)
            find_best = True
            best_arch_nums = op_list2str(op_list)
            torch.save(search_model,'/home/city/disk/log/builds-darts/darts2_%04d_%s_%s_%.2f.pth' %(epoch,time_string_short(),best_arch_nums, valid_a_top1))
            print('/home/city/disk/log/builds-darts/darts2_%04d_%s_%s_%.2f.pth' %(epoch,time_string_short(),best_arch_nums, valid_a_top1))
        else: find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint({'epoch' : epoch + 1,
                                     'args'  : deepcopy(xargs),
                                     'search_model': search_model.state_dict(),
                                     'w_optimizer' : w_optimizer.state_dict(),
                                     'a_optimizer' : a_optimizer.state_dict(),
                                     'w_scheduler' : w_scheduler.state_dict(),
                                     'genotypes'   : genotypes,
                                     'valid_accuracies' : valid_accuracies},
                                    model_base_path, logger)
        last_info = save_checkpoint({
            'epoch': epoch + 1,
            'args' : deepcopy(args),
            'last_checkpoint': save_path,
        }, logger.path('info'), logger)
        if find_best:
            logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
            logger.log('arch :\n{:}'.format(nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu().argmax(dim=-1)))
        if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-'*100)
    # check the performance from the architecture dataset
    logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
    if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
    logger.close()