コード例 #1
0
def account_one_arch(arch_index, arch_str, checkpoints, datasets,
                     dataloader_dict):
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
        for dataset in datasets:
            assert (
                dataset
                in checkpoint), "Can not find {:} in arch-{:} from {:}".format(
                    dataset, arch_index, checkpoint_path)
            results = checkpoint[dataset]
            assert results[
                "finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                "channel": results["channel"],
                "num_cells": results["num_cells"],
                "arch_str": arch_str,
                "class_num": results["config"]["class_num"],
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
    return information
コード例 #2
0
def account_one_arch(arch_index, arch_str, checkpoints, datasets,
                     dataloader_dict):
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
        for dataset in datasets:
            assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(
                dataset, arch_index, checkpoint_path)
            results = checkpoint[dataset]
            assert results[
                'finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                'channel': results['channel'],
                'num_cells': results['num_cells'],
                'arch_str': arch_str,
                'class_num': results['config']['class_num']
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
    return information
コード例 #3
0
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text],
                     datasets: List[Text],
                     dataloader_dict: Dict[Text, Any]) -> ArchResults:
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
        ok_dataset = 0
        for dataset in datasets:
            if dataset not in checkpoint:
                print('Can not find {:} in arch-{:} from {:}'.format(
                    dataset, arch_index, checkpoint_path))
                continue
            else:
                ok_dataset += 1
            results = checkpoint[dataset]
            assert results[
                'finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                'channel': results['channel'],
                'num_cells': results['num_cells'],
                'arch_str': arch_str,
                'class_num': results['config']['class_num']
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
        if ok_dataset == 0:
            raise ValueError(
                '{:} does not find any data'.format(checkpoint_path))
    return information
コード例 #4
0
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text],
                     datasets: List[Text]) -> ArchResults:
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
        except:
            raise ValueError(
                'This checkpoint failed to be loaded : {:}'.format(
                    checkpoint_path))
        used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
        ok_dataset = 0
        for dataset in datasets:
            if dataset not in checkpoint:
                print('Can not find {:} in arch-{:} from {:}'.format(
                    dataset, arch_index, checkpoint_path))
                continue
            else:
                ok_dataset += 1
            results = checkpoint[dataset]
            assert results[
                'finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                'name': 'infer.shape.tiny',
                'channels': arch_str,
                'arch_str': arch_str,
                'genotype': results['arch_config']['genotype'],
                'class_num': results['arch_config']['num_classes']
            }
            xresult = ResultsCount(dataset, results['net_state_dict'],
                                   results['train_acc1es'],
                                   results['train_losses'], results['param'],
                                   results['flop'], arch_config, used_seed,
                                   results['total_epoch'], None)
            xresult.update_train_info(results['train_acc1es'],
                                      results['train_acc5es'],
                                      results['train_losses'],
                                      results['train_times'])
            xresult.update_eval(results['valid_acc1es'],
                                results['valid_losses'],
                                results['valid_times'])
            information.update(dataset, int(used_seed), xresult)
        if ok_dataset < len(datasets):
            raise ValueError('{:} does find enought data : {:} vs {:}'.format(
                checkpoint_path, ok_dataset, len(datasets)))
    return information
コード例 #5
0
def account_one_arch(
    arch_index: int,
    arch_str: Text,
    checkpoints: List[Text],
    datasets: List[Text],
    dataloader_dict: Dict[Text, Any],
) -> ArchResults:
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
        ok_dataset = 0
        for dataset in datasets:
            if dataset not in checkpoint:
                print("Can not find {:} in arch-{:} from {:}".format(
                    dataset, arch_index, checkpoint_path))
                continue
            else:
                ok_dataset += 1
            results = checkpoint[dataset]
            assert results[
                "finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                "channel": results["channel"],
                "num_cells": results["num_cells"],
                "arch_str": arch_str,
                "class_num": results["config"]["class_num"],
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
        if ok_dataset == 0:
            raise ValueError(
                "{:} does not find any data".format(checkpoint_path))
    return information
コード例 #6
0
def test_nas_api():
  from nas_201_api import ArchResults
  xdata   = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth')
  for key in ['full', 'less']:
    print ('\n------------------------- {:} -------------------------'.format(key))
    archRes = ArchResults.create_from_state_dict(xdata[key])
    print(archRes)
    print(archRes.arch_idx_str())
    print(archRes.get_dataset_names())
    print(archRes.get_comput_costs('cifar10-valid'))
    # get the metrics
    print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False))
    print(archRes.get_metrics('cifar10-valid', 'x-valid', None,  True))
    print(archRes.query('cifar10-valid', 777))
コード例 #7
0
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults,
                              arch_info_less: ArchResults):
    # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
    cifar010_latency = (api.get_latency(arch_index, "cifar10-valid", hp="200")
                        + api.get_latency(arch_index, "cifar10", hp="200")) / 2
    arch_info_full.reset_latency("cifar10-valid", None, cifar010_latency)
    arch_info_full.reset_latency("cifar10", None, cifar010_latency)
    arch_info_less.reset_latency("cifar10-valid", None, cifar010_latency)
    arch_info_less.reset_latency("cifar10", None, cifar010_latency)

    cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200")
    arch_info_full.reset_latency("cifar100", None, cifar100_latency)
    arch_info_less.reset_latency("cifar100", None, cifar100_latency)

    image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200")
    arch_info_full.reset_latency("ImageNet16-120", None, image_latency)
    arch_info_less.reset_latency("ImageNet16-120", None, image_latency)

    train_per_epoch_time = list(
        arch_info_less.query("cifar10-valid", 777).train_times.values())
    train_per_epoch_time = sum(train_per_epoch_time) / len(
        train_per_epoch_time)
    eval_ori_test_time, eval_x_valid_time = [], []
    for key, value in arch_info_less.query("cifar10-valid",
                                           777).eval_times.items():
        if key.startswith("ori-test@"):
            eval_ori_test_time.append(value)
        elif key.startswith("x-valid@"):
            eval_x_valid_time.append(value)
        else:
            raise ValueError("-- {:} --".format(key))
    eval_ori_test_time, eval_x_valid_time = float(
        np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
    nums = {
        "ImageNet16-120-train": 151700,
        "ImageNet16-120-valid": 3000,
        "ImageNet16-120-test": 6000,
        "cifar10-valid-train": 25000,
        "cifar10-valid-valid": 25000,
        "cifar10-train": 50000,
        "cifar10-test": 10000,
        "cifar100-train": 50000,
        "cifar100-test": 10000,
        "cifar100-valid": 5000,
    }
    eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
        nums["cifar10-valid-valid"] + nums["cifar10-test"])
    for arch_info in [arch_info_less, arch_info_full]:
        arch_info.reset_pseudo_train_times(
            "cifar10-valid",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["cifar10-valid-train"],
        )
        arch_info.reset_pseudo_train_times(
            "cifar10",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["cifar10-train"],
        )
        arch_info.reset_pseudo_train_times(
            "cifar100",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["cifar100-train"],
        )
        arch_info.reset_pseudo_train_times(
            "ImageNet16-120",
            None,
            train_per_epoch_time / nums["cifar10-valid-train"] *
            nums["ImageNet16-120-train"],
        )
        arch_info.reset_pseudo_eval_times(
            "cifar10-valid",
            None,
            "x-valid",
            eval_per_sample * nums["cifar10-valid-valid"],
        )
        arch_info.reset_pseudo_eval_times(
            "cifar10-valid", None, "ori-test",
            eval_per_sample * nums["cifar10-test"])
        arch_info.reset_pseudo_eval_times(
            "cifar10", None, "ori-test",
            eval_per_sample * nums["cifar10-test"])
        arch_info.reset_pseudo_eval_times(
            "cifar100", None, "x-valid",
            eval_per_sample * nums["cifar100-valid"])
        arch_info.reset_pseudo_eval_times(
            "cifar100", None, "x-test",
            eval_per_sample * nums["cifar100-valid"])
        arch_info.reset_pseudo_eval_times(
            "cifar100", None, "ori-test",
            eval_per_sample * nums["cifar100-test"])
        arch_info.reset_pseudo_eval_times(
            "ImageNet16-120",
            None,
            "x-valid",
            eval_per_sample * nums["ImageNet16-120-valid"],
        )
        arch_info.reset_pseudo_eval_times(
            "ImageNet16-120",
            None,
            "x-test",
            eval_per_sample * nums["ImageNet16-120-valid"],
        )
        arch_info.reset_pseudo_eval_times(
            "ImageNet16-120",
            None,
            "ori-test",
            eval_per_sample * nums["ImageNet16-120-test"],
        )
    # arch_info_full.debug_test()
    # arch_info_less.debug_test()
    return arch_info_full, arch_info_less
コード例 #8
0
def correct_time_related_info(arch_index: int, arch_info_full: ArchResults,
                              arch_info_less: ArchResults):
    # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth
    cifar010_latency = (api.get_latency(arch_index, 'cifar10-valid', hp='200')
                        + api.get_latency(arch_index, 'cifar10', hp='200')) / 2
    arch_info_full.reset_latency('cifar10-valid', None, cifar010_latency)
    arch_info_full.reset_latency('cifar10', None, cifar010_latency)
    arch_info_less.reset_latency('cifar10-valid', None, cifar010_latency)
    arch_info_less.reset_latency('cifar10', None, cifar010_latency)

    cifar100_latency = api.get_latency(arch_index, 'cifar100', hp='200')
    arch_info_full.reset_latency('cifar100', None, cifar100_latency)
    arch_info_less.reset_latency('cifar100', None, cifar100_latency)

    image_latency = api.get_latency(arch_index, 'ImageNet16-120', hp='200')
    arch_info_full.reset_latency('ImageNet16-120', None, image_latency)
    arch_info_less.reset_latency('ImageNet16-120', None, image_latency)

    train_per_epoch_time = list(
        arch_info_less.query('cifar10-valid', 777).train_times.values())
    train_per_epoch_time = sum(train_per_epoch_time) / len(
        train_per_epoch_time)
    eval_ori_test_time, eval_x_valid_time = [], []
    for key, value in arch_info_less.query('cifar10-valid',
                                           777).eval_times.items():
        if key.startswith('ori-test@'):
            eval_ori_test_time.append(value)
        elif key.startswith('x-valid@'):
            eval_x_valid_time.append(value)
        else:
            raise ValueError('-- {:} --'.format(key))
    eval_ori_test_time, eval_x_valid_time = float(
        np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time))
    nums = {
        'ImageNet16-120-train': 151700,
        'ImageNet16-120-valid': 3000,
        'ImageNet16-120-test': 6000,
        'cifar10-valid-train': 25000,
        'cifar10-valid-valid': 25000,
        'cifar10-train': 50000,
        'cifar10-test': 10000,
        'cifar100-train': 50000,
        'cifar100-test': 10000,
        'cifar100-valid': 5000
    }
    eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (
        nums['cifar10-valid-valid'] + nums['cifar10-test'])
    for arch_info in [arch_info_less, arch_info_full]:
        arch_info.reset_pseudo_train_times(
            'cifar10-valid', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['cifar10-valid-train'])
        arch_info.reset_pseudo_train_times(
            'cifar10', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['cifar10-train'])
        arch_info.reset_pseudo_train_times(
            'cifar100', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['cifar100-train'])
        arch_info.reset_pseudo_train_times(
            'ImageNet16-120', None, train_per_epoch_time /
            nums['cifar10-valid-train'] * nums['ImageNet16-120-train'])
        arch_info.reset_pseudo_eval_times(
            'cifar10-valid', None, 'x-valid',
            eval_per_sample * nums['cifar10-valid-valid'])
        arch_info.reset_pseudo_eval_times(
            'cifar10-valid', None, 'ori-test',
            eval_per_sample * nums['cifar10-test'])
        arch_info.reset_pseudo_eval_times(
            'cifar10', None, 'ori-test',
            eval_per_sample * nums['cifar10-test'])
        arch_info.reset_pseudo_eval_times(
            'cifar100', None, 'x-valid',
            eval_per_sample * nums['cifar100-valid'])
        arch_info.reset_pseudo_eval_times(
            'cifar100', None, 'x-test',
            eval_per_sample * nums['cifar100-valid'])
        arch_info.reset_pseudo_eval_times(
            'cifar100', None, 'ori-test',
            eval_per_sample * nums['cifar100-test'])
        arch_info.reset_pseudo_eval_times(
            'ImageNet16-120', None, 'x-valid',
            eval_per_sample * nums['ImageNet16-120-valid'])
        arch_info.reset_pseudo_eval_times(
            'ImageNet16-120', None, 'x-test',
            eval_per_sample * nums['ImageNet16-120-valid'])
        arch_info.reset_pseudo_eval_times(
            'ImageNet16-120', None, 'ori-test',
            eval_per_sample * nums['ImageNet16-120-test'])
    # arch_info_full.debug_test()
    # arch_info_less.debug_test()
    return arch_info_full, arch_info_less