コード例 #1
0
def main(xargs, nas_bench):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( xargs.workers )
  prepare_seed(xargs.rand_seed)
  logger = prepare_logger(args)

  assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
  train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
  split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
  cifar_split = load_config(split_Fpath, None, None)
  train_split, valid_split = cifar_split.train, cifar_split.valid
  logger.log('Load split file from {:}'.format(split_Fpath))
  config_path = 'configs/nas-benchmark/algos/R-EA.config'
  config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
  # To split data
  train_data_v2 = deepcopy(train_data)
  train_data_v2.transform = valid_data.transform
  valid_data    = train_data_v2
  search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
  # data loader
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
  valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
  logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
  logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
  extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}

  search_space = get_search_spaces('cell', xargs.search_space_name)
  random_arch = random_architecture_func(xargs.max_nodes, search_space)
  #x =random_arch() ; y = mutate_arch(x)
  logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
  best_arch, best_acc, total_time_cost, history = None, -1, 0, []
  #for idx in range(xargs.random_num):
  while total_time_cost < xargs.time_budget:
    arch = random_arch()
    accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
    if total_time_cost + cost_time > xargs.time_budget: break
    else: total_time_cost += cost_time
    history.append(arch)
    if best_arch is None or best_acc < accuracy:
      best_acc, best_arch = accuracy, arch
    logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
  logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost))
  
  info = nas_bench.query_by_arch( best_arch )
  if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
  else           : logger.log('{:}'.format(info))
  logger.log('-'*100)
  logger.close()
  return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
コード例 #2
0
ファイル: RANDOM.py プロジェクト: Joey61Liuyi/AutoDL-Projects
def main(xargs, nas_bench):
    assert torch.cuda.is_available(), "CUDA is not available."
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    if xargs.dataset == "cifar10":
        dataname = "cifar10-valid"
    else:
        dataname = xargs.dataset
    if xargs.data_path is not None:
        train_data, valid_data, xshape, class_num = get_datasets(
            xargs.dataset, xargs.data_path, -1)
        split_Fpath = "configs/nas-benchmark/cifar-split.txt"
        cifar_split = load_config(split_Fpath, None, None)
        train_split, valid_split = cifar_split.train, cifar_split.valid
        logger.log("Load split file from {:}".format(split_Fpath))
        config_path = "configs/nas-benchmark/algos/R-EA.config"
        config = load_config(config_path, {
            "class_num": class_num,
            "xshape": xshape
        }, logger)
        # To split data
        train_data_v2 = deepcopy(train_data)
        train_data_v2.transform = valid_data.transform
        valid_data = train_data_v2
        search_data = SearchDataset(xargs.dataset, train_data, train_split,
                                    valid_split)
        # data loader
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
            num_workers=xargs.workers,
            pin_memory=True,
        )
        valid_loader = torch.utils.data.DataLoader(
            valid_data,
            batch_size=config.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
            num_workers=xargs.workers,
            pin_memory=True,
        )
        logger.log(
            "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}"
            .format(xargs.dataset, len(train_loader), len(valid_loader),
                    config.batch_size))
        logger.log("||||||| {:10s} ||||||| Config={:}".format(
            xargs.dataset, config))
        extra_info = {
            "config": config,
            "train_loader": train_loader,
            "valid_loader": valid_loader,
        }
    else:
        config_path = "configs/nas-benchmark/algos/R-EA.config"
        config = load_config(config_path, None, logger)
        logger.log("||||||| {:10s} ||||||| Config={:}".format(
            xargs.dataset, config))
        extra_info = {
            "config": config,
            "train_loader": None,
            "valid_loader": None
        }
    search_space = get_search_spaces("cell", xargs.search_space_name)
    random_arch = random_architecture_func(xargs.max_nodes, search_space)
    # x =random_arch() ; y = mutate_arch(x)
    x_start_time = time.time()
    logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
    best_arch, best_acc, total_time_cost, history = None, -1, 0, []
    # for idx in range(xargs.random_num):
    while total_time_cost < xargs.time_budget:
        arch = random_arch()
        accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info,
                                             dataname)
        if total_time_cost + cost_time > xargs.time_budget:
            break
        else:
            total_time_cost += cost_time
        history.append(arch)
        if best_arch is None or best_acc < accuracy:
            best_acc, best_arch = accuracy, arch
        logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(
            len(history), arch, accuracy))
    logger.log(
        "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s)."
        .format(
            time_string(),
            best_arch,
            best_acc,
            len(history),
            total_time_cost,
            time.time() - x_start_time,
        ))

    info = nas_bench.query_by_arch(best_arch, "200")
    if info is None:
        logger.log("Did not find this architecture : {:}.".format(best_arch))
    else:
        logger.log("{:}".format(info))
    logger.log("-" * 100)
    logger.close()
    return logger.log_dir, nas_bench.query_index_by_arch(best_arch)