示例#1
0
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults,
                              arch_info_less: ArchResults):
    # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
    cifar010_latency = (api.get_latency(arch_index, "cifar10-valid", hp="200")
                        + api.get_latency(arch_index, "cifar10", hp="200")) / 2
    arch_info_full.reset_latency("cifar10-valid", None, cifar010_latency)
    arch_info_full.reset_latency("cifar10", None, cifar010_latency)
    arch_info_less.reset_latency("cifar10-valid", None, cifar010_latency)
    arch_info_less.reset_latency("cifar10", None, cifar010_latency)

    cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200")
    arch_info_full.reset_latency("cifar100", None, cifar100_latency)
    arch_info_less.reset_latency("cifar100", None, cifar100_latency)

    image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200")
    arch_info_full.reset_latency("ImageNet16-120", None, image_latency)
    arch_info_less.reset_latency("ImageNet16-120", None, image_latency)

    train_per_epoch_time = list(
        arch_info_less.query("cifar10-valid", 777).train_times.values())
    train_per_epoch_time = sum(train_per_epoch_time) / len(
        train_per_epoch_time)
    eval_ori_test_time, eval_x_valid_time = [], []
    for key, value in arch_info_less.query("cifar10-valid",
                                           777).eval_times.items():
        if key.startswith("ori-test@"):
            eval_ori_test_time.append(value)
        elif key.startswith("x-valid@"):
            eval_x_valid_time.append(value)
        else:
            raise ValueError("-- {:} --".format(key))
    eval_ori_test_time, eval_x_valid_time = float(
        np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
    nums = {
        "ImageNet16-120-train": 151700,
        "ImageNet16-120-valid": 3000,
        "ImageNet16-120-test": 6000,
        "cifar10-valid-train": 25000,
        "cifar10-valid-valid": 25000,
        "cifar10-train": 50000,
        "cifar10-test": 10000,
        "cifar100-train": 50000,
        "cifar100-test": 10000,
        "cifar100-valid": 5000,
    }
    eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
        nums["cifar10-valid-valid"] + nums["cifar10-test"])
    for arch_info in [arch_info_less, arch_info_full]:
        arch_info.reset_pseudo_train_times(
            "cifar10-valid",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["cifar10-valid-train"],
        )
        arch_info.reset_pseudo_train_times(
            "cifar10",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["cifar10-train"],
        )
        arch_info.reset_pseudo_train_times(
            "cifar100",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["cifar100-train"],
        )
        arch_info.reset_pseudo_train_times(
            "ImageNet16-120",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["ImageNet16-120-train"],
        )
        arch_info.reset_pseudo_eval_times(
            "cifar10-valid",
            None,
            "x-valid",
            eval_per_sample * nums["cifar10-valid-valid"],
        )
        arch_info.reset_pseudo_eval_times(
            "cifar10-valid", None, "ori-test",
            eval_per_sample * nums["cifar10-test"])
        arch_info.reset_pseudo_eval_times(
            "cifar10", None, "ori-test",
            eval_per_sample * nums["cifar10-test"])
        arch_info.reset_pseudo_eval_times(
            "cifar100", None, "x-valid",
            eval_per_sample * nums["cifar100-valid"])
        arch_info.reset_pseudo_eval_times(
            "cifar100", None, "x-test",
            eval_per_sample * nums["cifar100-valid"])
        arch_info.reset_pseudo_eval_times(
            "cifar100", None, "ori-test",
            eval_per_sample * nums["cifar100-test"])
        arch_info.reset_pseudo_eval_times(
            "ImageNet16-120",
            None,
            "x-valid",
            eval_per_sample * nums["ImageNet16-120-valid"],
        )
        arch_info.reset_pseudo_eval_times(
            "ImageNet16-120",
            None,
            "x-test",
            eval_per_sample * nums["ImageNet16-120-valid"],
        )
        arch_info.reset_pseudo_eval_times(
            "ImageNet16-120",
            None,
            "ori-test",
            eval_per_sample * nums["ImageNet16-120-test"],
        )
    # arch_info_full.debug_test()
    # arch_info_less.debug_test()
    return arch_info_full, arch_info_less
示例#2
0
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults,
                              arch_info_less: ArchResults):
    # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
    cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200')
                        + api.get_latency(arch_index, 'cifar10', hp='200')) / 2
    arch_info_full.reset_latency('cifar10-valid', None, cifar010_latency)
    arch_info_full.reset_latency('cifar10', None, cifar010_latency)
    arch_info_less.reset_latency('cifar10-valid', None, cifar010_latency)
    arch_info_less.reset_latency('cifar10', None, cifar010_latency)

    cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200')
    arch_info_full.reset_latency('cifar100', None, cifar100_latency)
    arch_info_less.reset_latency('cifar100', None, cifar100_latency)

    image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200')
    arch_info_full.reset_latency('ImageNet16-120', None, image_latency)
    arch_info_less.reset_latency('ImageNet16-120', None, image_latency)

    train_per_epoch_time = list(
        arch_info_less.query('cifar10-valid', 777).train_times.values())
    train_per_epoch_time = sum(train_per_epoch_time) / len(
        train_per_epoch_time)
    eval_ori_test_time, eval_x_valid_time = [], []
    for key, value in arch_info_less.query('cifar10-valid',
                                           777).eval_times.items():
        if key.startswith('ori-test@'):
            eval_ori_test_time.append(value)
        elif key.startswith('x-valid@'):
            eval_x_valid_time.append(value)
        else:
            raise ValueError('-- {:} --'.format(key))
    eval_ori_test_time, eval_x_valid_time = float(
        np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
    nums = {
        'ImageNet16-120-train': 151700,
        'ImageNet16-120-valid': 3000,
        'ImageNet16-120-test': 6000,
        'cifar10-valid-train': 25000,
        'cifar10-valid-valid': 25000,
        'cifar10-train': 50000,
        'cifar10-test': 10000,
        'cifar100-train': 50000,
        'cifar100-test': 10000,
        'cifar100-valid': 5000
    }
    eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
        nums['cifar10-valid-valid'] + nums['cifar10-test'])
    for arch_info in [arch_info_less, arch_info_full]:
        arch_info.reset_pseudo_train_times(
            'cifar10-valid', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['cifar10-valid-train'])
        arch_info.reset_pseudo_train_times(
            'cifar10', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['cifar10-train'])
        arch_info.reset_pseudo_train_times(
            'cifar100', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['cifar100-train'])
        arch_info.reset_pseudo_train_times(
            'ImageNet16-120', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['ImageNet16-120-train'])
        arch_info.reset_pseudo_eval_times(
            'cifar10-valid', None, 'x-valid',
            eval_per_sample * nums['cifar10-valid-valid'])
        arch_info.reset_pseudo_eval_times(
            'cifar10-valid', None, 'ori-test',
            eval_per_sample * nums['cifar10-test'])
        arch_info.reset_pseudo_eval_times(
            'cifar10', None, 'ori-test',
            eval_per_sample * nums['cifar10-test'])
        arch_info.reset_pseudo_eval_times(
            'cifar100', None, 'x-valid',
            eval_per_sample * nums['cifar100-valid'])
        arch_info.reset_pseudo_eval_times(
            'cifar100', None, 'x-test',
            eval_per_sample * nums['cifar100-valid'])
        arch_info.reset_pseudo_eval_times(
            'cifar100', None, 'ori-test',
            eval_per_sample * nums['cifar100-test'])
        arch_info.reset_pseudo_eval_times(
            'ImageNet16-120', None, 'x-valid',
            eval_per_sample * nums['ImageNet16-120-valid'])
        arch_info.reset_pseudo_eval_times(
            'ImageNet16-120', None, 'x-test',
            eval_per_sample * nums['ImageNet16-120-valid'])
        arch_info.reset_pseudo_eval_times(
            'ImageNet16-120', None, 'ori-test',
            eval_per_sample * nums['ImageNet16-120-test'])
    # arch_info_full.debug_test()
    # arch_info_less.debug_test()
    return arch_info_full, arch_info_less