Пример #1
0
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
                          arch_config, workers, logger):
    machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
    all_infos = {'info': machine_info}
    all_dataset_keys = []
    # look all the datasets
    for dataset, xpath, split in zip(datasets, xpaths, splits):
        # train valid data
        task = None
        train_data, valid_data, xshape, class_num = get_datasets(
            dataset, xpath, -1, task)

        # load the configuration
        if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
            if use_less:
                config_path = 'nas_bench_201/configs/nas-benchmark/LESS.config'
            else:
                config_path = 'nas_bench_201/configs/nas-benchmark/{}.config'.format(
                    dataset)

            p = 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(
                dataset)
            if not os.path.exists(p):
                import json
                label_list = list(range(len(train_data)))
                random.shuffle(label_list)
                strlist = [str(label_list[i]) for i in range(len(label_list))]
                splited = {
                    'train': ["int", strlist[:len(train_data) // 2]],
                    'valid': ["int", strlist[len(train_data) // 2:]]
                }
                with open(p, 'w') as f:
                    f.write(json.dumps(splited))
            split_info = load_config(
                'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(
                    dataset), None, None)
        else:
            raise ValueError('invalid dataset : {:}'.format(dataset))

        config = load_config(config_path, {
            'class_num': class_num,
            'xshape': xshape
        }, logger)
        # data loader
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=workers,
            pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(
            valid_data,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=workers,
            pin_memory=True)
        splits = load_config(
            'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(
                dataset), None, None)
        ValLoaders = {
            'ori-test':
            valid_loader,
            'x-valid':
            torch.utils.data.DataLoader(
                valid_data,
                batch_size=config.batch_size,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(
                    splits.xvalid),
                num_workers=workers,
                pin_memory=True),
            'x-test':
            torch.utils.data.DataLoader(
                valid_data,
                batch_size=config.batch_size,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(
                    splits.xtest),
                num_workers=workers,
                pin_memory=True)
        }
        dataset_key = '{:}'.format(dataset)
        if bool(split): dataset_key = dataset_key + '-valid'
        logger.log(
            'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'
            .format(dataset_key, len(train_data), len(valid_data),
                    len(train_loader), len(valid_loader), config.batch_size))
        logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
            dataset_key, config))
        for key, value in ValLoaders.items():
            logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(
                key, len(value)))

        results = evaluate_for_seed(arch_config, config, arch, train_loader,
                                    ValLoaders, seed, logger)
        all_infos[dataset_key] = results
        all_dataset_keys.append(dataset_key)
    all_infos['all_dataset_keys'] = all_dataset_keys
    return all_infos
Пример #2
0
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger):
  machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
  all_infos = {'info': machine_info}
  all_dataset_keys = []
  # look all the datasets
  for dataset, xpath, split in zip(datasets, xpaths, splits):
    # train valid data
    train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
    # load the configurature
    if dataset == 'cifar10' or dataset == 'cifar100':
      if use_less: config_path = 'configs/nas-benchmark/LESS.config'
      else       : config_path = 'configs/nas-benchmark/CIFAR.config'
      split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
    elif dataset.startswith('ImageNet16'):
      if use_less: config_path = 'configs/nas-benchmark/LESS.config'
      else       : config_path = 'configs/nas-benchmark/ImageNet-16.config'
      split_info  = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
    else:
      raise ValueError('invalid dataset : {:}'.format(dataset))
    config = load_config(config_path, \
                            {'class_num': class_num,
                             'xshape'   : xshape}, \
                            logger)
    # check whether use splited validation set
    if bool(split):
      assert dataset == 'cifar10'
      ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
      assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
      train_data_v2 = deepcopy(train_data)
      train_data_v2.transform = valid_data.transform
      valid_data = train_data_v2
      # data loader
      train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
      valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
      ValLoaders['x-valid'] = valid_loader
    else:
      # data loader
      train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
      valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
      if dataset == 'cifar10':
        ValLoaders = {'ori-test': valid_loader}
      elif dataset == 'cifar100':
        cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
        ValLoaders = {'ori-test': valid_loader,
                      'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True),
                      'x-test'  : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True)
                     }
      elif dataset == 'ImageNet16-120':
        imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
        ValLoaders = {'ori-test': valid_loader,
                      'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True),
                      'x-test'  : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True)
                     }
      else:
        raise ValueError('invalid dataset : {:}'.format(dataset))

    dataset_key = '{:}'.format(dataset)
    if bool(split): dataset_key = dataset_key + '-valid'
    logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
    logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
    for key, value in ValLoaders.items():
      logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
    results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger)
    all_infos[dataset_key] = results
    all_dataset_keys.append( dataset_key )
  all_infos['all_dataset_keys'] = all_dataset_keys
  return all_infos
Пример #3
0
def train_and_eval(args,
                   data,
                   logger,
                   arch,
                   nas_bench,
                   extra_info,
                   dataname='cifar10-valid',
                   use_012_epoch_training=True,
                   use_loss_extrapolation=False):

    if use_loss_extrapolation:
        arch_index = nas_bench.query_index_by_arch(arch)
        print('=' * 80)
        print('ARCH INDEX: {0}'.format(arch_index))
        print('=' * 80)
        nepoch = 12
        assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
        config, train_loader, valid_loader = data
        arch_config = {'channel': args.channel, 'num_cells': args.num_cells}
        t0 = time.time()
        print(config, 'CONFIG')
        results = evaluate_for_seed(arch_config,
                                    config,
                                    arch,
                                    train_loader, {'valid': valid_loader},
                                    seed=random.randint(0, (2**32) - 1),
                                    logger=logger)
        key = '{:}@{:}'.format('valid', nepoch - 1)
        valid_acc = results['valid_acc1es'][key]
        time_cost = None

        info = nas_bench.get_more_info(arch_index, dataname, None, True)
        valid_acc2, time_cost2 = info[
            'valid-accuracy'], info['train-all-time'] + info['valid-per-time']
        print('Taken: {0:.2f} s'.format(time.time() - t0))
        print('API time: {0:.2f}'.format(time_cost2))
        print('Validation accuracy 25 epochs: {0:.2f}'.format(valid_acc))
        print('Validation accuracy reference: {0:.2f}'.format(valid_acc2))

    elif use_012_epoch_training and nas_bench is not None:
        arch_index = nas_bench.query_index_by_arch(arch)
        print('=' * 80)
        print('ARCH INDEX: {0}'.format(arch_index))
        print('=' * 80)
        assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
        info = nas_bench.get_more_info(arch_index, dataname, None, True)
        valid_acc, time_cost = info[
            'valid-accuracy'], info['train-all-time'] + info['valid-per-time']
        #_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
    elif not use_012_epoch_training and nas_bench is not None:
        # Please contact me if you want to use the following logic, because it has some potential issues.
        # Please use `use_012_epoch_training=False` for cifar10 only.
        # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
        arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25
        assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
        try:
            valid_acc = info['valid-accuracy']
        except:
            valid_acc = info['valtest-accuracy']

        time_cost = None
    else:
        # train a model from scratch.
        raise ValueError('NOT IMPLEMENT YET')

    if time_cost is None:
        # The following codes are used to estimate the time cost.
        # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
        # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
        nums = {
            'ImageNet16-120-train': 151700,
            'ImageNet16-120-valid': 3000,
            'cifar10-valid-train': 25000,
            'cifar10-valid-valid': 25000,
            'cifar100-train': 50000,
            'cifar100-valid': 5000
        }

        xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None,
                                         True)
        xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
        info = nas_bench.get_more_info(
            arch_index, dataname, nepoch, False, True
        )  # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
        cost = nas_bench.get_cost_info(arch_index, dataname, False)
        estimated_train_cost = xoinfo['train-per-time'] / nums[
            'cifar10-valid-train'] * nums['{:}-train'.format(
                dataname)] / xocost['latency'] * cost['latency'] * nepoch
        estimated_valid_cost = xoinfo['valid-per-time'] / nums[
            'cifar10-valid-valid'] * nums['{:}-valid'.format(
                dataname)] / xocost['latency'] * cost['latency']
        time_cost = estimated_train_cost + estimated_valid_cost
        print('Adjusted time cost: {0:.2f}'.format(time_cost))
    return valid_acc, time_cost