コード例 #1
0
def simplify(save_dir, save_name, nets, total, sup_config):
  dataloader_dict = get_nas_bench_loaders(6)
  hps, seeds = ['12', '200'], set()
  for hp in hps:
    sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
    ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth')))
    seed2names = defaultdict(list)
    for ckp in ckps:
      parts = re.split('-|\.', ckp.name)
      seed2names[parts[3]].append(ckp.name)
    print('DIR : {:}'.format(sub_save_dir))
    nums = []
    for seed, xlist in seed2names.items():
      seeds.add(seed)
      nums.append(len(xlist))
      print('  [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist)))
    assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total)
  print('{:} start simplify the checkpoint.'.format(time_string()))

  datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')

  # Create the directory to save the processed data
  # full_save_dir contains all benchmark files with trained weights.
  # simplify_save_dir contains all benchmark files without trained weights.
  full_save_dir = save_dir / (save_name + '-FULL')
  simple_save_dir = save_dir / (save_name + '-SIMPLIFY')
  full_save_dir.mkdir(parents=True, exist_ok=True)
  simple_save_dir.mkdir(parents=True, exist_ok=True)
  # all data in memory
  arch2infos, evaluated_indexes = dict(), set()
  end_time, arch_time = time.time(), AverageMeter()
  # save the meta information
  temp_final_infos = {'meta_archs' : nets,
                      'total_archs': total,
                      'arch2infos' : None,
                      'evaluated_indexes': set()}
  pickle_save(temp_final_infos, str(full_save_dir / 'meta.pickle'))
  pickle_save(temp_final_infos, str(simple_save_dir / 'meta.pickle'))

  for index in tqdm(range(total)):
    arch_str = nets[index]
    hp2info = OrderedDict()

    full_save_path = full_save_dir / '{:06d}.pickle'.format(index)
    simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)
    for hp in hps:
      sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
      ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds]
      ckps = [x for x in ckps if x.exists()]
      if len(ckps) == 0:
        raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp))

      arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict)
      hp2info[hp] = arch_info
    
    hp2info = correct_time_related_info(index, hp2info)
    evaluated_indexes.add(index)
    
    to_save_data = OrderedDict({'12': hp2info['12'].state_dict(),
                                '200': hp2info['200'].state_dict()})
    pickle_save(to_save_data, str(full_save_path))
    
    for hp in hps: hp2info[hp].clear_params()
    to_save_data = OrderedDict({'12': hp2info['12'].state_dict(),
                                '200': hp2info['200'].state_dict()})
    pickle_save(to_save_data, str(simple_save_path))
    arch2infos[index] = to_save_data
    # measure elapsed time
    arch_time.update(time.time() - end_time)
    end_time  = time.time()
    need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True))
    # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
  print('{:} {:} done.'.format(time_string(), save_name))
  final_infos = {'meta_archs' : nets,
                 'total_archs': total,
                 'arch2infos' : arch2infos,
                 'evaluated_indexes': evaluated_indexes}
  save_file_name = save_dir / '{:}.pickle'.format(save_name)
  pickle_save(final_infos, str(save_file_name))
  # move the benchmark file to a new path
  hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
  hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum)
  shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
  print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
  # move the directory to a new path
  hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
  hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
  shutil.move(full_save_dir, hd5_full_save_dir)
  shutil.move(simple_save_dir, hd5_simple_save_dir)
コード例 #2
0
def simplify(save_dir, meta_file, basestr, target_dir):
    meta_infos = torch.load(meta_file, map_location="cpu")
    meta_archs = meta_infos["archs"]  # a list of architecture strings
    meta_num_archs = meta_infos["total"]
    assert meta_num_archs == len(
        meta_archs), "invalid number of archs : {:} vs {:}".format(
            meta_num_archs, len(meta_archs))

    sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr))))
    print("{:} find {:} directories used to save checkpoints".format(
        time_string(), len(sub_model_dirs)))

    subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
    num_seeds = defaultdict(lambda: 0)
    for index, sub_dir in enumerate(sub_model_dirs):
        xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth"))
        arch_indexes = set()
        for checkpoint in xcheckpoints:
            temp_names = checkpoint.name.split("-")
            assert (len(temp_names) == 4 and temp_names[0] == "arch"
                    and temp_names[2]
                    == "seed"), "invalid checkpoint name : {:}".format(
                        checkpoint.name)
            arch_indexes.add(temp_names[1])
        subdir2archs[sub_dir] = sorted(list(arch_indexes))
        num_evaluated_arch += len(arch_indexes)
        # count number of seeds for each architecture
        for arch_index in arch_indexes:
            num_seeds[len(
                list(sub_dir.glob(
                    "arch-{:}-seed-*.pth".format(arch_index))))] += 1
    print(
        "{:} There are {:5d} architectures that have been evaluated ({:} in total)."
        .format(time_string(), num_evaluated_arch, meta_num_archs))
    for key in sorted(list(num_seeds.keys())):
        print(
            "{:} There are {:5d} architectures that are evaluated {:} times.".
            format(time_string(), num_seeds[key], key))

    dataloader_dict = get_nas_bench_loaders(6)
    to_save_simply = save_dir / "simplifies"
    to_save_allarc = save_dir / "simplifies" / "architectures"
    if not to_save_simply.exists():
        to_save_simply.mkdir(parents=True, exist_ok=True)
    if not to_save_allarc.exists():
        to_save_allarc.mkdir(parents=True, exist_ok=True)

    assert (save_dir /
            target_dir) in subdir2archs, "can not find {:}".format(target_dir)
    arch2infos, datasets = {}, (
        "cifar10-valid",
        "cifar10",
        "cifar100",
        "ImageNet16-120",
    )
    evaluated_indexes = set()
    target_full_dir = save_dir / target_dir
    target_less_dir = save_dir / "{:}-LESS".format(target_dir)
    arch_indexes = subdir2archs[target_full_dir]
    num_seeds = defaultdict(lambda: 0)
    end_time = time.time()
    arch_time = AverageMeter()
    for idx, arch_index in enumerate(arch_indexes):
        checkpoints = list(
            target_full_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))
        ckps_less = list(
            target_less_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))
        # create the arch info for each architecture
        try:
            arch_info_full = account_one_arch(
                arch_index,
                meta_archs[int(arch_index)],
                checkpoints,
                datasets,
                dataloader_dict,
            )
            arch_info_less = account_one_arch(
                arch_index,
                meta_archs[int(arch_index)],
                ckps_less,
                datasets,
                dataloader_dict,
            )
            num_seeds[len(checkpoints)] += 1
        except:
            print("Loading {:} failed, : {:}".format(arch_index, checkpoints))
            continue
        assert (int(arch_index) not in evaluated_indexes
                ), "conflict arch-index : {:}".format(arch_index)
        assert (0 <= int(arch_index) < len(meta_archs)
                ), "invalid arch-index {:} (not found in meta_archs)".format(
                    arch_index)
        arch_info = {"full": arch_info_full, "less": arch_info_less}
        evaluated_indexes.add(int(arch_index))
        arch2infos[int(arch_index)] = arch_info
        # to correct the latency and training_time info.
        arch_info_full, arch_info_less = correct_time_related_info(
            int(arch_index), arch_info_full, arch_info_less)
        to_save_data = OrderedDict(full=arch_info_full.state_dict(),
                                   less=arch_info_less.state_dict())
        torch.save(to_save_data,
                   to_save_allarc / "{:}-FULL.pth".format(arch_index))
        arch_info["full"].clear_params()
        arch_info["less"].clear_params()
        torch.save(to_save_data,
                   to_save_allarc / "{:}-SIMPLE.pth".format(arch_index))
        # measure elapsed time
        arch_time.update(time.time() - end_time)
        end_time = time.time()
        need_time = "{:}".format(
            convert_secs2time(arch_time.avg * (len(arch_indexes) - idx - 1),
                              True))
        print("{:} {:} [{:03d}/{:03d}] : {:} still need {:}".format(
            time_string(), target_dir, idx, len(arch_indexes), arch_index,
            need_time))
    # measure time
    xstrs = [
        "{:}:{:03d}".format(key, num_seeds[key])
        for key in sorted(list(num_seeds.keys()))
    ]
    print("{:} {:} done : {:}".format(time_string(), target_dir, xstrs))
    final_infos = {
        "meta_archs": meta_archs,
        "total_archs": meta_num_archs,
        "basestr": basestr,
        "arch2infos": arch2infos,
        "evaluated_indexes": evaluated_indexes,
    }
    save_file_name = to_save_simply / "{:}.pth".format(target_dir)
    torch.save(final_infos, save_file_name)
    print("Save {:} / {:} architecture results into {:}.".format(
        len(evaluated_indexes), meta_num_archs, save_file_name))