else: acc_type = 'x-test' val_acc_type = 'x-valid' datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] assert args.dataset in datasets, 'Incorrect dataset' if args.cutout: train_data, valid_data, xshape, num_classes = get_datasets(name = args.dataset, root = args.data, cutout=args.cutout) else: train_data, valid_data, xshape, num_classes = get_datasets(name = args.dataset, root = args.data, cutout=-1) logging.info("train data len: {}, valid data len: {}, xshape: {}, #classes: {}".format(len(train_data), len(valid_data), xshape, num_classes)) config = load_config(path=args.config_path, extra={'class_num': num_classes, 'xshape': xshape}, logger=None) logging.info(f'config: {config}') _, train_loader, valid_loader = get_nas_search_loaders(train_data=train_data, valid_data=valid_data, dataset=args.dataset, config_root='configs', batch_size=(args.batch_size, args.valid_batch_size), workers=args.workers) train_queue, valid_queue = train_loader, valid_loader logging.info('search_loader: {}, valid_loader: {}'.format(len(train_queue), len(valid_queue))) # Model Initialization #model_config = {'C': 16, 'N': 5, 'num_classes': num_classes, 'max_nodes': 4, 'search_space': NAS_BENCH_201, 'affine': False} model = TinyNetwork(C = args.init_channels, N = args.num_cells, max_nodes = args.max_nodes, num_classes = num_classes, search_space = NAS_BENCH_201, affine = False, track_running_stats = args.track_running_stats) model = model.to(device) #logging.info(model) optimizer, _, criterion = get_optim_scheduler(parameters=model.get_weights(), config=config) criterion = criterion.cuda() logging.info(f'optimizer: {optimizer}\nCriterion: {criterion}')
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.parameters(), config) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format( epoch_str, cur_arch, cur_valid_acc)) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 find_best = True else: find_best = False # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 200) logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) start_time = time.time() best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log( 'RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.' .format(best_arch, best_acc, search_time.sum)) if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None and not args.constrain: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'inp_size': 0, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) elif xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'inp_size': 32, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) #logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() #network, criterion = search_model.cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup sampled_weights = [] for epoch in range(start_epoch, total_epoch + config.t_epochs): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time( epoch_time.val * (total_epoch - epoch + config.t_epochs), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format( epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()))) if epoch < total_epoch: search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.bilevel) else: search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \ = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) if (epoch + 1) % 50 == 0 and not config.t_epochs: weights = search_model.sample_weights(100) sampled_weights.append(weights) elif (epoch + 1) == total_epoch and config.t_epochs: weights = search_model.sample_weights(100) sampled_weights.append(weights) arch_iter = iter(weights) # validate with single arch single_weight = search_model.sample_weights(1)[0] single_valid_acc = AverageMeter() network.eval() for i in range(10): try: val_input, val_target = next(valid_iter) except Exception as e: valid_iter = iter(valid_loader) val_input, val_target = next(valid_iter) n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=single_weight) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) single_valid_acc.update(val_acc1.item(), n_val) logger.log('[{:}] valid : accuracy = {:.2f}'.format( epoch_str, single_valid_acc.avg)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False if epoch < total_epoch: genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None and epoch < total_epoch: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() network.eval() # Evaluate the architectures sampled throughout the search for i in range(len(sampled_weights) - 1): logger.log('Sample eval : epoch {}'.format((i + 1) * 50 - 1)) for w in sampled_weights[i]: sample_valid_acc = AverageMeter() for i in range(10): try: val_input, val_target = next(valid_iter) except Exception as e: valid_iter = iter(valid_loader) val_input, val_target = next(valid_iter) n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=w) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) sample_valid_acc.update(val_acc1.item(), n_val) w_gene = search_model.genotype(w) if api is not None: ind = api.query_index_by_arch(w_gene) info = api.query_meta_info_by_index(ind) metrics = info.get_metrics('cifar10', 'ori-test') acc = metrics['accuracy'] else: acc = 0.0 logger.log( 'sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format( sample_valid_acc.avg, acc)) # Evaluate the final sampling separately to find the top 10 architectures logger.log('Final sample eval') final_archs = [] for w in sampled_weights[-1]: sample_valid_acc = AverageMeter() for i in range(10): try: val_input, val_target = next(valid_iter) except Exception as e: valid_iter = iter(valid_loader) val_input, val_target = next(valid_iter) n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=w) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) sample_valid_acc.update(val_acc1.item(), n_val) w_gene = search_model.genotype(w) if api is not None: ind = api.query_index_by_arch(w_gene) info = api.query_meta_info_by_index(ind) metrics = info.get_metrics('cifar10', 'ori-test') acc = metrics['accuracy'] else: acc = 0.0 logger.log('sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format( sample_valid_acc.avg, acc)) final_archs.append((w, sample_valid_acc.avg)) top_10 = sorted(final_archs, key=lambda x: x[1], reverse=True)[:10] # Evaluate the top 10 architectures on the entire validation set logger.log('Evaluating top archs') for w, prev_acc in top_10: full_valid_acc = AverageMeter() for val_input, val_target in valid_loader: n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=w) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) full_valid_acc.update(val_acc1.item(), n_val) w_gene = search_model.genotype(w) logger.log('genotype {}'.format(w_gene)) if api is not None: ind = api.query_index_by_arch(w_gene) info = api.query_meta_info_by_index(ind) metrics = info.get_metrics('cifar10', 'ori-test') acc = metrics['accuracy'] else: acc = 0.0 logger.log( 'full valid : val_acc = {:.2f} test_acc = {:.2f} pval_acc = {:.2f}' .format(full_valid_acc.avg, acc, prev_acc)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1]))) logger.close()
def main(xargs): PID = os.getpid() assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True prepare_seed(xargs.rand_seed) if xargs.timestamp == 'none': xargs.timestamp = "{:}".format( time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time()))) train_data, valid_data, xshape, class_num = get_datasets(xargs, -1) ##### config & logging ##### config = edict() config.class_num = class_num config.xshape = xshape config.batch_size = xargs.batch_size xargs.save_dir = xargs.save_dir + \ "/repeat%d-prunNum%d-prec%d-%s-batch%d"%( xargs.repeat, xargs.prune_number, xargs.precision, xargs.init, config["batch_size"]) + \ "/{:}/seed{:}".format(xargs.timestamp, xargs.rand_seed) config.save_dir = xargs.save_dir logger = prepare_logger(xargs) ############### if xargs.dataset in [ 'MiniImageNet', 'MetaMiniImageNet', 'TieredImageNet', 'MetaTieredImageNet' ]: train_loader = torch.utils.data.DataLoader(train_data, batch_size=xargs.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) elif xargs.dataset != 'imagenet-1k': search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/', config.batch_size, xargs.workers) else: train_loader = torch.utils.data.DataLoader(train_data, batch_size=xargs.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) logger.log( '||||||| {:10s} ||||||| Train-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(train_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.search_space_name == 'nas-bench-201': model_config = edict({ 'name': 'DARTS-V1', 'C': 3, 'N': 1, 'depth': -1, 'use_stem': True, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': True, 'track_running_stats': bool(xargs.track_running_stats), }) model_config_thin = edict({ 'name': 'DARTS-V1', 'C': 1, 'N': 1, 'depth': 1, 'use_stem': False, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': True, 'track_running_stats': bool(xargs.track_running_stats), }) elif xargs.search_space_name in ['darts', 'darts_fewshot']: model_config = edict({ 'name': 'DARTS-V1', 'C': 1, 'N': 1, 'depth': 2, 'use_stem': True, 'stem_multiplier': 1, 'num_classes': class_num, 'space': search_space, 'affine': True, 'track_running_stats': bool(xargs.track_running_stats), 'super_type': xargs.super_type, 'steps': xargs.max_nodes, 'multiplier': xargs.max_nodes, }) model_config_thin = edict({ 'name': 'DARTS-V1', 'C': 1, 'N': 1, 'depth': 2, 'use_stem': False, 'stem_multiplier': 1, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': True, 'track_running_stats': bool(xargs.track_running_stats), 'super_type': xargs.super_type, 'steps': xargs.max_nodes, 'multiplier': xargs.max_nodes, }) network = get_cell_based_tiny_net(model_config) logger.log('model-config : {:}'.format(model_config)) arch_parameters = [ alpha.detach().clone() for alpha in network.get_alphas() ] for alpha in arch_parameters: alpha[:, :] = 0 # TODO Linear_Region_Collector lrc_model = Linear_Region_Collector(xargs, input_size=(1000, 1, 3, 3), sample_batch=3, dataset=xargs.dataset, data_path=xargs.data_path, seed=xargs.rand_seed) # ### all params trainable (except train_bn) ######################### flop, param = get_model_infos(network, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None or xargs.search_space_name in [ 'darts', 'darts_fewshot' ]: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) network = network.cuda() genotypes = {} genotypes['arch'] = { -1: network.genotype() } arch_parameters_history = [] arch_parameters_history_npy = [] start_time = time.time() epoch = -1 for alpha in arch_parameters: alpha[:, 0] = -INF arch_parameters_history.append( [alpha.detach().clone() for alpha in arch_parameters]) arch_parameters_history_npy.append( [alpha.detach().clone().cpu().numpy() for alpha in arch_parameters]) np.save(os.path.join(xargs.save_dir, "arch_parameters_history.npy"), arch_parameters_history_npy) while not is_single_path(network): epoch += 1 torch.cuda.empty_cache() print("<< ============== JOB (PID = %d) %s ============== >>" % (PID, '/'.join(xargs.save_dir.split("/")[-6:]))) arch_parameters, op_pruned = prune_func_rank( xargs, arch_parameters, model_config, model_config_thin, train_loader, lrc_model, search_space, precision=xargs.precision, prune_number=xargs.prune_number) # rebuild supernet network = get_cell_based_tiny_net(model_config) network = network.cuda() network.set_alphas(arch_parameters) arch_parameters_history.append( [alpha.detach().clone() for alpha in arch_parameters]) arch_parameters_history_npy.append([ alpha.detach().clone().cpu().numpy() for alpha in arch_parameters ]) np.save(os.path.join(xargs.save_dir, "arch_parameters_history.npy"), arch_parameters_history_npy) genotypes['arch'][epoch] = network.genotype() logger.log('operators remaining (1s) and prunned (0s)\n{:}'.format( '\n'.join([ str((alpha > -INF).int()) for alpha in network.get_alphas() ]))) if xargs.search_space_name in ['darts', 'darts_fewshot']: print("===>>> Prune Edge Groups...") if xargs.max_nodes == 4: edge_groups = [(0, 2), (2, 5), (5, 9), (9, 14)] elif xargs.max_nodes == 3: edge_groups = [(0, 2), (2, 5), (5, 9)] arch_parameters = prune_func_rank_group( xargs, arch_parameters, model_config, model_config_thin, train_loader, lrc_model, search_space, edge_groups=edge_groups, num_per_group=2, precision=xargs.precision, ) network = get_cell_based_tiny_net(model_config) network = network.cuda() network.set_alphas(arch_parameters) arch_parameters_history.append( [alpha.detach().clone() for alpha in arch_parameters]) arch_parameters_history_npy.append([ alpha.detach().clone().cpu().numpy() for alpha in arch_parameters ]) np.save(os.path.join(xargs.save_dir, "arch_parameters_history.npy"), arch_parameters_history_npy) logger.log('<<<--->>> End: {:}'.format(network.genotype())) logger.log('operators remaining (1s) and prunned (0s)\n{:}'.format( '\n'.join( [str((alpha > -INF).int()) for alpha in network.get_alphas()]))) end_time = time.time() logger.log('\n' + '-' * 100) logger.log("Time spent: %d s" % (end_time - start_time)) # check the performance from the architecture dataset if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes['arch'][epoch]))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if False: #last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies = 0, {'best': -1} if len(xargs.supernet_path) > 0: saved_info = torch.load(xargs.supernet_path) assert saved_info[ 'epoch'] == 'finished', "Epoch is not finished in this file" search_model.load_state_dict(saved_info['search_model']) else: # start training supernet start_time = time.time() train_shared_cnn(train_loader, network, criterion, w_scheduler, w_optimizer, xargs.print_freq, logger, config, start_epoch) logger.log( 'Supernet trained. Time-cost = {:.1f} s'.format(time.time() - start_time)) # save supernetweight save_path = save_checkpoint( { 'epoch': 'finished', #epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict() }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': 'finished', #epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) search_start_time = time.time() searcher = search_model.getSearcher(network, train_loader, valid_loader, logger, config) best_cands, performance_dict, performance_trace = searcher.search() logger.log( 'Architect Searched. Time-cost = {:.1f} s'.format(time.time() - search_start_time)) search_result = save_checkpoint( { 'epoch': 'finished', #epoch + 1, 'args': deepcopy(args), 'genotypes': best_cands, 'performance_dict': performance_dict, 'performance_trace': performance_trace }, model_best_path, logger) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, test_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) logger.log('use config from : {:}'.format(xargs.config_path)) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) _, train_loader, valid_loader = get_nas_search_loaders( train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) if hasattr(valid_loader.dataset, 'transforms'): valid_loader.dataset.transforms = deepcopy( train_loader.dataset.transforms) # data loader logger.log( '||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() w_optimizer, w_scheduler, criterion = get_optim_scheduler( shared_cnn.parameters(), config) a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) #flop, param = get_model_infos(shared_cnn, xshape) #logger.log('{:}'.format(shared_cnn)) #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) shared_cnn, controller, criterion = torch.nn.DataParallel( shared_cnn).cuda(), controller.cuda(), criterion.cuda() last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] baseline = checkpoint['baseline'] valid_accuracies = checkpoint['valid_accuracies'] shared_cnn.load_state_dict(checkpoint['shared_cnn']) controller.load_state_dict(checkpoint['controller']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes, baseline = 0, { 'best': -1 }, {}, None # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log( '\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn( train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) logger.log( '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \ = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \ dict2config({'baseline': baseline, 'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate, 'ctl_entropy_w': xargs.controller_entropy_weight, 'ctl_bl_dec' : xargs.controller_bl_dec}, None), \ epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s' .format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum)) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) shared_cnn.module.update_arch(best_arch) _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) genotypes[epoch] = best_arch # check the best accuracy valid_accuracies[epoch] = best_valid_acc if best_valid_acc > valid_accuracies['best']: valid_accuracies['best'] = best_valid_acc genotypes['best'] = best_arch find_best = True else: find_best = False logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'baseline': baseline, 'shared_cnn': shared_cnn.state_dict(), 'controller': controller.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, best_valid_acc)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) logger.log('During searching, the best architecture is {:}'.format( genotypes['best'])) logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) logger.log('Randomly select {:} architectures and select the best.'.format( xargs.controller_num_samples)) start_time = time.time() final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) search_time.update(time.time() - start_time) shared_cnn.module.update_arch(final_arch) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) logger.log('The Selected Final Architecture : {:}'.format(final_arch)) logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format( final_loss, final_top1, final_top5)) logger.log( 'ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, final_arch)) if api is not None: logger.log('{:}'.format(api.query_by_arch(final_arch))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() total_epochs = xargs.epochs iters = 1 drop_ops_num = 20 iters = 30 // drop_ops_num start_iter, epochs, operations = 0, total_epochs // iters, {} for i in range(1, 4): for j in range(i): node_str = (i, j) operations[node_str] = copy.deepcopy(search_space) logger.log('operations={}'.format(operations)) for i in range(start_iter, iters): train_func(xargs, search_loader, valid_loader, network, operations, criterion, \ w_scheduler, w_optimizer, logger, i, epochs) shrinking(xargs, valid_loader, i, network, operations, drop_ops_num, search_space, logger, api) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) if xargs.ood_inner or xargs.ood_outer: mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] # lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] lists = [transforms.ToTensor(), transforms.Normalize(mean, std)] # lists += [CUTOUT(-1)] ood_transform = transforms.Compose(lists) ood_data = dset.SVHN(root=args.data_path, split='train', download=True, transform=ood_transform) ood_loader = torch.utils.data.DataLoader( ood_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( list(range(len(ood_data)))[:len(train_data)]), pin_memory=True, num_workers=xargs.workers) else: ood_loader = None logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) global search_space search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats), }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) # logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') # network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() network, criterion = search_model.cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format( epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs, logger, ood_loader) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch - 1], '200'))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) print(config) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( model.get_weights(), config) a_optimizer = torch.optim.Adam(model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel(model).cuda(), criterion.cuda() checkpoint_path = 'output/search-cell-nas-bench-102/result-{}/checkpoint/seed-{}_epoch-{}.pth'.format( xargs.dataset, xargs.rand_seed, xargs.epoch) if checkpoint_path is not None: # automatically resume from previous checkpoint logger.log("=> loading checkpoint from {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['search_model']) # start inference start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup all_archs = network.module.get_all_archs() random.shuffle(all_archs) valid_accuracies = {} process_start_time = time.time() for i, genotype in enumerate(all_archs): network.module.set_cal_mode('dynamic', genotype) recalculate_bn(network, search_loader) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(i, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) valid_accuracies[genotype.tostr()] = valid_a_top1 process_end_time = time.time() logger.log('process time: {}'.format(process_end_time - process_start_time)) torch.save(valid_accuracies, '{}/result.dat'.format(xargs.save_dir)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'SPOS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype))) logger.close()
def main(xargs): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, test_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) logger.log("use config from : {:}".format(xargs.config_path)) config = load_config(xargs.config_path, { "class_num": class_num, "xshape": xshape }, logger) _, train_loader, valid_loader = get_nas_search_loaders( train_data, test_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers, ) # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) if hasattr(valid_loader.dataset, "transforms"): valid_loader.dataset.transforms = deepcopy( train_loader.dataset.transforms) # data loader logger.log( "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}" .format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) logger.log("||||||| {:10s} ||||||| Config={:}".format( xargs.dataset, config)) search_space = get_search_spaces("cell", xargs.search_space_name) model_config = dict2config( { "name": "ENAS", "C": xargs.channel, "N": xargs.num_cells, "max_nodes": xargs.max_nodes, "num_classes": class_num, "space": search_space, "affine": False, "track_running_stats": bool(xargs.track_running_stats), }, None, ) shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() w_optimizer, w_scheduler, criterion = get_optim_scheduler( shared_cnn.parameters(), config) a_optimizer = torch.optim.Adam( controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) # flop, param = get_model_infos(shared_cnn, xshape) # logger.log('{:}'.format(shared_cnn)) # logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log("search-space : {:}".format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) shared_cnn, controller, criterion = ( torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda(), ) last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) genotypes = checkpoint["genotypes"] baseline = checkpoint["baseline"] valid_accuracies = checkpoint["valid_accuracies"] shared_cnn.load_state_dict(checkpoint["shared_cnn"]) controller.load_state_dict(checkpoint["controller"]) w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes, baseline = 0, { "best": -1 }, {}, None # start training start_time, search_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup, ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) logger.log( "\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format( epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn( train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger, ) logger.log( "[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%" .format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline = train_controller( valid_loader, shared_cnn, controller, criterion, a_optimizer, dict2config( { "baseline": baseline, "ctl_train_steps": xargs.controller_train_steps, "ctl_num_aggre": xargs.controller_num_aggregate, "ctl_entropy_w": xargs.controller_entropy_weight, "ctl_bl_dec": xargs.controller_bl_dec, }, None, ), epoch_str, xargs.print_freq, logger, ) search_time.update(time.time() - start_time) logger.log( "[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s" .format( epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum, )) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) shared_cnn.module.update_arch(best_arch) _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) genotypes[epoch] = best_arch # check the best accuracy valid_accuracies[epoch] = best_valid_acc if best_valid_acc > valid_accuracies["best"]: valid_accuracies["best"] = best_valid_acc genotypes["best"] = best_arch find_best = True else: find_best = False logger.log("<<<--->>> The {:}-th epoch : {:}".format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(xargs), "baseline": baseline, "shared_cnn": shared_cnn.state_dict(), "controller": controller.state_dict(), "w_optimizer": w_optimizer.state_dict(), "a_optimizer": a_optimizer.state_dict(), "w_scheduler": w_scheduler.state_dict(), "genotypes": genotypes, "valid_accuracies": valid_accuracies, }, model_base_path, logger, ) last_info = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(args), "last_checkpoint": save_path, }, logger.path("info"), logger, ) if find_best: logger.log( "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%." .format(epoch_str, best_valid_acc)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log("\n" + "-" * 100) logger.log("During searching, the best architecture is {:}".format( genotypes["best"])) logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"])) logger.log("Randomly select {:} architectures and select the best.".format( xargs.controller_num_samples)) start_time = time.time() final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) search_time.update(time.time() - start_time) shared_cnn.module.update_arch(final_arch) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) logger.log("The Selected Final Architecture : {:}".format(final_arch)) logger.log("Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format( final_loss, final_top1, final_top5)) logger.log( "ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( total_epoch, search_time.sum, final_arch)) if api is not None: logger.log("{:}".format(api.query_by_arch(final_arch))) logger.close()
def main(xargs): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) if xargs.overwite_epochs is None: extra_info = {"class_num": class_num, "xshape": xshape} else: extra_info = { "class_num": class_num, "xshape": xshape, "epochs": xargs.overwite_epochs, } config = load_config(xargs.config_path, extra_info, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", (config.batch_size, config.test_batch_size), xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}" .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log("||||||| {:10s} ||||||| Config={:}".format( xargs.dataset, config)) search_space = get_search_spaces(xargs.search_space, "nats-bench") model_config = dict2config( dict( name="generic", super_type="search-shape", candidate_Cs=search_space["candidates"], max_num_Cs=search_space["numbers"], num_classes=class_num, genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats), ), None, ) logger.log("search space : {:}".format(search_space)) logger.log("model config : {:}".format(model_config)) search_model = get_cell_based_tiny_net(model_config) search_model.set_algo(xargs.algo) logger.log("{:}".format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.weights, config) a_optimizer = torch.optim.Adam( search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) params = count_parameters_in_MB(search_model) logger.log("The parameters of the search model = {:.2f} MB".format(params)) logger.log("search-space : {:}".format(search_space)) if bool(xargs.use_api): api = create(None, "size", fast_mode=True, verbose=False) else: api = None logger.log("{:} create API = {:} done".format(time_string(), api)) last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) network, criterion = search_model.cuda(), criterion.cuda( ) # use a single GPU last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) genotypes = checkpoint["genotypes"] valid_accuracies = checkpoint["valid_accuracies"] search_model.load_state_dict(checkpoint["search_model"]) w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { "best": -1 }, { -1: network.random } # start training start_time, search_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup, ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) if (xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch): enable_controller = True network.set_warmup_ratio(None) else: enable_controller = False network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio) logger.log( "\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}" .format( epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller, )) if xargs.algo == "mask_gumbel" or xargs.algo == "tas": network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log("[RESET tau as : {:}]".format(network.tau)) ( search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5, ) = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, enable_controller, xargs.algo, epoch_str, xargs.print_freq, logger, ) search_time.update(time.time() - start_time) logger.log( "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s" .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%" .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype = network.genotype logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion, logger) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}" .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log("<<<--->>> The {:}-th epoch : {:}".format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(xargs), "search_model": search_model.state_dict(), "w_optimizer": w_optimizer.state_dict(), "a_optimizer": a_optimizer.state_dict(), "w_scheduler": w_scheduler.state_dict(), "genotypes": genotypes, "valid_accuracies": valid_accuracies, }, model_base_path, logger, ) last_info = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(args), "last_checkpoint": save_path, }, logger.path("info"), logger, ) with torch.no_grad(): logger.log("{:}".format(search_model.show_alphas())) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "90"))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype = network.genotype search_time.update(time.time() - start_time) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion, logger) logger.log( "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%." .format(genotype, valid_a_top1)) logger.log("\n" + "-" * 100) # check the performance from the architecture dataset logger.log("[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( xargs.algo, total_epoch, search_time.sum, genotype)) if api is not None: logger.log("{:}".format(api.query_by_arch(genotype, "90"))) logger.close()
def main(xargs): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { "class_num": class_num, "xshape": xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", (config.batch_size, config.test_batch_size), xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}" .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log("||||||| {:10s} ||||||| Config={:}".format( xargs.dataset, config)) search_space = get_search_spaces("cell", xargs.search_space_name) model_config = dict2config( { "name": "RANDOM", "C": xargs.channel, "N": xargs.num_cells, "max_nodes": xargs.max_nodes, "num_classes": class_num, "space": search_space, "affine": False, "track_running_stats": bool(xargs.track_running_stats), }, None, ) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.parameters(), config) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) genotypes = checkpoint["genotypes"] valid_accuracies = checkpoint["valid_accuracies"] search_model.load_state_dict(checkpoint["search_model"]) w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {} # start training start_time, search_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup, ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format( epoch_str, need_time, min(w_scheduler.get_lr()))) # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger, ) search_time.update(time.time() - start_time) logger.log( "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s" .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%" .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) logger.log("[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format( epoch_str, cur_arch, cur_valid_acc)) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies["best"]: valid_accuracies["best"] = valid_a_top1 find_best = True else: find_best = False # save checkpoint save_path = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(xargs), "search_model": search_model.state_dict(), "w_optimizer": w_optimizer.state_dict(), "w_scheduler": w_scheduler.state_dict(), "genotypes": genotypes, "valid_accuracies": valid_accuracies, }, model_base_path, logger, ) last_info = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(args), "last_checkpoint": save_path, }, logger.path("info"), logger, ) if find_best: logger.log( "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%." .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log("\n" + "-" * 200) logger.log("Pre-searching costs {:.1f} s".format(search_time.sum)) start_time = time.time() best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log( "RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s." .format(best_arch, best_acc, search_time.sum)) if api is not None: logger.log("{:}".format(api.query_by_arch(best_arch, "200"))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() start_epoch, valid_accuracies = 0, {'best': -1} # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, xargs.epochs): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict() }, "{}_epoch-{}.pth".format(model_base_path, epoch), logger) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() # genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log('\n' + '-' * 100) logger.log('SPOS : run {:} epochs, cost {:.1f} s.'.format( total_epoch, search_time.sum)) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) model = get_cell_based_tiny_net(model_config) flop, param = get_model_infos(model, xshape) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) file_name = '{}/result.dat'.format(xargs.save_dir) result = torch.load(file_name) logger.log('load file from {}'.format(file_name)) logger.log("length of result={}".format(len(result.keys()))) real_acc = {} for key in result.keys(): real_acc[key] = get_arch_real_acc(api, key, xargs.dataset) if real_acc[key] is None: sys.exit(0) real_acc = sorted(real_acc.items(), key=lambda d: d[1], reverse=True) result = sorted(result.items(), key=lambda d: d[1], reverse=True) real_rank = {} rank = 1 for value in real_acc: real_rank[value[0]] = rank rank += 1 metric_rank = {} rank = 1 for value in result: metric_rank[value[0]] = rank rank += 1 real_rank_list, rank_list = [], [] rank = 1 for value in real_acc: rank_list.append(metric_rank[value[0]]) real_rank_list.append(rank) rank += 1 logger.log("ktau = {}".format( scipy.stats.stats.kendalltau(real_rank_list, rank_list)[0]))
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) if xargs.overwite_epochs is None: extra_info = {'class_num': class_num, 'xshape': xshape} else: extra_info = { 'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs } config = load_config(xargs.config_path, extra_info, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') model_config = dict2config( dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num, genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) logger.log('search space : {:}'.format(search_space)) logger.log('model config : {:}'.format(model_config)) search_model = get_cell_based_tiny_net(model_config) search_model.set_algo(xargs.algo) logger.log('{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.weights, config) a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) params = count_parameters_in_MB(search_model) logger.log('The parameters of the search model = {:.2f} MB'.format(params)) logger.log('search-space : {:}'.format(search_space)) if bool(xargs.use_api): api = create(None, 'size', fast_mode=True, verbose=False) else: api = None logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = search_model.cuda(), criterion.cuda( ) # use a single GPU last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: network.random } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) if xargs.algo == 'fbv2' or xargs.algo == 'tas': network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('[RESET tau as : {:}]'.format(network.tau)) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype = network.genotype logger.log('[{:}] - [get_best_arch] : {:}'.format(epoch_str, genotype)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion, logger) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '90'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype = network.genotype search_time.update(time.time() - start_time) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion, logger) logger.log( 'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.' .format(genotype, valid_a_top1)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( xargs.algo, total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '90'))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space, 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) logger.log('search space : {:}'.format(search_space)) model = get_cell_based_tiny_net(model_config) flop, param = get_model_infos(model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) angles = {} for arch_idx in range(0, 10000, 200): checkpoint_path_template = 'output/search-cell-nas-bench-102/result-{}/standalone_arch-{}/checkpoint/seed-{}_epoch-{}.pth' logger.log("=> loading checkpoint from {}".format(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, 0))) epochs = config.epochs - 1 load(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, 0), model) init_model = deepcopy(model) genotype = load(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, epochs), model) logger.log("=> loading checkpoint from {}".format(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, epochs))) cur_model = deepcopy(model) angle = get_arch_angle(init_model, cur_model, genotype, search_space) logger.log('[{:}] cal angle : angle={} | {:}, acc: {}'.format(arch_idx, angle, genotype, get_arch_real_acc(api, genotype, args))) angles[genotype.tostr()] = angle real_acc = {} for key in angles.keys(): real_acc[key] = get_arch_real_acc(api, key, args) assert(real_acc[key] is not None) real_acc = sorted(real_acc.items(), key=lambda d: d[1], reverse=True) angles = sorted(angles.items(), key=lambda d: d[1], reverse=True) angle_rank = {} rank = 1 for value in angles: angle_rank[value[0]] = rank rank += 1 angle_rank_list, real_rank_list = [],[] rank = 1 for value in real_acc: angle_rank_list.append(angle_rank[value[0]]) real_rank_list.append(rank) rank += 1 logger.log('Real_rank_list={}'.format(real_rank_list)) logger.log('Angle_rank_list={}'.format(angle_rank_list)) logger.log('Tau={}'.format(scipy.stats.stats.kendalltau(real_rank_list, angle_rank_list)[0]))
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) # search_space = get_search_spaces('cell', xargs.search_space_name) search_space = get_sub_search_spaces('cell', xargs.search_space_name) logger.log('search_space={}'.format(search_space)) model_config = dict2config( { 'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) min_LR = min(w_scheduler.get_lr()) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min_LR)) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) # if find_best: # logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) # copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1]))) logger.close()
def main(xargs, myargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(xargs) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'AutoDL-Projects/configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.num_worker) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if not hasattr(xargs, 'model_config') or xargs.model_config is None: model_config = dict2config( dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None) else: model_config = load_config( xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None) logger.log('search space : {:}'.format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: init_genotype } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) #search_model.set_cal_mode('urs') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('joint') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('select') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'last_checkpoint': save_path, }, logger.path('info'), logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( 'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.' .format(genotype, valid_a_top1)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200'))) logger.close()
def main(xargs): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { "class_num": class_num, "xshape": xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}" .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log("||||||| {:10s} ||||||| Config={:}".format( xargs.dataset, config)) search_space = get_search_spaces("cell", xargs.search_space_name) model_config = dict2config( { "name": "DARTS-V2", "C": xargs.channel, "N": xargs.num_cells, "max_nodes": xargs.max_nodes, "num_classes": class_num, "space": search_space, "affine": False, "track_running_stats": bool(xargs.track_running_stats), }, None, ) search_model = get_cell_based_tiny_net(model_config) logger.log("search-model :\n{:}".format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam( search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) flop, param = get_model_infos(search_model, xshape) # logger.log('{:}'.format(search_model)) logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) genotypes = checkpoint["genotypes"] valid_accuracies = checkpoint["valid_accuracies"] search_model.load_state_dict(checkpoint["search_model"]) w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = ( 0, { "best": -1 }, { -1: search_model.genotype() }, ) # start training start_time, search_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup, ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) min_LR = min(w_scheduler.get_lr()) logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format( epoch_str, need_time, min_LR)) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, ) search_time.update(time.time() - start_time) logger.log( "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s" .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%" .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies["best"]: valid_accuracies["best"] = valid_a_top1 genotypes["best"] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log("<<<--->>> The {:}-th epoch : {:}".format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(xargs), "search_model": search_model.state_dict(), "w_optimizer": w_optimizer.state_dict(), "a_optimizer": a_optimizer.state_dict(), "w_scheduler": w_scheduler.state_dict(), "genotypes": genotypes, "valid_accuracies": valid_accuracies, }, model_base_path, logger, ) last_info = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(args), "last_checkpoint": save_path, }, logger.path("info"), logger, ) if find_best: logger.log( "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%." .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log("arch-parameters :\n{:}".format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log("\n" + "-" * 100) # check the performance from the architecture dataset logger.log( "DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1]), "200")) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) model = get_cell_based_tiny_net(model_config) flop, param = get_model_infos(model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') checkpoint_path_template = 'output/search-cell-nas-bench-102/result-{}/checkpoint/seed-{}_epoch-{}.pth' if checkpoint_path_template is not None: # automatically resume from previous checkpoint logger.log("=> loading checkpoint from {}".format( checkpoint_path_template.format(xargs.dataset, xargs.rand_seed, 0))) load( checkpoint_path_template.format(xargs.dataset, xargs.rand_seed, 0), model) init_model = deepcopy(model) load( checkpoint_path_template.format(xargs.dataset, xargs.rand_seed, xargs.epochs), model) # start calculate angles start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup all_archs = model.get_all_archs() arch_angles = {} process_start_time = time.time() for i, genotype in enumerate(all_archs): angle = get_arch_angle(init_model, model, genotype, search_space) logger.log('[{:}] cal angle : angle={} | {:}, acc: {}'.format( i, angle, genotype, get_arch_real_acc(api, genotype))) arch_angles[genotype.tostr()] = angle # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() process_end_time = time.time() logger.log('process time: {}'.format(process_end_time - process_start_time)) torch.save(arch_angles, '{}/result.dat'.format(xargs.save_dir)) logger.close()