예제 #1
0
def simplify(save_dir, save_name, nets, total, sup_config):
  dataloader_dict = get_nas_bench_loaders(6)
  hps, seeds = ['12', '200'], set()
  for hp in hps:
    sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
    ckps = sorted(list(sub_save_dir.glob('arch-*-seed-*.pth')))
    seed2names = defaultdict(list)
    for ckp in ckps:
      parts = re.split('-|\.', ckp.name)
      seed2names[parts[3]].append(ckp.name)
    print('DIR : {:}'.format(sub_save_dir))
    nums = []
    for seed, xlist in seed2names.items():
      seeds.add(seed)
      nums.append(len(xlist))
      print('  [seed={:}] there are {:} checkpoints.'.format(seed, len(xlist)))
    assert len(nets) == total == max(nums), 'there are some missed files : {:} vs {:}'.format(max(nums), total)
  print('{:} start simplify the checkpoint.'.format(time_string()))

  datasets = ('cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120')

  # Create the directory to save the processed data
  # full_save_dir contains all benchmark files with trained weights.
  # simplify_save_dir contains all benchmark files without trained weights.
  full_save_dir = save_dir / (save_name + '-FULL')
  simple_save_dir = save_dir / (save_name + '-SIMPLIFY')
  full_save_dir.mkdir(parents=True, exist_ok=True)
  simple_save_dir.mkdir(parents=True, exist_ok=True)
  # all data in memory
  arch2infos, evaluated_indexes = dict(), set()
  end_time, arch_time = time.time(), AverageMeter()
  # save the meta information
  temp_final_infos = {'meta_archs' : nets,
                      'total_archs': total,
                      'arch2infos' : None,
                      'evaluated_indexes': set()}
  pickle_save(temp_final_infos, str(full_save_dir / 'meta.pickle'))
  pickle_save(temp_final_infos, str(simple_save_dir / 'meta.pickle'))

  for index in tqdm(range(total)):
    arch_str = nets[index]
    hp2info = OrderedDict()

    full_save_path = full_save_dir / '{:06d}.pickle'.format(index)
    simple_save_path = simple_save_dir / '{:06d}.pickle'.format(index)
    for hp in hps:
      sub_save_dir = save_dir / 'raw-data-{:}'.format(hp)
      ckps = [sub_save_dir / 'arch-{:06d}-seed-{:}.pth'.format(index, seed) for seed in seeds]
      ckps = [x for x in ckps if x.exists()]
      if len(ckps) == 0:
        raise ValueError('Invalid data : index={:}, hp={:}'.format(index, hp))

      arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict)
      hp2info[hp] = arch_info
    
    hp2info = correct_time_related_info(index, hp2info)
    evaluated_indexes.add(index)
    
    to_save_data = OrderedDict({'12': hp2info['12'].state_dict(),
                                '200': hp2info['200'].state_dict()})
    pickle_save(to_save_data, str(full_save_path))
    
    for hp in hps: hp2info[hp].clear_params()
    to_save_data = OrderedDict({'12': hp2info['12'].state_dict(),
                                '200': hp2info['200'].state_dict()})
    pickle_save(to_save_data, str(simple_save_path))
    arch2infos[index] = to_save_data
    # measure elapsed time
    arch_time.update(time.time() - end_time)
    end_time  = time.time()
    need_time = '{:}'.format(convert_secs2time(arch_time.avg * (total-index-1), True))
    # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
  print('{:} {:} done.'.format(time_string(), save_name))
  final_infos = {'meta_archs' : nets,
                 'total_archs': total,
                 'arch2infos' : arch2infos,
                 'evaluated_indexes': evaluated_indexes}
  save_file_name = save_dir / '{:}.pickle'.format(save_name)
  pickle_save(final_infos, str(save_file_name))
  # move the benchmark file to a new path
  hd5sum = get_md5_file(str(save_file_name) + '.pbz2')
  hd5_file_name = save_dir / '{:}-{:}.pickle.pbz2'.format(NATS_TSS_BASE_NAME, hd5sum)
  shutil.move(str(save_file_name) + '.pbz2', hd5_file_name)
  print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
  # move the directory to a new path
  hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
  hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
  shutil.move(full_save_dir, hd5_full_save_dir)
  shutil.move(simple_save_dir, hd5_simple_save_dir)
예제 #2
0
def simplify(save_dir, save_name, nets, total, sup_config):
    hps, seeds = ["12", "200"], set()
    for hp in hps:
        sub_save_dir = save_dir / "raw-data-{:}".format(hp)
        ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth")))
        seed2names = defaultdict(list)
        for ckp in ckps:
            parts = re.split("-|\.", ckp.name)
            seed2names[parts[3]].append(ckp.name)
        print("DIR : {:}".format(sub_save_dir))
        nums = []
        for seed, xlist in seed2names.items():
            seeds.add(seed)
            nums.append(len(xlist))
            print("  [seed={:}] there are {:} checkpoints.".format(
                seed, len(xlist)))
        assert (len(nets) == total ==
                max(nums)), "there are some missed files : {:} vs {:}".format(
                    max(nums), total)
    print("{:} start simplify the checkpoint.".format(time_string()))

    datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120")

    # Create the directory to save the processed data
    # full_save_dir contains all benchmark files with trained weights.
    # simplify_save_dir contains all benchmark files without trained weights.
    full_save_dir = save_dir / (save_name + "-FULL")
    simple_save_dir = save_dir / (save_name + "-SIMPLIFY")
    full_save_dir.mkdir(parents=True, exist_ok=True)
    simple_save_dir.mkdir(parents=True, exist_ok=True)
    # all data in memory
    arch2infos, evaluated_indexes = dict(), set()
    end_time, arch_time = time.time(), AverageMeter()
    # save the meta information
    for index in tqdm(range(total)):
        arch_str = nets[index]
        hp2info = OrderedDict()

        simple_save_path = simple_save_dir / "{:06d}.pickle".format(index)

        arch2infos[index] = pickle_load(simple_save_path)
        evaluated_indexes.add(index)

        # measure elapsed time
        arch_time.update(time.time() - end_time)
        end_time = time.time()
        need_time = "{:}".format(
            convert_secs2time(arch_time.avg * (total - index - 1), True))
        # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time))
    print("{:} {:} done.".format(time_string(), save_name))
    final_infos = {
        "meta_archs": nets,
        "total_archs": total,
        "arch2infos": arch2infos,
        "evaluated_indexes": evaluated_indexes,
    }
    save_file_name = save_dir / "{:}.pickle".format(save_name)
    pickle_save(final_infos, str(save_file_name))
    # move the benchmark file to a new path
    hd5sum = get_md5_file(str(save_file_name) + ".pbz2")
    hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(
        NATS_TSS_BASE_NAME, hd5sum)
    shutil.move(str(save_file_name) + ".pbz2", hd5_file_name)
    print("Save {:} / {:} architecture results into {:} -> {:}.".format(
        len(evaluated_indexes), total, save_file_name, hd5_file_name))
    # move the directory to a new path
    hd5_full_save_dir = save_dir / "{:}-{:}-full".format(
        NATS_TSS_BASE_NAME, hd5sum)
    hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(
        NATS_TSS_BASE_NAME, hd5sum)
    shutil.move(full_save_dir, hd5_full_save_dir)
    shutil.move(simple_save_dir, hd5_simple_save_dir)