def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): machine_info, arch_config = get_machine_info(), deepcopy(arch_config) all_infos = {'info': machine_info} all_dataset_keys = [] # look all the datasets for dataset, xpath, split in zip(datasets, xpaths, splits): # train valid data task = None train_data, valid_data, xshape, class_num = get_datasets( dataset, xpath, -1, task) # load the configuration if dataset in ['mnist', 'svhn', 'aircraft', 'pets']: if use_less: config_path = 'nas_bench_201/configs/nas-benchmark/LESS.config' else: config_path = 'nas_bench_201/configs/nas-benchmark/{}.config'.format( dataset) p = 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format( dataset) if not os.path.exists(p): import json label_list = list(range(len(train_data))) random.shuffle(label_list) strlist = [str(label_list[i]) for i in range(len(label_list))] splited = { 'train': ["int", strlist[:len(train_data) // 2]], 'valid': ["int", strlist[len(train_data) // 2:]] } with open(p, 'w') as f: f.write(json.dumps(splited)) split_info = load_config( 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format( dataset), None, None) else: raise ValueError('invalid dataset : {:}'.format(dataset)) config = load_config(config_path, { 'class_num': class_num, 'xshape': xshape }, logger) # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) splits = load_config( 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format( dataset), None, None) ValLoaders = { 'ori-test': valid_loader, 'x-valid': torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( splits.xvalid), num_workers=workers, pin_memory=True), 'x-test': torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( splits.xtest), num_workers=workers, pin_memory=True) } dataset_key = '{:}'.format(dataset) if bool(split): dataset_key = dataset_key + '-valid' logger.log( 'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}' .format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format( dataset_key, config)) for key, value in ValLoaders.items(): logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format( key, len(value))) results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos['all_dataset_keys'] = all_dataset_keys return all_infos
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): machine_info, arch_config = get_machine_info(), deepcopy(arch_config) all_infos = {'info': machine_info} all_dataset_keys = [] # look all the datasets for dataset, xpath, split in zip(datasets, xpaths, splits): # train valid data train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configurature if dataset == 'cifar10' or dataset == 'cifar100': if use_less: config_path = 'configs/nas-benchmark/LESS.config' else : config_path = 'configs/nas-benchmark/CIFAR.config' split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) elif dataset.startswith('ImageNet16'): if use_less: config_path = 'configs/nas-benchmark/LESS.config' else : config_path = 'configs/nas-benchmark/ImageNet-16.config' split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) else: raise ValueError('invalid dataset : {:}'.format(dataset)) config = load_config(config_path, \ {'class_num': class_num, 'xshape' : xshape}, \ logger) # check whether use splited validation set if bool(split): assert dataset == 'cifar10' ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) ValLoaders['x-valid'] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) if dataset == 'cifar10': ValLoaders = {'ori-test': valid_loader} elif dataset == 'cifar100': cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) ValLoaders = {'ori-test': valid_loader, 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) } elif dataset == 'ImageNet16-120': imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) ValLoaders = {'ori-test': valid_loader, 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) } else: raise ValueError('invalid dataset : {:}'.format(dataset)) dataset_key = '{:}'.format(dataset) if bool(split): dataset_key = dataset_key + '-valid' logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) for key, value in ValLoaders.items(): logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append( dataset_key ) all_infos['all_dataset_keys'] = all_dataset_keys return all_infos
def train_and_eval(args, data, logger, arch, nas_bench, extra_info, dataname='cifar10-valid', use_012_epoch_training=True, use_loss_extrapolation=False): if use_loss_extrapolation: arch_index = nas_bench.query_index_by_arch(arch) print('=' * 80) print('ARCH INDEX: {0}'.format(arch_index)) print('=' * 80) nepoch = 12 assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) config, train_loader, valid_loader = data arch_config = {'channel': args.channel, 'num_cells': args.num_cells} t0 = time.time() print(config, 'CONFIG') results = evaluate_for_seed(arch_config, config, arch, train_loader, {'valid': valid_loader}, seed=random.randint(0, (2**32) - 1), logger=logger) key = '{:}@{:}'.format('valid', nepoch - 1) valid_acc = results['valid_acc1es'][key] time_cost = None info = nas_bench.get_more_info(arch_index, dataname, None, True) valid_acc2, time_cost2 = info[ 'valid-accuracy'], info['train-all-time'] + info['valid-per-time'] print('Taken: {0:.2f} s'.format(time.time() - t0)) print('API time: {0:.2f}'.format(time_cost2)) print('Validation accuracy 25 epochs: {0:.2f}'.format(valid_acc)) print('Validation accuracy reference: {0:.2f}'.format(valid_acc2)) elif use_012_epoch_training and nas_bench is not None: arch_index = nas_bench.query_index_by_arch(arch) print('=' * 80) print('ARCH INDEX: {0}'.format(arch_index)) print('=' * 80) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) info = nas_bench.get_more_info(arch_index, dataname, None, True) valid_acc, time_cost = info[ 'valid-accuracy'], info['train-all-time'] + info['valid-per-time'] #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs elif not use_012_epoch_training and nas_bench is not None: # Please contact me if you want to use the following logic, because it has some potential issues. # Please use `use_012_epoch_training=False` for cifar10 only. # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25 assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) try: valid_acc = info['valid-accuracy'] except: valid_acc = info['valtest-accuracy'] time_cost = None else: # train a model from scratch. raise ValueError('NOT IMPLEMENT YET') if time_cost is None: # The following codes are used to estimate the time cost. # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared. nums = { 'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 'cifar10-valid-train': 25000, 'cifar10-valid-valid': 25000, 'cifar100-train': 50000, 'cifar100-valid': 5000 } xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False) info = nas_bench.get_more_info( arch_index, dataname, nepoch, False, True ) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). cost = nas_bench.get_cost_info(arch_index, dataname, False) estimated_train_cost = xoinfo['train-per-time'] / nums[ 'cifar10-valid-train'] * nums['{:}-train'.format( dataname)] / xocost['latency'] * cost['latency'] * nepoch estimated_valid_cost = xoinfo['valid-per-time'] / nums[ 'cifar10-valid-valid'] * nums['{:}-valid'.format( dataname)] / xocost['latency'] * cost['latency'] time_cost = estimated_train_cost + estimated_valid_cost print('Adjusted time cost: {0:.2f}'.format(time_cost)) return valid_acc, time_cost