def simplify(save_dir, save_name, nets, total, sup_config): dataloader_dict = get_nas_bench_loaders(6) hps, seeds = ['12', '200'], set() for hp in hps: sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth'))) seed2names = defaultdict(list) for ckp in ckps: parts = re.split('-|\.', ckp.name) seed2names[parts[3]].append(ckp.name) print('DIR : {:}'.format(sub_save_dir)) nums = [] for seed, xlist in seed2names.items(): seeds.add(seed) nums.append(len(xlist)) print(' [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist))) assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total) print('{:} start simplify the checkpoint.'.format(time_string())) datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120') # Create the directory to save the processed data # full_save_dir contains all benchmark files with trained weights. # simplify_save_dir contains all benchmark files without trained weights. full_save_dir = save_dir / (save_name + '-FULL') simple_save_dir = save_dir / (save_name + '-SIMPLIFY') full_save_dir.mkdir(parents=True, exist_ok=True) simple_save_dir.mkdir(parents=True, exist_ok=True) # all data in memory arch2infos, evaluated_indexes = dict(), set() end_time, arch_time = time.time(), AverageMeter() # save the meta information temp_final_infos = {'meta_archs' : nets, 'total_archs': total, 'arch2infos' : None, 'evaluated_indexes': set()} pickle_save(temp_final_infos, str(full_save_dir / 'meta.pickle')) pickle_save(temp_final_infos, str(simple_save_dir / 'meta.pickle')) for index in tqdm(range(total)): arch_str = nets[index] hp2info = OrderedDict() full_save_path = full_save_dir / '{:06d}.pickle'.format(index) simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index) for hp in hps: sub_save_dir = save_dir / 'raw-data-{:}'.format(hp) ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds] ckps = [x for x in ckps if x.exists()] if len(ckps) == 0: raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp)) arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict) hp2info[hp] = arch_info hp2info = correct_time_related_info(index, hp2info) evaluated_indexes.add(index) to_save_data = OrderedDict({'12': hp2info['12'].state_dict(), '200': hp2info['200'].state_dict()}) pickle_save(to_save_data, str(full_save_path)) for hp in hps: hp2info[hp].clear_params() to_save_data = OrderedDict({'12': hp2info['12'].state_dict(), '200': hp2info['200'].state_dict()}) pickle_save(to_save_data, str(simple_save_path)) arch2infos[index] = to_save_data # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True)) # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) print('{:} {:} done.'.format(time_string(), save_name)) final_infos = {'meta_archs' : nets, 'total_archs': total, 'arch2infos' : arch2infos, 'evaluated_indexes': evaluated_indexes} save_file_name = save_dir / '{:}.pickle'.format(save_name) pickle_save(final_infos, str(save_file_name)) # move the benchmark file to a new path hd5sum = get_md5_file(str(save_file_name) + '.pbz2') hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum) shutil.move(str(save_file_name) + '.pbz2', hd5_file_name) print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name)) # move the directory to a new path hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum) hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum) shutil.move(full_save_dir, hd5_full_save_dir) shutil.move(simple_save_dir, hd5_simple_save_dir)
def simplify(save_dir, meta_file, basestr, target_dir): meta_infos = torch.load(meta_file, map_location="cpu") meta_archs = meta_infos["archs"] # a list of architecture strings meta_num_archs = meta_infos["total"] assert meta_num_archs == len( meta_archs), "invalid number of archs : {:} vs {:}".format( meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) print("{:} find {:} directories used to save checkpoints".format( time_string(), len(sub_model_dirs))) subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 num_seeds = defaultdict(lambda: 0) for index, sub_dir in enumerate(sub_model_dirs): xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth")) arch_indexes = set() for checkpoint in xcheckpoints: temp_names = checkpoint.name.split("-") assert (len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed"), "invalid checkpoint name : {:}".format( checkpoint.name) arch_indexes.add(temp_names[1]) subdir2archs[sub_dir] = sorted(list(arch_indexes)) num_evaluated_arch += len(arch_indexes) # count number of seeds for each architecture for arch_index in arch_indexes: num_seeds[len( list(sub_dir.glob( "arch-{:}-seed-*.pth".format(arch_index))))] += 1 print( "{:} There are {:5d} architectures that have been evaluated ({:} in total)." .format(time_string(), num_evaluated_arch, meta_num_archs)) for key in sorted(list(num_seeds.keys())): print( "{:} There are {:5d} architectures that are evaluated {:} times.". format(time_string(), num_seeds[key], key)) dataloader_dict = get_nas_bench_loaders(6) to_save_simply = save_dir / "simplifies" to_save_allarc = save_dir / "simplifies" / "architectures" if not to_save_simply.exists(): to_save_simply.mkdir(parents=True, exist_ok=True) if not to_save_allarc.exists(): to_save_allarc.mkdir(parents=True, exist_ok=True) assert (save_dir / target_dir) in subdir2archs, "can not find {:}".format(target_dir) arch2infos, datasets = {}, ( "cifar10-valid", "cifar10", "cifar100", "ImageNet16-120", ) evaluated_indexes = set() target_full_dir = save_dir / target_dir target_less_dir = save_dir / "{:}-LESS".format(target_dir) arch_indexes = subdir2archs[target_full_dir] num_seeds = defaultdict(lambda: 0) end_time = time.time() arch_time = AverageMeter() for idx, arch_index in enumerate(arch_indexes): checkpoints = list( target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) ckps_less = list( target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index))) # create the arch info for each architecture try: arch_info_full = account_one_arch( arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict, ) arch_info_less = account_one_arch( arch_index, meta_archs[int(arch_index)], ckps_less, datasets, dataloader_dict, ) num_seeds[len(checkpoints)] += 1 except: print("Loading {:} failed, : {:}".format(arch_index, checkpoints)) continue assert (int(arch_index) not in evaluated_indexes ), "conflict arch-index : {:}".format(arch_index) assert (0 <= int(arch_index) < len(meta_archs) ), "invalid arch-index {:} (not found in meta_archs)".format( arch_index) arch_info = {"full": arch_info_full, "less": arch_info_less} evaluated_indexes.add(int(arch_index)) arch2infos[int(arch_index)] = arch_info # to correct the latency and training_time info. arch_info_full, arch_info_less = correct_time_related_info( int(arch_index), arch_info_full, arch_info_less) to_save_data = OrderedDict(full=arch_info_full.state_dict(), less=arch_info_less.state_dict()) torch.save(to_save_data, to_save_allarc / "{:}-FULL.pth".format(arch_index)) arch_info["full"].clear_params() arch_info["less"].clear_params() torch.save(to_save_data, to_save_allarc / "{:}-SIMPLE.pth".format(arch_index)) # measure elapsed time arch_time.update(time.time() - end_time) end_time = time.time() need_time = "{:}".format( convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1), True)) print("{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format( time_string(), target_dir, idx, len(arch_indexes), arch_index, need_time)) # measure time xstrs = [ "{:}:{:03d}".format(key, num_seeds[key]) for key in sorted(list(num_seeds.keys())) ] print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs)) final_infos = { "meta_archs": meta_archs, "total_archs": meta_num_archs, "basestr": basestr, "arch2infos": arch2infos, "evaluated_indexes": evaluated_indexes, } save_file_name = to_save_simply / "{:}.pth".format(target_dir) torch.save(final_infos, save_file_name) print("Save {:} / {:} architecture results into {:}.".format( len(evaluated_indexes), meta_num_archs, save_file_name))