Exemplo n.º 1
0
def account_one_arch(arch_index, arch_str, checkpoints, datasets,
                     dataloader_dict):
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
        for dataset in datasets:
            assert dataset in checkpoint, 'Can not find {:} in arch-{:} from {:}'.format(
                dataset, arch_index, checkpoint_path)
            results = checkpoint[dataset]
            assert results[
                'finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                'channel': results['channel'],
                'num_cells': results['num_cells'],
                'arch_str': arch_str,
                'class_num': results['config']['class_num']
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
    return information
Exemplo n.º 2
0
def account_one_arch(arch_index, arch_str, checkpoints, datasets,
                     dataloader_dict):
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
        for dataset in datasets:
            assert (
                dataset
                in checkpoint), "Can not find {:} in arch-{:} from {:}".format(
                    dataset, arch_index, checkpoint_path)
            results = checkpoint[dataset]
            assert results[
                "finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                "channel": results["channel"],
                "num_cells": results["num_cells"],
                "arch_str": arch_str,
                "class_num": results["config"]["class_num"],
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
    return information
Exemplo n.º 3
0
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text],
                     datasets: List[Text],
                     dataloader_dict: Dict[Text, Any]) -> ArchResults:
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
        ok_dataset = 0
        for dataset in datasets:
            if dataset not in checkpoint:
                print('Can not find {:} in arch-{:} from {:}'.format(
                    dataset, arch_index, checkpoint_path))
                continue
            else:
                ok_dataset += 1
            results = checkpoint[dataset]
            assert results[
                'finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                'channel': results['channel'],
                'num_cells': results['num_cells'],
                'arch_str': arch_str,
                'class_num': results['config']['class_num']
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
        if ok_dataset == 0:
            raise ValueError(
                '{:} does not find any data'.format(checkpoint_path))
    return information
Exemplo n.º 4
0
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text],
                     datasets: List[Text]) -> ArchResults:
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
        except:
            raise ValueError(
                'This checkpoint failed to be loaded : {:}'.format(
                    checkpoint_path))
        used_seed = checkpoint_path.name.split('-')[-1].split('.')[0]
        ok_dataset = 0
        for dataset in datasets:
            if dataset not in checkpoint:
                print('Can not find {:} in arch-{:} from {:}'.format(
                    dataset, arch_index, checkpoint_path))
                continue
            else:
                ok_dataset += 1
            results = checkpoint[dataset]
            assert results[
                'finish-train'], 'This {:} arch seed={:} does not finish train on {:} ::: {:}'.format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                'name': 'infer.shape.tiny',
                'channels': arch_str,
                'arch_str': arch_str,
                'genotype': results['arch_config']['genotype'],
                'class_num': results['arch_config']['num_classes']
            }
            xresult = ResultsCount(dataset, results['net_state_dict'],
                                   results['train_acc1es'],
                                   results['train_losses'], results['param'],
                                   results['flop'], arch_config, used_seed,
                                   results['total_epoch'], None)
            xresult.update_train_info(results['train_acc1es'],
                                      results['train_acc5es'],
                                      results['train_losses'],
                                      results['train_times'])
            xresult.update_eval(results['valid_acc1es'],
                                results['valid_losses'],
                                results['valid_times'])
            information.update(dataset, int(used_seed), xresult)
        if ok_dataset < len(datasets):
            raise ValueError('{:} does find enought data : {:} vs {:}'.format(
                checkpoint_path, ok_dataset, len(datasets)))
    return information
Exemplo n.º 5
0
def account_one_arch(
    arch_index: int,
    arch_str: Text,
    checkpoints: List[Text],
    datasets: List[Text],
    dataloader_dict: Dict[Text, Any],
) -> ArchResults:
    information = ArchResults(arch_index, arch_str)

    for checkpoint_path in checkpoints:
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        used_seed = checkpoint_path.name.split("-")[-1].split(".")[0]
        ok_dataset = 0
        for dataset in datasets:
            if dataset not in checkpoint:
                print("Can not find {:} in arch-{:} from {:}".format(
                    dataset, arch_index, checkpoint_path))
                continue
            else:
                ok_dataset += 1
            results = checkpoint[dataset]
            assert results[
                "finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format(
                    arch_index, used_seed, dataset, checkpoint_path)
            arch_config = {
                "channel": results["channel"],
                "num_cells": results["num_cells"],
                "arch_str": arch_str,
                "class_num": results["config"]["class_num"],
            }

            xresult = create_result_count(used_seed, dataset, arch_config,
                                          results, dataloader_dict)
            information.update(dataset, int(used_seed), xresult)
        if ok_dataset == 0:
            raise ValueError(
                "{:} does not find any data".format(checkpoint_path))
    return information