def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) logger.log('use config from : {:}'.format(xargs.config_path)) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) _, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) if hasattr(valid_loader.dataset, 'transforms'): valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms) # data loader logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_sub_search_spaces('cell', xargs.search_space_name) logger.log('search_space={}'.format(search_space)) model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space, 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config) a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) #flop, param = get_model_infos(shared_cnn, xshape) #logger.log('{:}'.format(shared_cnn)) #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] baseline = checkpoint['baseline'] valid_accuracies = checkpoint['valid_accuracies'] shared_cnn.load_state_dict( checkpoint['shared_cnn'] ) controller.load_state_dict( checkpoint['controller'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None # start training start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \ = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \ dict2config({'baseline': baseline, 'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate, 'ctl_entropy_w': xargs.controller_entropy_weight, 'ctl_bl_dec' : xargs.controller_bl_dec}, None), \ epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum)) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) shared_cnn.module.update_arch(best_arch) _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) genotypes[epoch] = best_arch # check the best accuracy valid_accuracies[epoch] = best_valid_acc if best_valid_acc > valid_accuracies['best']: valid_accuracies['best'] = best_valid_acc genotypes['best'] = best_arch find_best = True else: find_best = False logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint({'epoch' : epoch + 1, 'args' : deepcopy(xargs), 'baseline' : baseline, 'shared_cnn' : shared_cnn.state_dict(), 'controller' : controller.state_dict(), 'w_optimizer' : w_optimizer.state_dict(), 'a_optimizer' : a_optimizer.state_dict(), 'w_scheduler' : w_scheduler.state_dict(), 'genotypes' : genotypes, 'valid_accuracies' : valid_accuracies}, model_base_path, logger) last_info = save_checkpoint({ 'epoch': epoch + 1, 'args' : deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) start_time = time.time() final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) search_time.update(time.time() - start_time) shared_cnn.module.update_arch(final_arch) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) logger.log('The Selected Final Architecture : {:}'.format(final_arch)) logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch)) if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) )) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space, 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict( checkpoint['search_model'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()} # start training start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) min_LR = min(w_scheduler.get_lr()) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR)) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint({'epoch' : epoch + 1, 'args' : deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer' : w_optimizer.state_dict(), 'a_optimizer' : a_optimizer.state_dict(), 'w_scheduler' : w_scheduler.state_dict(), 'genotypes' : genotypes, 'valid_accuracies' : valid_accuracies}, model_base_path, logger) last_info = save_checkpoint({ 'epoch': epoch + 1, 'args' : deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) # check the performance from the architecture dataset logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) elif xargs.dataset.startswith('ImageNet16'): split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format( xargs.dataset) imagenet16_split = load_config(split_Fpath, None, None) train_split, valid_split = imagenet16_split.train, imagenet16_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) #config_path = 'configs/nas-benchmark/algos/SETN.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True, num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) #search_model.set_cal_mode('urs') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('joint') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('select') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( 'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.' .format(genotype, valid_a_top1)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger4(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) # search_space = get_search_spaces('cell', xargs.search_space_name) search_space = get_sub_search_spaces('cell', xargs.search_space_name) logger.log('search_space={}'.format(search_space)) model_config = dict2config( { 'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.parameters(), config) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format( epoch_str, cur_arch, cur_valid_acc)) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 find_best = True else: find_best = False # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) pdb.set_trace() copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 200) logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) start_time = time.time() best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log( 'RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.' .format(best_arch, best_acc, search_time.sum)) if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) print(config) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( model.get_weights(), config) a_optimizer = torch.optim.Adam(model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(model).cuda(), criterion.cuda() checkpoint_path = 'output/search-cell-nas-bench-102/result-{}/checkpoint/seed-{}_epoch-{}.pth'.format( xargs.dataset, xargs.rand_seed, xargs.epoch) if checkpoint_path is not None: # automatically resume from previous checkpoint logger.log("=> loading checkpoint from {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['search_model']) # start inference start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup all_archs = network.module.get_all_archs() random.shuffle(all_archs) valid_accuracies = {} process_start_time = time.time() for i, genotype in enumerate(all_archs): network.module.set_cal_mode('dynamic', genotype) recalculate_bn(network, search_loader) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(i, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) valid_accuracies[genotype.tostr()] = valid_a_top1 process_end_time = time.time() logger.log('process time: {}'.format(process_end_time - process_start_time)) torch.save(valid_accuracies, '{}/result.dat'.format(xargs.save_dir)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'SPOS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) #elif xargs.dataset.startswith('ImageNet16'): # # all_indexes = list(range(len(train_data))) ; random.seed(111) ; random.shuffle(all_indexes) # # train_split, valid_split = sorted(all_indexes[: len(train_data)//2]), sorted(all_indexes[len(train_data)//2 :]) # # imagenet16_split = dict2config({'train': train_split, 'valid': valid_split}, None) # # _ = configure2str(imagenet16_split, 'temp.txt') # split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) # imagenet16_split = load_config(split_Fpath, None, None) # train_split, valid_split = imagenet16_split.train, imagenet16_split.valid # logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) logger.log('config : {:}'.format(config)) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space}, None) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) if xargs.arch_nas_dataset is None: api = None else : api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict( checkpoint['search_model'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc)) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 find_best = True else: find_best = False # save checkpoint save_path = save_checkpoint({'epoch' : epoch + 1, 'args' : deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer' : w_optimizer.state_dict(), 'w_scheduler' : w_scheduler.state_dict(), 'genotypes' : genotypes, 'valid_accuracies' : valid_accuracies}, model_base_path, logger) last_info = save_checkpoint({ 'epoch': epoch + 1, 'args' : deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*200) logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) start_time = time.time() best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) # class_num = 4 # xshape = (1,3,88,88) # print(xshape) if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = '/home/city/Projects/NAS-Projects/configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) elif xargs.dataset.startswith('ImageNet16'): split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) imagenet16_split = load_config(split_Fpath, None, None) train_split, valid_split = imagenet16_split.train, imagenet16_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) config_path = '/home/city/Projects/NAS-Projects/configs/nas-benchmark/algos/DARTS.config' config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) print('config') print(config) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) # train_transform = transforms.Compose([ # transforms.RandomHorizontalFlip(), # transforms.ToTensor() # # transforms.Normalize(mean=[128, 128, 128], std=[50, 50, 50]) # ]) # val_transform = transforms.Compose([ # # transforms.RandomHorizontalFlip(), # transforms.ToTensor() # # transforms.Normalize(mean=[128, 128, 128], std=[50, 50, 50]) # ]) # # train_data = datasets.ImageFolder(root='/home/city/Projects/build_assessment/data/train', # transform=train_transform) # valid_data = datasets.ImageFolder(root='/home/city/Projects/build_assessment/data/val', # transform=val_transform) # print(len(train_data)) # print('2333333333333333333333333333333') # train_split = [] # valid_split = [] # # for i in range(len(train_data)): # if i%2==0: # train_split.append(i) # else: # valid_split.append(i) # search_data = SearchDataset('builds', train_data, train_split, valid_split) # # search_loader = torch.utils.data.DataLoader(search_data, # batch_size=32, shuffle=True, # num_workers=4, pin_memory=True # ) # valid_loader = torch.utils.data.DataLoader(valid_data, # batch_size=32, shuffle=True, # num_workers=2, pin_memory=True # ) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space}, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict( checkpoint['search_model'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) min_LR = min(w_scheduler.get_lr()) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR)) search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() op_list, _ = genotypes['best'].tolist(remove_str=None) find_best = True best_arch_nums = op_list2str(op_list) torch.save(search_model,'/home/city/disk/log/builds-darts/darts2_%04d_%s_%s_%.2f.pth' %(epoch,time_string_short(),best_arch_nums, valid_a_top1)) print('/home/city/disk/log/builds-darts/darts2_%04d_%s_%s_%.2f.pth' %(epoch,time_string_short(),best_arch_nums, valid_a_top1)) else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint({'epoch' : epoch + 1, 'args' : deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer' : w_optimizer.state_dict(), 'a_optimizer' : a_optimizer.state_dict(), 'w_scheduler' : w_scheduler.state_dict(), 'genotypes' : genotypes, 'valid_accuracies' : valid_accuracies}, model_base_path, logger) last_info = save_checkpoint({ 'epoch': epoch + 1, 'args' : deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('arch :\n{:}'.format(nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu().argmax(dim=-1))) if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) # check the performance from the architecture dataset logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) logger.close()