def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text], splits: List[Text], config_path: Text, seed: int, workers: int, logger): machine_info = get_machine_info() all_infos = {'info': machine_info} all_dataset_keys = [] # look all the dataset for dataset, xpath, split in zip(datasets, xpaths, splits): # the train and valid data train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == 'cifar10' or dataset == 'cifar100': split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None) elif dataset.startswith('ImageNet16'): split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) else: raise ValueError('invalid dataset : {:}'.format(dataset)) config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) # check whether use the splitted validation set if bool(split): assert dataset == 'cifar10' ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)} assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid)) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True) ValLoaders['x-valid'] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) if dataset == 'cifar10': ValLoaders = {'ori-test': valid_loader} elif dataset == 'cifar100': cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None) ValLoaders = {'ori-test': valid_loader, 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True), 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True) } elif dataset == 'ImageNet16-120': imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None) ValLoaders = {'ori-test': valid_loader, 'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True), 'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True) } else: raise ValueError('invalid dataset : {:}'.format(dataset)) dataset_key = '{:}'.format(dataset) if bool(split): dataset_key = dataset_key + '-valid' logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size)) logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) for key, value in ValLoaders.items(): logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) # arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| # this genotype is the architecture with the highest accuracy on CIFAR-100 validation set genotype = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|' arch_config = dict2config(dict(name='infer.shape.tiny', channels=channels, genotype=genotype, num_classes=class_num), None) results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) all_infos[dataset_key] = results all_dataset_keys.append( dataset_key ) all_infos['all_dataset_keys'] = all_dataset_keys return all_infos
def evaluate_all_datasets( arch: Text, datasets: List[Text], xpaths: List[Text], splits: List[Text], config_path: Text, seed: int, raw_arch_config, workers, logger, ): machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) all_infos = {"info": machine_info} all_dataset_keys = [] # look all the datasets for dataset, xpath, split in zip(datasets, xpaths, splits): # train valid data train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) # load the configuration if dataset == "cifar10" or dataset == "cifar100": split_info = load_config( "configs/nas-benchmark/cifar-split.txt", None, None ) elif dataset.startswith("ImageNet16"): split_info = load_config( "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None ) else: raise ValueError("invalid dataset : {:}".format(dataset)) config = load_config( config_path, dict(class_num=class_num, xshape=xshape), logger ) # check whether use splited validation set if bool(split): assert dataset == "cifar10" ValLoaders = { "ori-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) } assert len(train_data) == len(split_info.train) + len( split_info.valid ), "invalid length : {:} vs {:} + {:}".format( len(train_data), len(split_info.train), len(split_info.valid) ) train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True, ) ValLoaders["x-valid"] = valid_loader else: # data loader train_loader = torch.utils.data.DataLoader( train_data, batch_size=config.batch_size, shuffle=True, num_workers=workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True, ) if dataset == "cifar10": ValLoaders = {"ori-test": valid_loader} elif dataset == "cifar100": cifar100_splits = load_config( "configs/nas-benchmark/cifar100-test-split.txt", None, None ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar100_splits.xvalid ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( cifar100_splits.xtest ), num_workers=workers, pin_memory=True, ), } elif dataset == "ImageNet16-120": imagenet16_splits = load_config( "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None ) ValLoaders = { "ori-test": valid_loader, "x-valid": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( imagenet16_splits.xvalid ), num_workers=workers, pin_memory=True, ), "x-test": torch.utils.data.DataLoader( valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( imagenet16_splits.xtest ), num_workers=workers, pin_memory=True, ), } else: raise ValueError("invalid dataset : {:}".format(dataset)) dataset_key = "{:}".format(dataset) if bool(split): dataset_key = dataset_key + "-valid" logger.log( "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size, ) ) logger.log( "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) ) for key, value in ValLoaders.items(): logger.log( "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) ) arch_config = dict2config( dict( name="infer.tiny", C=raw_arch_config["channel"], N=raw_arch_config["num_cells"], genotype=arch, num_classes=config.class_num, ), None, ) results = bench_evaluate_for_seed( arch_config, config, train_loader, ValLoaders, seed, logger ) all_infos[dataset_key] = results all_dataset_keys.append(dataset_key) all_infos["all_dataset_keys"] = all_dataset_keys return all_infos