def main(args): assert os.path.isdir(args.data_path), 'invalid data-path : {:}'.format( args.data_path) assert os.path.isfile(args.checkpoint), 'invalid checkpoint : {:}'.format( args.checkpoint) checkpoint = torch.load(args.checkpoint) xargs = checkpoint['args'] train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, args.data_path, xargs.cutout_length) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=xargs.batch_size, shuffle=False, num_workers=xargs.workers, pin_memory=True) logger = PrintLogger() model_config = dict2config(checkpoint['model-config'], logger) base_model = obtain_model(model_config) flop, param = get_model_infos(base_model, xshape) logger.log('model ====>>>>:\n{:}'.format(base_model)) logger.log('model information : {:}'.format(base_model.get_message())) logger.log('-' * 50) logger.log('Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G'.format( param, flop, flop / 1e3)) logger.log('-' * 50) logger.log('valid_data : {:}'.format(valid_data)) optim_config = dict2config(checkpoint['optim-config'], logger) _, _, criterion = get_optim_scheduler(base_model.parameters(), optim_config) logger.log('criterion : {:}'.format(criterion)) base_model.load_state_dict(checkpoint['base-model']) _, valid_func = get_procedures(xargs.procedure) logger.log( 'initialize the CNN done, evaluate it using {:}'.format(valid_func)) network = torch.nn.DataParallel(base_model).cuda() try: valid_loss, valid_acc1, valid_acc5 = valid_func( valid_loader, network, criterion, optim_config, 'pure-evaluation', xargs.print_freq_eval, logger) except: _, valid_func = get_procedures('basic') valid_loss, valid_acc1, valid_acc5 = valid_func( valid_loader, network, criterion, optim_config, 'pure-evaluation', xargs.print_freq_eval, logger) num_bytes = torch.cuda.max_memory_cached( next(network.parameters()).device) * 1.0 logger.log( '***{:s}*** EVALUATION loss = {:.6f}, accuracy@1 = {:.2f}, accuracy@5 = {:.2f}, error@1 = {:.2f}, error@5 = {:.2f}' .format(time_string(), valid_loss, valid_acc1, valid_acc5, 100 - valid_acc1, 100 - valid_acc5)) logger.log( '[GPU-Memory-Usage on {:} is {:} bytes, {:.2f} KB, {:.2f} MB, {:.2f} GB.]' .format( next(network.parameters()).device, int(num_bytes), num_bytes / 1e3, num_bytes / 1e6, num_bytes / 1e9)) logger.close()
def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, Any], results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount: xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None) if 'train_times' in results: # new version xresult.update_train_info(results['train_acc1es'], results['train_acc5es'], results['train_losses'], results['train_times']) xresult.update_eval(results['valid_acc1es'], results['valid_losses'], results['valid_times']) else: network = get_cell_based_tiny_net(net_config) network.load_state_dict(xresult.get_net_param()) if dataset == 'cifar10-valid': xresult.update_OLD_eval('x-valid' , results['valid_acc1es'], results['valid_losses']) loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format('cifar10', 'test')], network.cuda()) xresult.update_OLD_eval('ori-test', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) xresult.update_latency(latencies) elif dataset == 'cifar10': xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) xresult.update_latency(latencies) elif dataset == 'cifar100' or dataset == 'ImageNet16-120': xresult.update_OLD_eval('ori-test', results['valid_acc1es'], results['valid_losses']) loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network.cuda()) xresult.update_OLD_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network.cuda()) xresult.update_OLD_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) xresult.update_latency(latencies) else: raise ValueError('invalid dataset name : {:}'.format(dataset)) return xresult
def load_net_from_checkpoint(checkpoint): assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) checkpoint = torch.load(checkpoint) model_config = dict2config(checkpoint['model-config'], None) model = obtain_model(model_config) model.load_state_dict(checkpoint['base-model']) return model
def get_cell_based_tiny_net(config): if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict super_type = getattr(config, 'super_type', 'basic') group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] if super_type == 'basic' and config.name in group_names: from .cell_searchs import nas201_super_nets as nas_super_nets try: return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) except: return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) elif super_type == 'nasnet-super': from .cell_searchs import nasnet_super_nets as nas_super_nets return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) elif config.name == 'infer.tiny': from .cell_infers import TinyNetwork if hasattr(config, 'genotype'): genotype = config.genotype elif hasattr(config, 'arch_str'): genotype = CellStructure.str2structure(config.arch_str) else: raise ValueError( 'Can not find genotype from this config : {:}'.format(config)) return TinyNetwork(config.C, config.N, genotype, config.num_classes) else: raise ValueError('invalid network name : {:}'.format(config.name))
def test_one_shot_model(ckpath, use_train): from models import get_cell_based_tiny_net, get_search_spaces from datasets import get_datasets, SearchDataset from config_utils import load_config, dict2config from utils.nas_utils import evaluate_one_shot use_train = int(use_train) > 0 #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' #ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' print ('ckpath : {:}'.format(ckpath)) ckp = torch.load(ckpath) xargs = ckp['args'] train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) #config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) config = load_config('./configs/nas-benchmark/algos/DARTS.config', {'class_num': class_num, 'xshape': xshape}, None) if xargs.dataset == 'cifar10': cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None) xvalid_data = deepcopy(train_data) xvalid_data.transform = valid_data.transform valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space, 'affine' : False, 'track_running_stats': True}, None) search_model = get_cell_based_tiny_net(model_config) search_model.load_state_dict( ckp['search_model'] ) search_model = search_model.cuda() api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth') archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict): information = ArchResults(arch_index, arch_str) for checkpoint_path in checkpoints: checkpoint = torch.load(checkpoint_path, map_location='cpu') used_seed = checkpoint_path.name.split('-')[-1].split('.')[0] for dataset in datasets: assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(dataset, arch_index, checkpoint_path) results = checkpoint[dataset] assert results['finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(arch_index, used_seed, dataset, checkpoint_path) arch_config = {'channel': results['channel'], 'num_cells': results['num_cells'], 'arch_str': arch_str, 'class_num': results['config']['class_num']} xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], \ results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) if dataset == 'cifar10-valid': xresult.update_eval('x-valid' , results['valid_acc1es'], results['valid_losses']) elif dataset == 'cifar10': xresult.update_eval('ori-test', results['valid_acc1es'], results['valid_losses']) elif dataset == 'cifar100' or dataset == 'ImageNet16-120': xresult.update_eval('ori-test', results['valid_acc1es'], results['valid_losses']) net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None) network = get_cell_based_tiny_net(net_config) network.load_state_dict(xresult.get_net_param()) network = network.cuda() loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'valid')], network) xresult.update_eval('x-valid', {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) loss, top1, top5, latencies = pure_evaluate(dataloader_dict['{:}@{:}'.format(dataset, 'test')], network) xresult.update_eval('x-test' , {results['total_epoch']-1: top1}, {results['total_epoch']-1: loss}) xresult.update_latency(latencies) else: raise ValueError('invalid dataset name : {:}'.format(dataset)) information.update(dataset, int(used_seed), xresult) return information
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger): prepare_seed(seed) # random seed net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': arch, 'num_classes': config.class_num} , None) ) #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) flop, param = get_model_infos(net, config.xshape) logger.log('Network : {:}'.format(net.get_message()), False) logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) # train and valid optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} train_times , valid_times = {}, {} for epoch in range(total_epoch): scheduler.update(epoch, 0.0) train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') train_losses[epoch] = train_loss train_acc1es[epoch] = train_acc1 train_acc5es[epoch] = train_acc5 train_times [epoch] = train_tm with torch.no_grad(): for key, xloder in valid_loaders.items(): valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid') valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1 valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5)) info_seed = {'flop' : flop, 'param': param, 'channel' : arch_config['channel'], 'num_cells' : arch_config['num_cells'], 'config' : config._asdict(), 'total_epoch' : total_epoch , 'train_losses': train_losses, 'train_acc1es': train_acc1es, 'train_acc5es': train_acc5es, 'train_times' : train_times, 'valid_losses': valid_losses, 'valid_acc1es': valid_acc1es, 'valid_acc5es': valid_acc5es, 'valid_times' : valid_times, 'net_state_dict': net.state_dict(), 'net_string' : '{:}'.format(net), 'finish-train': True } return info_seed
def test_one_shot_model(ckpath, use_train): from models import get_cell_based_tiny_net, get_search_spaces from datasets import get_datasets, SearchDataset from config_utils import load_config, dict2config from utils.nas_utils import evaluate_one_shot use_train = int(use_train) > 0 # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth' # ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth' print("ckpath : {:}".format(ckpath)) ckp = torch.load(ckpath) xargs = ckp["args"] train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) # config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None) config = load_config( "./configs/nas-benchmark/algos/DARTS.config", { "class_num": class_num, "xshape": xshape }, None, ) if xargs.dataset == "cifar10": cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None) xvalid_data = deepcopy(train_data) xvalid_data.transform = valid_data.transform valid_loader = torch.utils.data.DataLoader( xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar_split.valid), num_workers=12, pin_memory=True, ) else: raise ValueError("invalid dataset : {:}".format(xargs.dataseet)) search_space = get_search_spaces("cell", xargs.search_space_name) model_config = dict2config( { "name": "SETN", "C": xargs.channel, "N": xargs.num_cells, "max_nodes": xargs.max_nodes, "num_classes": class_num, "space": search_space, "affine": False, "track_running_stats": True, }, None, ) search_model = get_cell_based_tiny_net(model_config) search_model.load_state_dict(ckp["search_model"]) search_model = search_model.cuda() api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth") archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
def get_cell_based_tiny_net(config): if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict super_type = getattr(config, 'super_type', 'basic') if super_type == 'nasnet-super': return NASNetworkPRDARTS(config.C, config.N, config.steps, config.multiplier, \ config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) else: raise ValueError('invalid network name : {:}'.format(config.name))
def get_cell_based_tiny_net(config): if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict super_type = getattr(config, 'super_type', 'basic') group_names = [ 'DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic' ] if super_type == 'basic' and config.name in group_names: from .cell_searchs import nas201_super_nets as nas_super_nets try: return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) except: return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) elif super_type == 'search-shape': from .shape_searchs import GenericNAS301Model genotype = CellStructure.str2structure(config.genotype) return GenericNAS301Model(config.candidate_Cs, config.max_num_Cs, genotype, config.num_classes, config.affine, config.track_running_stats) elif super_type == 'nasnet-super': from .cell_searchs import nasnet_super_nets as nas_super_nets return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) elif config.name == 'infer.tiny': from .cell_infers import TinyNetwork if hasattr(config, 'genotype'): genotype = config.genotype elif hasattr(config, 'arch_str'): genotype = CellStructure.str2structure(config.arch_str) else: raise ValueError( 'Can not find genotype from this config : {:}'.format(config)) return TinyNetwork(config.C, config.N, genotype, config.num_classes) elif config.name == 'infer.shape.tiny': from .shape_infers import DynamicShapeTinyNet if isinstance(config.channels, str): channels = tuple([int(x) for x in config.channels.split(':')]) else: channels = config.channels genotype = CellStructure.str2structure(config.genotype) return DynamicShapeTinyNet(channels, genotype, config.num_classes) elif config.name == 'infer.nasnet-cifar': from .cell_infers import NASNetonCIFAR raise NotImplementedError else: raise ValueError('invalid network name : {:}'.format(config.name))
def get_cell_based_tiny_net(config): if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict super_type = getattr(config, 'super_type', 'basic') group_names = ['DARTS-V1', 'DARTS-V1-ccbn', 'DARTS-V1-sabn'] if super_type == 'basic' and config.name in group_names: from .cell_searchs import nas201_super_nets as nas_super_nets return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) else: raise ValueError('invalid network name : {:}'.format(config.name))
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space, 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) logger.log('search space : {:}'.format(search_space)) model = get_cell_based_tiny_net(model_config) flop, param = get_model_infos(model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) checkpoint_path_template = '{}/checkpoint/seed-{}_epoch-{}.pth' logger.log("=> loading checkpoint from {}".format(checkpoint_path_template.format(args.save_dir, args.rand_seed, 0))) load(checkpoint_path_template.format(args.save_dir, args.rand_seed, 0), model) init_model = deepcopy(model) angles = [] for epoch in range(xargs.epochs): genotype = load(checkpoint_path_template.format(args.save_dir, args.rand_seed, epoch), model) logger.log("=> loading checkpoint from {}".format(checkpoint_path_template.format(args.dataset, args.rand_seed, epoch))) cur_model = deepcopy(model) angle = get_arch_angle(init_model, cur_model, genotype, search_space) logger.log('[{:}] cal angle : angle={}'.format(epoch, angle)) angle = round(angle,2) angles.append(angle) print(angles)
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'SPOS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space, 'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None) logger.log('search space : {:}'.format(search_space)) model = get_cell_based_tiny_net(model_config) flop, param = get_model_infos(model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) angles = {} for arch_idx in range(0, 10000, 200): checkpoint_path_template = 'output/search-cell-nas-bench-102/result-{}/standalone_arch-{}/checkpoint/seed-{}_epoch-{}.pth' logger.log("=> loading checkpoint from {}".format(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, 0))) epochs = config.epochs - 1 load(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, 0), model) init_model = deepcopy(model) genotype = load(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, epochs), model) logger.log("=> loading checkpoint from {}".format(checkpoint_path_template.format(args.dataset, arch_idx, args.rand_seed, epochs))) cur_model = deepcopy(model) angle = get_arch_angle(init_model, cur_model, genotype, search_space) logger.log('[{:}] cal angle : angle={} | {:}, acc: {}'.format(arch_idx, angle, genotype, get_arch_real_acc(api, genotype, args))) angles[genotype.tostr()] = angle real_acc = {} for key in angles.keys(): real_acc[key] = get_arch_real_acc(api, key, args) assert(real_acc[key] is not None) real_acc = sorted(real_acc.items(), key=lambda d: d[1], reverse=True) angles = sorted(angles.items(), key=lambda d: d[1], reverse=True) angle_rank = {} rank = 1 for value in angles: angle_rank[value[0]] = rank rank += 1 angle_rank_list, real_rank_list = [],[] rank = 1 for value in real_acc: angle_rank_list.append(angle_rank[value[0]]) real_rank_list.append(rank) rank += 1 logger.log('Real_rank_list={}'.format(real_rank_list)) logger.log('Angle_rank_list={}'.format(angle_rank_list)) logger.log('Tau={}'.format(scipy.stats.stats.kendalltau(real_rank_list, angle_rank_list)[0]))
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) if os.path.isdir(xargs.save_dir): if click.confirm( '\nSave directory already exists in {}. Erase?'.format( xargs.save_dir, default=False)): os.system('rm -r ' + xargs.save_dir) assert not os.path.exists(xargs.save_dir) os.mkdir(xargs.save_dir) logger = prepare_logger(args) writer = SummaryWriter(xargs.save_dir) perturb_alpha = None if xargs.perturb: perturb_alpha = random_alpha train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) # config_path = 'configs/nas-benchmark/algos/DARTS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': xargs.model, 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': bool(xargs.affine), 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': bool(xargs.affine), 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) # logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config, xargs.weight_learning_rate) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) # logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training # start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time = time.time(), AverageMeter( ), AverageMeter() total_epoch = config.epochs + config.warmup assert 0 < xargs.early_stop_epoch <= total_epoch - 1 for epoch in range(start_epoch, total_epoch): if epoch >= xargs.early_stop_epoch: logger.log(f"Early stop @ {epoch} epoch.") break if xargs.perturb: epsilon_alpha = 0.03 + (xargs.epsilon_alpha - 0.03) * epoch / total_epoch logger.log(f'epoch {epoch} epsilon_alpha {epsilon_alpha}') else: epsilon_alpha = None w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.gradient_clip, perturb_alpha, epsilon_alpha) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) writer.add_scalar('search/weight_loss', search_w_loss, epoch) writer.add_scalar('search/weight_top1_acc', search_w_top1, epoch) writer.add_scalar('search/weight_top5_acc', search_w_top5, epoch) writer.add_scalar('search/arch_loss', search_a_loss, epoch) writer.add_scalar('search/arch_top1_acc', search_a_top1, epoch) writer.add_scalar('search/arch_top5_acc', search_a_top5, epoch) writer.add_scalar('evaluate/loss', valid_a_loss, epoch) writer.add_scalar('evaluate/top1_acc', valid_a_top1, epoch) writer.add_scalar('evaluate/top5_acc', valid_a_top5, epoch) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) writer.add_scalar('entropy', search_model.entropy, epoch) per_edge_dict = get_per_egde_value_dict(search_model.arch_parameters) for edge_name, edge_val in per_edge_dict.items(): writer.add_scalars(f"cell/{edge_name}", edge_val, epoch) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if xargs.snapshoot > 0 and epoch % xargs.snapshoot == 0: save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'search_model': search_model.state_dict(), }, os.path.join(str(logger.model_dir), f"checkpoint_epoch{epoch}.pth"), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) index = api.query_index_by_arch(genotypes[epoch]) info = api.query_meta_info_by_index( index) # This is an instance of `ArchResults` res_metrics = info.get_metrics( f'{xargs.dataset}', 'ori-test') # This is a dict with metric names as keys # cost_metrics = info.get_comput_costs('cifar10') writer.add_scalar(f'{xargs.dataset}_ground_acc_ori-test', res_metrics['accuracy'], epoch) writer.add_scalar(f'{xargs.dataset}_search_acc', valid_a_top1, epoch) if xargs.dataset.lower() != 'cifar10': writer.add_scalar( f'{xargs.dataset}_ground_acc_x-test', info.get_metrics(f'{xargs.dataset}', 'x-test')['accuracy'], epoch) if find_best: valid_accuracies['best_gt'] = res_metrics['accuracy'] writer.add_scalar(f"{xargs.dataset}_cur_best_gt_acc_ori-test", valid_accuracies['best_gt'], epoch) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) logger.log('{:} : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( args.model, xargs.early_stop_epoch, search_time.sum, genotypes[xargs.early_stop_epoch - 1])) if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[xargs.early_stop_epoch - 1]))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) elif xargs.dataset.startswith('ImageNet16'): split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) imagenet16_split = load_config(split_Fpath, None, None) train_split, valid_split = imagenet16_split.train, imagenet16_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) logger.log('use config from : {:}'.format(xargs.config_path)) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) logger.log('config: {:}'.format(config)) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = test_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space' : search_space}, None) shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config) a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) #flop, param = get_model_infos(shared_cnn, xshape) #logger.log('{:}'.format(shared_cnn)) #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda() last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] baseline = checkpoint['baseline'] valid_accuracies = checkpoint['valid_accuracies'] shared_cnn.load_state_dict( checkpoint['shared_cnn'] ) controller.load_state_dict( checkpoint['controller'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_optimizer.load_state_dict ( checkpoint['w_optimizer'] ) a_optimizer.load_state_dict ( checkpoint['a_optimizer'] ) logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) ) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \ = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \ dict2config({'baseline': baseline, 'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate, 'ctl_entropy_w': xargs.controller_entropy_weight, 'ctl_bl_dec' : xargs.controller_bl_dec}, None), \ epoch_str, xargs.print_freq, logger) logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline)) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) shared_cnn.module.update_arch(best_arch) _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) genotypes[epoch] = best_arch # check the best accuracy valid_accuracies[epoch] = best_valid_acc if best_valid_acc > valid_accuracies['best']: valid_accuracies['best'] = best_valid_acc genotypes['best'] = best_arch find_best = True else: find_best = False logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint({'epoch' : epoch + 1, 'args' : deepcopy(xargs), 'baseline' : baseline, 'shared_cnn' : shared_cnn.state_dict(), 'controller' : controller.state_dict(), 'w_optimizer' : w_optimizer.state_dict(), 'a_optimizer' : a_optimizer.state_dict(), 'w_scheduler' : w_scheduler.state_dict(), 'genotypes' : genotypes, 'valid_accuracies' : valid_accuracies}, model_base_path, logger) last_info = save_checkpoint({ 'epoch': epoch + 1, 'args' : deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) copy_checkpoint(model_base_path, model_best_path, logger) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) logger.log('During searching, the best architecture is {:}'.format(genotypes['best'])) logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples)) final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) shared_cnn.module.update_arch(final_arch) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) logger.log('The Selected Final Architecture : {:}'.format(final_arch)) logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5)) # check the performance from the architecture dataset #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): # logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) #else: # nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) # geno = genotypes[total_epoch-1] # logger.log('The last model is {:}'.format(geno)) # info = nas_bench.query_by_arch( geno ) # if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) # else : logger.log('{:}'.format(info)) # logger.log('-'*100) # geno = genotypes['best'] # logger.log('The best model is {:}'.format(geno)) # info = nas_bench.query_by_arch( geno ) # if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) # else : logger.log('{:}'.format(info)) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) if xargs.overwite_epochs is None: extra_info = {'class_num': class_num, 'xshape': xshape} else: extra_info = { 'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs } config = load_config(xargs.config_path, extra_info, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') model_config = dict2config( dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None) logger.log('search space : {:}'.format(search_space)) logger.log('model config : {:}'.format(model_config)) search_model = get_cell_based_tiny_net(model_config) search_model.set_algo(xargs.algo) logger.log('{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.weights, config) a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) params = count_parameters_in_MB(search_model) logger.log('The parameters of the search model = {:.2f} MB'.format(params)) logger.log('search-space : {:}'.format(search_space)) if bool(xargs.use_api): api = create(None, 'topology', fast_mode=True, verbose=False) else: api = None logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = search_model.cuda(), criterion.cuda( ) # use a single GPU last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] baseline = checkpoint['baseline'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: network.return_topK(1, True)[0] } baseline = None # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) network.set_drop_path( float(epoch + 1) / total_epoch, xargs.drop_path_rate) if xargs.algo == 'gdas': network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('[RESET tau as : {:} and drop_path as {:}]'.format( network.tau, network.drop_path)) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) if xargs.algo == 'enas': ctl_loss, ctl_acc, baseline, ctl_reward \ = train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger) logger.log( '[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}' .format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward)) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) if xargs.algo == 'setn' or xargs.algo == 'enas': network.set_cal_mode('dynamic', genotype) elif xargs.algo == 'gdas': network.set_cal_mode('gdas', None) elif xargs.algo.startswith('darts'): network.set_cal_mode('joint', None) elif xargs.algo == 'random': network.set_cal_mode('urs', None) else: raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo)) logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format( epoch_str, genotype, temp_accuracy)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion, xargs.algo, logger) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'baseline': baseline, 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) if xargs.algo == 'setn' or xargs.algo == 'enas': network.set_cal_mode('dynamic', genotype) elif xargs.algo == 'gdas': network.set_cal_mode('gdas', None) elif xargs.algo.startswith('darts'): network.set_cal_mode('joint', None) elif xargs.algo == 'random': network.set_cal_mode('urs', None) else: raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo)) search_time.update(time.time() - start_time) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion, xargs.algo, logger) logger.log( 'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.' .format(genotype, valid_a_top1)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( xargs.algo, total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200'))) logger.close()
def main(xargs): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { "class_num": class_num, "xshape": xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", (config.batch_size, config.test_batch_size), xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}" .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log("||||||| {:10s} ||||||| Config={:}".format( xargs.dataset, config)) search_space = get_search_spaces("cell", xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( dict( name="SETN", C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats), ), None, ) else: model_config = load_config( xargs.model_config, dict( num_classes=class_num, space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats), ), None, ) logger.log("search space : {:}".format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam( search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) logger.log("search-space : {:}".format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) genotypes = checkpoint["genotypes"] valid_accuracies = checkpoint["valid_accuracies"] search_model.load_state_dict(checkpoint["search_model"]) w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) start_epoch, valid_accuracies, genotypes = 0, { "best": -1 }, { -1: init_genotype } # start training start_time, search_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup, ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format( epoch_str, need_time, min(w_scheduler.get_lr()))) ( search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5, ) = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, ) search_time.update(time.time() - start_time) logger.log( "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s" .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%" .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) network.module.set_cal_mode("dynamic", genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}" .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) # search_model.set_cal_mode('urs') # valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) # logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # search_model.set_cal_mode('joint') # valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) # logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # search_model.set_cal_mode('select') # valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) # logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log("<<<--->>> The {:}-th epoch : {:}".format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(xargs), "search_model": search_model.state_dict(), "w_optimizer": w_optimizer.state_dict(), "a_optimizer": a_optimizer.state_dict(), "w_scheduler": w_scheduler.state_dict(), "genotypes": genotypes, "valid_accuracies": valid_accuracies, }, model_base_path, logger, ) last_info = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(args), "last_checkpoint": save_path, }, logger.path("info"), logger, ) with torch.no_grad(): logger.log("{:}".format(search_model.show_alphas())) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) network.module.set_cal_mode("dynamic", genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%." .format(genotype, valid_a_top1)) logger.log("\n" + "-" * 100) # check the performance from the architecture dataset logger.log( "SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log("{:}".format(api.query_by_arch(genotype, "200"))) logger.close()
def create_result_count( used_seed: int, dataset: Text, arch_config: Dict[Text, Any], results: Dict[Text, Any], dataloader_dict: Dict[Text, Any], ) -> ResultsCount: xresult = ResultsCount( dataset, results["net_state_dict"], results["train_acc1es"], results["train_losses"], results["param"], results["flop"], arch_config, used_seed, results["total_epoch"], None, ) net_config = dict2config( { "name": "infer.tiny", "C": arch_config["channel"], "N": arch_config["num_cells"], "genotype": CellStructure.str2structure(arch_config["arch_str"]), "num_classes": arch_config["class_num"], }, None, ) if "train_times" in results: # new version xresult.update_train_info( results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"], ) xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) else: network = get_cell_based_tiny_net(net_config) network.load_state_dict(xresult.get_net_param()) if dataset == "cifar10-valid": xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda()) xresult.update_OLD_eval( "ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}, ) xresult.update_latency(latencies) elif dataset == "cifar10": xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()) xresult.update_latency(latencies) elif dataset == "cifar100" or dataset == "ImageNet16-120": xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda()) xresult.update_OLD_eval( "x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}, ) loss, top1, top5, latencies = pure_evaluate( dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda()) xresult.update_OLD_eval( "x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}, ) xresult.update_latency(latencies) else: raise ValueError("invalid dataset name : {:}".format(dataset)) return xresult
def main(xargs, myargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(xargs) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'AutoDL-Projects/configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.num_worker) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if not hasattr(xargs, 'model_config') or xargs.model_config is None: model_config = dict2config( dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num, space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None) else: model_config = load_config( xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None) logger.log('search space : {:}'.format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: init_genotype } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) #search_model.set_cal_mode('urs') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('joint') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('select') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'last_checkpoint': save_path, }, logger.path('info'), logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( 'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.' .format(genotype, valid_a_top1)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200'))) logger.close()
def main(xargs): assert torch.cuda.is_available(), "CUDA is not available." torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { "class_num": class_num, "xshape": xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers, ) logger.log( "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}" .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log("||||||| {:10s} ||||||| Config={:}".format( xargs.dataset, config)) search_space = get_search_spaces("cell", xargs.search_space_name) model_config = dict2config( { "name": "DARTS-V2", "C": xargs.channel, "N": xargs.num_cells, "max_nodes": xargs.max_nodes, "num_classes": class_num, "space": search_space, "affine": False, "track_running_stats": bool(xargs.track_running_stats), }, None, ) search_model = get_cell_based_tiny_net(model_config) logger.log("search-model :\n{:}".format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam( search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, ) logger.log("w-optimizer : {:}".format(w_optimizer)) logger.log("a-optimizer : {:}".format(a_optimizer)) logger.log("w-scheduler : {:}".format(w_scheduler)) logger.log("criterion : {:}".format(criterion)) flop, param = get_model_infos(search_model, xshape) # logger.log('{:}'.format(search_model)) logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log("{:} create API = {:} done".format(time_string(), api)) last_info, model_base_path, model_best_path = ( logger.path("info"), logger.path("model"), logger.path("best"), ) network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info["epoch"] checkpoint = torch.load(last_info["last_checkpoint"]) genotypes = checkpoint["genotypes"] valid_accuracies = checkpoint["valid_accuracies"] search_model.load_state_dict(checkpoint["search_model"]) w_scheduler.load_state_dict(checkpoint["w_scheduler"]) w_optimizer.load_state_dict(checkpoint["w_optimizer"]) a_optimizer.load_state_dict(checkpoint["a_optimizer"]) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = ( 0, { "best": -1 }, { -1: search_model.genotype() }, ) # start training start_time, search_time, epoch_time, total_epoch = ( time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup, ) for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = "Time Left: {:}".format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) min_LR = min(w_scheduler.get_lr()) logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format( epoch_str, need_time, min_LR)) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, ) search_time.update(time.time() - start_time) logger.log( "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s" .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%" .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies["best"]: valid_accuracies["best"] = valid_a_top1 genotypes["best"] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log("<<<--->>> The {:}-th epoch : {:}".format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(xargs), "search_model": search_model.state_dict(), "w_optimizer": w_optimizer.state_dict(), "a_optimizer": a_optimizer.state_dict(), "w_scheduler": w_scheduler.state_dict(), "genotypes": genotypes, "valid_accuracies": valid_accuracies, }, model_base_path, logger, ) last_info = save_checkpoint( { "epoch": epoch + 1, "args": deepcopy(args), "last_checkpoint": save_path, }, logger.path("info"), logger, ) if find_best: logger.log( "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%." .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log("arch-parameters :\n{:}".format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log("\n" + "-" * 100) # check the performance from the architecture dataset logger.log( "DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1]), "200")) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) elif xargs.dataset.startswith('ImageNet16'): split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format( xargs.dataset) imagenet16_split = load_config(split_Fpath, None, None) train_split, valid_split = imagenet16_split.train, imagenet16_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) config_path = 'configs/nas-benchmark/algos/DARTS.config' config = load_config(config_path, { 'class_num': class_num, 'xshape': xshape }, logger) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True, num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space }, None) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter( ), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) # check the performance from the architecture dataset #if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): # logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) #else: # nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset) # geno = genotypes[total_epoch-1] # logger.log('The last model is {:}'.format(geno)) # info = nas_bench.query_by_arch( geno ) # if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) # else : logger.log('{:}'.format(info)) # logger.log('-'*100) # geno = genotypes['best'] # logger.log('The best model is {:}'.format(geno)) # info = nas_bench.query_by_arch( geno ) # if info is None: logger.log('Did not find this architecture : {:}.'.format(geno)) # else : logger.log('{:}'.format(info)) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, train_loader, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if False: #last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies = 0, {'best': -1} if len(xargs.supernet_path) > 0: saved_info = torch.load(xargs.supernet_path) assert saved_info[ 'epoch'] == 'finished', "Epoch is not finished in this file" search_model.load_state_dict(saved_info['search_model']) else: # start training supernet start_time = time.time() train_shared_cnn(train_loader, network, criterion, w_scheduler, w_optimizer, xargs.print_freq, logger, config, start_epoch) logger.log( 'Supernet trained. Time-cost = {:.1f} s'.format(time.time() - start_time)) # save supernetweight save_path = save_checkpoint( { 'epoch': 'finished', #epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict() }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': 'finished', #epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) search_start_time = time.time() searcher = search_model.getSearcher(network, train_loader, valid_loader, logger, config) best_cands, performance_dict, performance_trace = searcher.search() logger.log( 'Architect Searched. Time-cost = {:.1f} s'.format(time.time() - search_start_time)) search_result = save_checkpoint( { 'epoch': 'finished', #epoch + 1, 'args': deepcopy(args), 'genotypes': best_cands, 'performance_dict': performance_dict, 'performance_trace': performance_trace }, model_best_path, logger) logger.close()
def evaluate_all_datasets( arch: Text, datasets: List[Text], xpaths: List[Text], splits: List[Text], config_path: Text, seed: int, raw_arch_config, workers, logger, ): machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) all_infos = {"info": machine_info} all_dataset_keys = [] # look all the datasets for dataset, xpath, split in zip(datasets, xpaths, splits): # train valid data train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == "cifar10" or dataset == "cifar100": split_info = load_config( "configs/nas-benchmark/cifar-split.txt", None, None ) elif dataset.startswith("ImageNet16"): split_info = load_config( "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None ) else: raise ValueError("invalid dataset : {:}".format(dataset)) config = load_config( config_path, dict(class_num=class_num, xshape=xshape), logger ) # check whether use splited validation set if bool(split): assert dataset == "cifar10" ValLoaders = { "ori-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) } assert len(train_data) == len(split_info.train) + len( split_info.valid ), "invalid length : {:} vs {:} + {:}".format( len(train_data), len(split_info.train), len(split_info.valid) ) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True, ) ValLoaders["x-valid"] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) if dataset == "cifar10": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": cifar100_splits = load_config( "configs/nas-benchmark/cifar100-test-split.txt", None, None ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar100_splits.xvalid ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar100_splits.xtest ), num_workers=workers, pin_memory=True, ), } elif dataset == "ImageNet16-120": imagenet16_splits = load_config( "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( imagenet16_splits.xvalid ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( imagenet16_splits.xtest ), num_workers=workers, pin_memory=True, ), } else: raise ValueError("invalid dataset : {:}".format(dataset)) dataset_key = "{:}".format(dataset) if bool(split): dataset_key = dataset_key + "-valid" logger.log( "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size, ) ) logger.log( "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) ) for key, value in ValLoaders.items(): logger.log( "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) ) arch_config = dict2config( dict( name="infer.tiny", C=raw_arch_config["channel"], N=raw_arch_config["num_cells"], genotype=arch, num_classes=config.class_num, ), None, ) results = bench_evaluate_for_seed( arch_config, config, train_loader, ValLoaders, seed, logger ) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos["all_dataset_keys"] = all_dataset_keys return all_infos
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, test_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) logger.log('use config from : {:}'.format(xargs.config_path)) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) _, train_loader, valid_loader = get_nas_search_loaders( train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) if hasattr(valid_loader.dataset, 'transforms'): valid_loader.dataset.transforms = deepcopy( train_loader.dataset.transforms) # data loader logger.log( '||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) shared_cnn = get_cell_based_tiny_net(model_config) controller = shared_cnn.create_controller() w_optimizer, w_scheduler, criterion = get_optim_scheduler( shared_cnn.parameters(), config) a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) #flop, param = get_model_infos(shared_cnn, xshape) #logger.log('{:}'.format(shared_cnn)) #logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) shared_cnn, controller, criterion = torch.nn.DataParallel( shared_cnn).cuda(), controller.cuda(), criterion.cuda() last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] baseline = checkpoint['baseline'] valid_accuracies = checkpoint['valid_accuracies'] shared_cnn.load_state_dict(checkpoint['shared_cnn']) controller.load_state_dict(checkpoint['controller']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes, baseline = 0, { 'best': -1 }, {}, None # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log( '\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()), baseline)) cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn( train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) logger.log( '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \ = train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \ dict2config({'baseline': baseline, 'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate, 'ctl_entropy_w': xargs.controller_entropy_weight, 'ctl_bl_dec' : xargs.controller_bl_dec}, None), \ epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s' .format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum)) best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) shared_cnn.module.update_arch(best_arch) _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) genotypes[epoch] = best_arch # check the best accuracy valid_accuracies[epoch] = best_valid_acc if best_valid_acc > valid_accuracies['best']: valid_accuracies['best'] = best_valid_acc genotypes['best'] = best_arch find_best = True else: find_best = False logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'baseline': baseline, 'shared_cnn': shared_cnn.state_dict(), 'controller': controller.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, best_valid_acc)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) logger.log('During searching, the best architecture is {:}'.format( genotypes['best'])) logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best'])) logger.log('Randomly select {:} architectures and select the best.'.format( xargs.controller_num_samples)) start_time = time.time() final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples) search_time.update(time.time() - start_time) shared_cnn.module.update_arch(final_arch) final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) logger.log('The Selected Final Architecture : {:}'.format(final_arch)) logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format( final_loss, final_top1, final_top5)) logger.log( 'ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, final_arch)) if api is not None: logger.log('{:}'.format(api.query_by_arch(final_arch))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) elif xargs.dataset.startswith('ImageNet16'): split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format( xargs.dataset) imagenet16_split = load_config(split_Fpath, None, None) train_split, valid_split = imagenet16_split.train, imagenet16_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) #config_path = 'configs/nas-benchmark/algos/SETN.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True, num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) logger.log('search space : {:}'.format(search_space)) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space : {:}'.format(search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype)) #search_model.set_cal_mode('urs') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('joint') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) #search_model.set_cal_mode('select') #valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion) #logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 genotypes[epoch] = genotype logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() # the final post procedure : count the time start_time = time.time() genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) network.module.set_cal_mode('dynamic', genotype) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( 'Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.' .format(genotype, valid_a_top1)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotype)) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None and not args.constrain: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'inp_size': 0, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) elif xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'inp_size': 32, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) #logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() #network, criterion = search_model.cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup sampled_weights = [] for epoch in range(start_epoch, total_epoch + config.t_epochs): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time( epoch_time.val * (total_epoch - epoch + config.t_epochs), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format( epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()))) if epoch < total_epoch: search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.bilevel) else: search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \ = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) if (epoch + 1) % 50 == 0 and not config.t_epochs: weights = search_model.sample_weights(100) sampled_weights.append(weights) elif (epoch + 1) == total_epoch and config.t_epochs: weights = search_model.sample_weights(100) sampled_weights.append(weights) arch_iter = iter(weights) # validate with single arch single_weight = search_model.sample_weights(1)[0] single_valid_acc = AverageMeter() network.eval() for i in range(10): try: val_input, val_target = next(valid_iter) except Exception as e: valid_iter = iter(valid_loader) val_input, val_target = next(valid_iter) n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=single_weight) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) single_valid_acc.update(val_acc1.item(), n_val) logger.log('[{:}] valid : accuracy = {:.2f}'.format( epoch_str, single_valid_acc.avg)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False if epoch < total_epoch: genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None and epoch < total_epoch: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() network.eval() # Evaluate the architectures sampled throughout the search for i in range(len(sampled_weights) - 1): logger.log('Sample eval : epoch {}'.format((i + 1) * 50 - 1)) for w in sampled_weights[i]: sample_valid_acc = AverageMeter() for i in range(10): try: val_input, val_target = next(valid_iter) except Exception as e: valid_iter = iter(valid_loader) val_input, val_target = next(valid_iter) n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=w) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) sample_valid_acc.update(val_acc1.item(), n_val) w_gene = search_model.genotype(w) if api is not None: ind = api.query_index_by_arch(w_gene) info = api.query_meta_info_by_index(ind) metrics = info.get_metrics('cifar10', 'ori-test') acc = metrics['accuracy'] else: acc = 0.0 logger.log( 'sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format( sample_valid_acc.avg, acc)) # Evaluate the final sampling separately to find the top 10 architectures logger.log('Final sample eval') final_archs = [] for w in sampled_weights[-1]: sample_valid_acc = AverageMeter() for i in range(10): try: val_input, val_target = next(valid_iter) except Exception as e: valid_iter = iter(valid_loader) val_input, val_target = next(valid_iter) n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=w) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) sample_valid_acc.update(val_acc1.item(), n_val) w_gene = search_model.genotype(w) if api is not None: ind = api.query_index_by_arch(w_gene) info = api.query_meta_info_by_index(ind) metrics = info.get_metrics('cifar10', 'ori-test') acc = metrics['accuracy'] else: acc = 0.0 logger.log('sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format( sample_valid_acc.avg, acc)) final_archs.append((w, sample_valid_acc.avg)) top_10 = sorted(final_archs, key=lambda x: x[1], reverse=True)[:10] # Evaluate the top 10 architectures on the entire validation set logger.log('Evaluating top archs') for w, prev_acc in top_10: full_valid_acc = AverageMeter() for val_input, val_target in valid_loader: n_val = val_input.size(0) with torch.no_grad(): val_target = val_target.cuda(non_blocking=True) _, logits, _ = network(val_input, weights=w) val_acc1, val_acc5 = obtain_accuracy(logits.data, val_target.data, topk=(1, 5)) full_valid_acc.update(val_acc1.item(), n_val) w_gene = search_model.genotype(w) logger.log('genotype {}'.format(w_gene)) if api is not None: ind = api.query_index_by_arch(w_gene) info = api.query_meta_info_by_index(ind) metrics = info.get_metrics('cifar10', 'ori-test') acc = metrics['accuracy'] else: acc = 0.0 logger.log( 'full valid : val_acc = {:.2f} test_acc = {:.2f} pval_acc = {:.2f}' .format(full_valid_acc.avg, acc, prev_acc)) logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1]))) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format( epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch - 1], '200'))) logger.close()
def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text], splits: List[Text], config_path: Text, seed: int, workers: int, logger): machine_info = get_machine_info() all_infos = {'info': machine_info} all_dataset_keys = [] # look all the dataset for dataset, xpath, split in zip(datasets, xpaths, splits): # the train and valid data train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == 'cifar10' or dataset == 'cifar100': split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) elif dataset.startswith('ImageNet16'): split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) else: raise ValueError('invalid dataset : {:}'.format(dataset)) config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) # check whether use the splitted validation set if bool(split): assert dataset == 'cifar10' ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) ValLoaders['x-valid'] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) if dataset == 'cifar10': ValLoaders = {'ori-test': valid_loader} elif dataset == 'cifar100': cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) ValLoaders = {'ori-test': valid_loader, 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) } elif dataset == 'ImageNet16-120': imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) ValLoaders = {'ori-test': valid_loader, 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) } else: raise ValueError('invalid dataset : {:}'.format(dataset)) dataset_key = '{:}'.format(dataset) if bool(split): dataset_key = dataset_key + '-valid' logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) for key, value in ValLoaders.items(): logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) # arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| # this genotype is the architecture with the highest accuracy on CIFAR-100 validation set genotype = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|' arch_config = dict2config(dict(name='infer.shape.tiny', channels=channels, genotype=genotype, num_classes=class_num), None) results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append( dataset_key ) all_infos['all_dataset_keys'] = all_dataset_keys return all_infos
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) elif xargs.dataset.startswith('ImageNet16'): split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format( xargs.dataset) imagenet16_split = load_config(split_Fpath, None, None) train_split, valid_split = imagenet16_split.train, imagenet16_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) else: raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) config_path = 'configs/nas-benchmark/algos/DARTS.config' config = load_config(config_path, { 'class_num': class_num, 'xshape': xshape }, logger) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True, num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space }, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) #logger.log('{:}'.format(search_model)) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ (config.batch_size, config.test_batch_size), xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) model_config = dict2config( { 'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.parameters(), config) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {} # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) search_w_loss, search_w_top1, search_w_top5 = search_func( search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( valid_loader, network, criterion) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num) logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format( epoch_str, cur_arch, cur_valid_acc)) genotypes[epoch] = cur_arch # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 find_best = True else: find_best = False # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch]))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 200) logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum)) start_time = time.time() best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log( 'RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.' .format(best_arch, best_acc, search_time.sum)) if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch))) logger.close()