def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults): # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth cifar010_latency = (api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200")) / 2 arch_info_full.reset_latency("cifar10-valid", None, cifar010_latency) arch_info_full.reset_latency("cifar10", None, cifar010_latency) arch_info_less.reset_latency("cifar10-valid", None, cifar010_latency) arch_info_less.reset_latency("cifar10", None, cifar010_latency) cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200") arch_info_full.reset_latency("cifar100", None, cifar100_latency) arch_info_less.reset_latency("cifar100", None, cifar100_latency) image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200") arch_info_full.reset_latency("ImageNet16-120", None, image_latency) arch_info_less.reset_latency("ImageNet16-120", None, image_latency) train_per_epoch_time = list( arch_info_less.query("cifar10-valid", 777).train_times.values()) train_per_epoch_time = sum(train_per_epoch_time) / len( train_per_epoch_time) eval_ori_test_time, eval_x_valid_time = [], [] for key, value in arch_info_less.query("cifar10-valid", 777).eval_times.items(): if key.startswith("ori-test@"): eval_ori_test_time.append(value) elif key.startswith("x-valid@"): eval_x_valid_time.append(value) else: raise ValueError("-- {:} --".format(key)) eval_ori_test_time, eval_x_valid_time = float( np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) nums = { "ImageNet16-120-train": 151700, "ImageNet16-120-valid": 3000, "ImageNet16-120-test": 6000, "cifar10-valid-train": 25000, "cifar10-valid-valid": 25000, "cifar10-train": 50000, "cifar10-test": 10000, "cifar100-train": 50000, "cifar100-test": 10000, "cifar100-valid": 5000, } eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / ( nums["cifar10-valid-valid"] + nums["cifar10-test"]) for arch_info in [arch_info_less, arch_info_full]: arch_info.reset_pseudo_train_times( "cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"], ) arch_info.reset_pseudo_train_times( "cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"], ) arch_info.reset_pseudo_train_times( "cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"], ) arch_info.reset_pseudo_train_times( "ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"], ) arch_info.reset_pseudo_eval_times( "cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"], ) arch_info.reset_pseudo_eval_times( "cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]) arch_info.reset_pseudo_eval_times( "cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]) arch_info.reset_pseudo_eval_times( "cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]) arch_info.reset_pseudo_eval_times( "cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]) arch_info.reset_pseudo_eval_times( "cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]) arch_info.reset_pseudo_eval_times( "ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"], ) arch_info.reset_pseudo_eval_times( "ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"], ) arch_info.reset_pseudo_eval_times( "ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"], ) # arch_info_full.debug_test() # arch_info_less.debug_test() return arch_info_full, arch_info_less
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch_info_less: ArchResults): # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200') + api.get_latency(arch_index, 'cifar10', hp='200')) / 2 arch_info_full.reset_latency('cifar10-valid', None, cifar010_latency) arch_info_full.reset_latency('cifar10', None, cifar010_latency) arch_info_less.reset_latency('cifar10-valid', None, cifar010_latency) arch_info_less.reset_latency('cifar10', None, cifar010_latency) cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200') arch_info_full.reset_latency('cifar100', None, cifar100_latency) arch_info_less.reset_latency('cifar100', None, cifar100_latency) image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200') arch_info_full.reset_latency('ImageNet16-120', None, image_latency) arch_info_less.reset_latency('ImageNet16-120', None, image_latency) train_per_epoch_time = list( arch_info_less.query('cifar10-valid', 777).train_times.values()) train_per_epoch_time = sum(train_per_epoch_time) / len( train_per_epoch_time) eval_ori_test_time, eval_x_valid_time = [], [] for key, value in arch_info_less.query('cifar10-valid', 777).eval_times.items(): if key.startswith('ori-test@'): eval_ori_test_time.append(value) elif key.startswith('x-valid@'): eval_x_valid_time.append(value) else: raise ValueError('-- {:} --'.format(key)) eval_ori_test_time, eval_x_valid_time = float( np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) nums = { 'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 'ImageNet16-120-test': 6000, 'cifar10-valid-train': 25000, 'cifar10-valid-valid': 25000, 'cifar10-train': 50000, 'cifar10-test': 10000, 'cifar100-train': 50000, 'cifar100-test': 10000, 'cifar100-valid': 5000 } eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / ( nums['cifar10-valid-valid'] + nums['cifar10-test']) for arch_info in [arch_info_less, arch_info_full]: arch_info.reset_pseudo_train_times( 'cifar10-valid', None, train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-valid-train']) arch_info.reset_pseudo_train_times( 'cifar10', None, train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar10-train']) arch_info.reset_pseudo_train_times( 'cifar100', None, train_per_epoch_time / nums['cifar10-valid-train'] * nums['cifar100-train']) arch_info.reset_pseudo_train_times( 'ImageNet16-120', None, train_per_epoch_time / nums['cifar10-valid-train'] * nums['ImageNet16-120-train']) arch_info.reset_pseudo_eval_times( 'cifar10-valid', None, 'x-valid', eval_per_sample * nums['cifar10-valid-valid']) arch_info.reset_pseudo_eval_times( 'cifar10-valid', None, 'ori-test', eval_per_sample * nums['cifar10-test']) arch_info.reset_pseudo_eval_times( 'cifar10', None, 'ori-test', eval_per_sample * nums['cifar10-test']) arch_info.reset_pseudo_eval_times( 'cifar100', None, 'x-valid', eval_per_sample * nums['cifar100-valid']) arch_info.reset_pseudo_eval_times( 'cifar100', None, 'x-test', eval_per_sample * nums['cifar100-valid']) arch_info.reset_pseudo_eval_times( 'cifar100', None, 'ori-test', eval_per_sample * nums['cifar100-test']) arch_info.reset_pseudo_eval_times( 'ImageNet16-120', None, 'x-valid', eval_per_sample * nums['ImageNet16-120-valid']) arch_info.reset_pseudo_eval_times( 'ImageNet16-120', None, 'x-test', eval_per_sample * nums['ImageNet16-120-valid']) arch_info.reset_pseudo_eval_times( 'ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test']) # arch_info_full.debug_test() # arch_info_less.debug_test() return arch_info_full, arch_info_less