コード例 #1
0
ファイル: sng_nasbench1shot1.py プロジェクト: dercaft/XNAS
def run(space=1, optimizer_name='SNG', budget=108, runing_times=500, runing_epochs=200,
        step=4, gamma=0.9, save_dir=None, nasbench=None, noise=0.0, sample_with_prob=True, utility_function='log',
        utility_function_hyper=0.4):
    print('##### Search Space {} #####'.format(space))
    search_space = eval('SearchSpace{}()'.format(space))
    cat_variables = []
    cs = search_space.get_configuration_space()
    for h in cs.get_hyperparameters():
        if type(h) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
            cat_variables.append(len(h.choices))
    # get category using cat_variables
    category = cat_variables

    distribution_optimizer = get_optimizer(optimizer_name, category, step=step, gamma=gamma,
                                           sample_with_prob=sample_with_prob, utility_function=utility_function,
                                           utility_function_hyper=utility_function_hyper)
    # path to save the test_accuracy
    file_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}.npz'.format(optimizer_name, str(space), str(runing_epochs),
                                                        str(step), str(
                                                            gamma), str(noise),
                                                        str(sample_with_prob), utility_function, str(utility_function_hyper))
    file_name = os.path.join(save_dir, file_name)
    nb_reward = Reward(search_space, nasbench, budget)
    record = {
        'validation_accuracy': np.zeros([runing_times, runing_epochs]) - 1,
        'test_accuracy': np.zeros([runing_times, runing_epochs]) - 1,
    }
    last_test_accuracy = np.zeros([runing_times])
    running_time_interval = np.zeros([runing_times, runing_epochs])
    test_accuracy = 0
    run_timer = Timer()

    for i in tqdm.tqdm(range(runing_times)):
        for j in range(runing_epochs):
            run_timer.tic()
            if hasattr(distribution_optimizer, 'training_finish') or j == (runing_epochs - 1):
                last_test_accuracy[i] = test_accuracy
            if hasattr(distribution_optimizer, 'training_finish'):
                if distribution_optimizer.training_finish:
                    break
            sample = distribution_optimizer.sampling()
            sample_index = one_hot_to_index(np.array(sample))
            validation_accuracy = nb_reward.compute_reward(sample_index)
            distribution_optimizer.record_information(
                sample, validation_accuracy)
            distribution_optimizer.update()
            current_best = np.argmax(
                distribution_optimizer.p_model.theta, axis=1)
            test_accuracy = nb_reward.get_accuracy(current_best)
            record['validation_accuracy'][i, j] = validation_accuracy
            record['test_accuracy'][i, j] = test_accuracy
            run_timer.toc()
            running_time_interval[i, j] = run_timer.diff
        del distribution_optimizer
        distribution_optimizer = get_optimizer(optimizer_name, category, step=step, gamma=gamma,
                                               sample_with_prob=sample_with_prob, utility_function=utility_function,
                                               utility_function_hyper=utility_function_hyper)
    np.savez(file_name, record['test_accuracy'], running_time_interval)
    return distribution_optimizer
コード例 #2
0
def compute_time_loader(data_loader):
    """Computes loader time."""
    timer = Timer()
    data_loader_iterator = iter(data_loader)
    total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
    total_iter = min(total_iter, len(data_loader))
    for cur_iter in range(total_iter):
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            timer.reset()
        timer.tic()
        next(data_loader_iterator)
        timer.toc()
    return timer.average_time
コード例 #3
0
def compute_full_loader(data_loader, epoch=1):
    """Computes full loader time."""
    timer = Timer()
    epoch_avg = []
    data_loader_len = len(data_loader)
    for j in range(epoch):
        timer.tic()
        for i, (inputs, labels) in enumerate(data_loader):
            inputs = inputs.cuda()
            labels = labels.cuda()
            timer.toc()
            logger.info(
                "Epoch {}/{}, Iter {}/{}: Dataloader time is: {}".format(
                    j + 1, epoch, i + 1, data_loader_len, timer.diff))
            timer.tic()
        epoch_avg.append(timer.average_time)
        timer.reset()
    return epoch_avg
コード例 #4
0
def compute_time_train(model, loss_fun):
    """Computes precise model forward + backward time using dummy data."""
    # Use train mode
    model.train()
    # Generate a dummy mini-batch and copy data to GPU
    # NOTE: using cfg.SEARCH space instead
    # im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
    im_size, batch_size = cfg.SEARCH.IM_SIZE, int(cfg.SEARCH.BATCH_SIZE /
                                                  cfg.NUM_GPUS)
    inputs = torch.rand(batch_size, 3, im_size,
                        im_size).cuda(non_blocking=False)
    labels = torch.zeros(batch_size,
                         dtype=torch.int64).cuda(non_blocking=False)
    # Cache BatchNorm2D running stats
    bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
    bn_stats = [[bn.running_mean.clone(),
                 bn.running_var.clone()] for bn in bns]
    # Compute precise forward backward pass time
    fw_timer, bw_timer = Timer(), Timer()
    total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
    for cur_iter in range(total_iter):
        # Reset the timers after the warmup phase
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            fw_timer.reset()
            bw_timer.reset()
        # Forward
        fw_timer.tic()
        preds = model(inputs)
        loss = loss_fun(preds, labels)
        torch.cuda.synchronize()
        fw_timer.toc()
        # Backward
        bw_timer.tic()
        loss.backward()
        torch.cuda.synchronize()
        bw_timer.toc()
    # Restore BatchNorm2D running stats
    for bn, (mean, var) in zip(bns, bn_stats):
        bn.running_mean, bn.running_var = mean, var
    return fw_timer.average_time, bw_timer.average_time
コード例 #5
0
def compute_time_eval(model):
    """Computes precise model forward test time using dummy data."""
    # Use eval mode
    model.eval()
    # Generate a dummy mini-batch and copy data to GPU
    im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE /
                                                 cfg.NUM_GPUS)
    inputs = torch.zeros(batch_size, 3, im_size,
                         im_size).cuda(non_blocking=False)
    # Compute precise forward pass time
    timer = Timer()
    total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
    for cur_iter in range(total_iter):
        # Reset the timers after the warmup phase
        if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
            timer.reset()
        # Forward
        timer.tic()
        model(inputs)
        torch.cuda.synchronize()
        timer.toc()
    return timer.average_time
コード例 #6
0
ファイル: pcdarts_test.py プロジェクト: taoari/XNAS
def basic_darts_cnn_test():
    # dartscnn test
    time_ = Timer()
    print("Testing darts CNN")
    search_net = DartsCNN().cuda()
    _random_architecture_weight = torch.randn(
        [search_net.num_edges * 2, len(search_net.basic_op_list)]).cuda()
    _input = torch.randn([2, 3, 32, 32]).cuda()
    time_.tic()
    _out_put = search_net(_input, _random_architecture_weight)
    time_.toc()
    print(_out_put.shape)
    print(time_.average_time)
    time_.reset()
    _random_one_hot = torch.Tensor(np.eye(len(search_net.basic_op_list))[
                                   np.random.choice(len(search_net.basic_op_list), search_net.num_edges * 2)]).cuda()
    _input = torch.randn([2, 3, 32, 32]).cuda()
    time_.tic()
    _out_put = search_net(_input, _random_one_hot)
    time_.toc()
    print(_out_put.shape)
    print(time_.average_time)
コード例 #7
0
ファイル: pcdarts_test.py プロジェクト: taoari/XNAS
def basic_nas_bench_201_cnn():
    #  nas_bench_201 test
    time_ = Timer()
    print("Testing nas bench 201 CNN")
    search_net = NASBench201CNN()
    _random_architecture_weight = torch.randn(
        [search_net.num_edges, len(search_net.basic_op_list)])
    _input = torch.randn([2, 3, 32, 32])
    time_.tic()
    _out_put = search_net(_input, _random_architecture_weight)
    time_.toc()
    print(_out_put.shape)
    print(time_.average_time)
    time_.reset()
    _random_one_hot = torch.Tensor(np.eye(len(search_net.basic_op_list))[
                                   np.random.choice(len(search_net.basic_op_list), search_net.num_edges)])
    _input = torch.randn([2, 3, 32, 32])
    time_.tic()
    _out_put = search_net(_input, _random_one_hot)
    time_.toc()
    print(_out_put.shape)
    print(time_.average_time)
コード例 #8
0
def darts_train_model():
    """train DARTS model"""
    setup_env()
    # Loading search space
    search_space = build_space()
    # TODO: fix the complexity function
    # search_space = setup_model()
    # Init controller and architect
    loss_fun = build_loss_fun().cuda()
    darts_controller = DartsCNNController(search_space, loss_fun)
    darts_controller.cuda()
    architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM,
                          cfg.OPTIM.WEIGHT_DECAY)
    # Load dataset
    [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                      cfg.SEARCH.BATCH_SIZE)
    # weights optimizer
    w_optim = torch.optim.SGD(darts_controller.weights(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    a_optim = torch.optim.Adam(darts_controller.alphas(),
                               cfg.DARTS.ALPHA_LR,
                               betas=(0.5, 0.999),
                               weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
    lr_scheduler = lr_scheduler_builder(w_optim)
    # Init meters
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = darts_load_checkpoint(last_checkpoint,
                                                 darts_controller, w_optim,
                                                 a_optim)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.SEARCH.WEIGHTS:
        darts_load_checkpoint(cfg.SEARCH.WEIGHTS, darts_controller)
        logger.info("Loaded initial weights from: {}".format(
            cfg.SEARCH.WEIGHTS))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(darts_controller, loss_fun, train_, val_)
    # Setup timer
    train_timer = Timer()
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    train_timer.tic()
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):

        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, darts_controller, architect, loss_fun,
                    w_optim, a_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = darts_save_checkpoint(darts_controller, w_optim,
                                                    a_optim, cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        lr_scheduler.step()
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_, darts_controller, val_meter, cur_epoch, writer)
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(darts_controller.genotype())
            logger.info(
                "########################################################")

            if cfg.SPACE.NAME == "nasbench301":
                logger.info("Evaluating with nasbench301")
                EvaluateNasbench(darts_controller.alpha, darts_controller.net,
                                 logger, "nasbench301")

            darts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    train_timer.toc()
    logger.info("Overall training time (hr) is:{}".format(
        str(train_timer.total_time)))
コード例 #9
0
def pcdarts_train_model():
    """train PC-DARTS model"""
    setup_env()
    # Loading search space
    search_space = build_space()
    # TODO: fix the complexity function
    # search_space = setup_model()
    # Init controller and architect
    loss_fun = build_loss_fun().cuda()
    pcdarts_controller = PCDartsCNNController(search_space, loss_fun)
    pcdarts_controller.cuda()
    architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM,
                          cfg.OPTIM.WEIGHT_DECAY)

    # Load dataset
    train_transform, valid_transform = data_transforms_cifar10(cutout_length=0)

    train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train))

    train_ = torch.utils.data.DataLoader(
        train_data,
        batch_size=cfg.SEARCH.BATCH_SIZE,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)
    val_ = torch.utils.data.DataLoader(
        train_data,
        batch_size=cfg.SEARCH.BATCH_SIZE,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    # weights optimizer
    w_optim = torch.optim.SGD(pcdarts_controller.weights(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)
    # alphas optimizer
    a_optim = torch.optim.Adam(pcdarts_controller.alphas(),
                               cfg.DARTS.ALPHA_LR,
                               betas=(0.5, 0.999),
                               weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY)
    lr_scheduler = lr_scheduler_builder(w_optim)
    # Init meters
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    # Load checkpoint or initial weights
    start_epoch = 0
    if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint():
        last_checkpoint = checkpoint.get_last_checkpoint()
        checkpoint_epoch = darts_load_checkpoint(last_checkpoint,
                                                 pcdarts_controller, w_optim,
                                                 a_optim)
        logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
        start_epoch = checkpoint_epoch + 1
    elif cfg.SEARCH.WEIGHTS:
        darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pcdarts_controller)
        logger.info("Loaded initial weights from: {}".format(
            cfg.SEARCH.WEIGHTS))
    # Compute model and loader timings
    if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
        benchmark.compute_time_full(pcdarts_controller, loss_fun, train_, val_)
    # Setup timer
    train_timer = Timer()
    # Perform the training loop
    logger.info("Start epoch: {}".format(start_epoch + 1))
    train_timer.tic()
    for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        lr = lr_scheduler.get_last_lr()[0]
        train_epoch(train_, val_, pcdarts_controller, architect, loss_fun,
                    w_optim, a_optim, lr, train_meter, cur_epoch)
        # Save a checkpoint
        if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = darts_save_checkpoint(pcdarts_controller,
                                                    w_optim, a_optim,
                                                    cur_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
        lr_scheduler.step()
        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            test_epoch(val_, pcdarts_controller, val_meter, cur_epoch, writer)
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(pcdarts_controller.genotype())
            logger.info(
                "########################################################")
            pcdarts_controller.print_alphas(logger)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    train_timer.toc()
    logger.info("Overall training time (hr) is:{}".format(
        str(train_timer.total_time)))
コード例 #10
0
ファイル: meters.py プロジェクト: taoari/XNAS
class TrainMeter(object):
    """Measures training stats."""
    def __init__(self, epoch_iters):
        self.epoch_iters = epoch_iters
        self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
        self.iter_timer = Timer()
        self.loss = ScalarMeter(cfg.LOG_PERIOD)
        self.loss_total = 0.0
        self.lr = None
        # Current minibatch errors (smoothed over a window)
        self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
        self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
        # Number of misclassified examples
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def reset(self, timer=False):
        if timer:
            self.iter_timer.reset()
        self.loss.reset()
        self.loss_total = 0.0
        self.lr = None
        self.mb_top1_err.reset()
        self.mb_top5_err.reset()
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def iter_tic(self):
        self.iter_timer.tic()

    def iter_toc(self):
        self.iter_timer.toc()

    def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
        # Current minibatch stats
        self.mb_top1_err.add_value(top1_err)
        self.mb_top5_err.add_value(top5_err)
        self.loss.add_value(loss)
        self.lr = lr
        # Aggregate stats
        self.num_top1_mis += top1_err * mb_size
        self.num_top5_mis += top5_err * mb_size
        self.loss_total += loss * mb_size
        self.num_samples += mb_size

    def get_iter_stats(self, cur_epoch, cur_iter):
        cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
        eta_sec = self.iter_timer.average_time * (self.max_iter -
                                                  cur_iter_total)
        mem_usage = gpu_mem_usage()
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
            "time_avg": self.iter_timer.average_time,
            "time_diff": self.iter_timer.diff,
            "eta": time_string(eta_sec),
            "top1_err": self.mb_top1_err.get_win_median(),
            "top5_err": self.mb_top5_err.get_win_median(),
            "loss": self.loss.get_win_median(),
            "lr": self.lr,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_iter_stats(self, cur_epoch, cur_iter):
        if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
            return
        stats = self.get_iter_stats(cur_epoch, cur_iter)
        logger.info(logging.dump_log_data(stats, "train_iter"))

    def get_epoch_stats(self, cur_epoch):
        cur_iter_total = (cur_epoch + 1) * self.epoch_iters
        eta_sec = self.iter_timer.average_time * (self.max_iter -
                                                  cur_iter_total)
        mem_usage = gpu_mem_usage()
        top1_err = self.num_top1_mis / self.num_samples
        top5_err = self.num_top5_mis / self.num_samples
        avg_loss = self.loss_total / self.num_samples
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "time_avg": self.iter_timer.average_time,
            "eta": time_string(eta_sec),
            "top1_err": top1_err,
            "top5_err": top5_err,
            "loss": avg_loss,
            "lr": self.lr,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_epoch_stats(self, cur_epoch):
        stats = self.get_epoch_stats(cur_epoch)
        logger.info(logging.dump_log_data(stats, "train_epoch"))
コード例 #11
0
ファイル: meters.py プロジェクト: taoari/XNAS
class TestMeter(object):
    """Measures testing stats."""
    def __init__(self, max_iter):
        self.max_iter = max_iter
        self.iter_timer = Timer()
        # Current minibatch errors (smoothed over a window)
        self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
        self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
        # Min errors (over the full test set)
        self.min_top1_err = 100.0
        self.min_top5_err = 100.0
        # Number of misclassified examples
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def reset(self, min_errs=False):
        if min_errs:
            self.min_top1_err = 100.0
            self.min_top5_err = 100.0
        self.iter_timer.reset()
        self.mb_top1_err.reset()
        self.mb_top5_err.reset()
        self.num_top1_mis = 0
        self.num_top5_mis = 0
        self.num_samples = 0

    def iter_tic(self):
        self.iter_timer.tic()

    def iter_toc(self):
        self.iter_timer.toc()

    def update_stats(self, top1_err, top5_err, mb_size):
        self.mb_top1_err.add_value(top1_err)
        self.mb_top5_err.add_value(top5_err)
        self.num_top1_mis += top1_err * mb_size
        self.num_top5_mis += top5_err * mb_size
        self.num_samples += mb_size

    def get_iter_stats(self, cur_epoch, cur_iter):
        mem_usage = gpu_mem_usage()
        iter_stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "iter": "{}/{}".format(cur_iter + 1, self.max_iter),
            "time_avg": self.iter_timer.average_time,
            "time_diff": self.iter_timer.diff,
            "top1_err": self.mb_top1_err.get_win_median(),
            "top5_err": self.mb_top5_err.get_win_median(),
            "mem": int(np.ceil(mem_usage)),
        }
        return iter_stats

    def log_iter_stats(self, cur_epoch, cur_iter):
        if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
            return
        stats = self.get_iter_stats(cur_epoch, cur_iter)
        logger.info(logging.dump_log_data(stats, "test_iter"))

    def get_epoch_stats(self, cur_epoch):
        top1_err = self.num_top1_mis / self.num_samples
        top5_err = self.num_top5_mis / self.num_samples
        self.min_top1_err = min(self.min_top1_err, top1_err)
        self.min_top5_err = min(self.min_top5_err, top5_err)
        mem_usage = gpu_mem_usage()
        stats = {
            "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
            "time_avg": self.iter_timer.average_time,
            "top1_err": top1_err,
            "top5_err": top5_err,
            "min_top1_err": self.min_top1_err,
            "min_top5_err": self.min_top5_err,
            "mem": int(np.ceil(mem_usage)),
        }
        return stats

    def log_epoch_stats(self, cur_epoch):
        stats = self.get_epoch_stats(cur_epoch)
        logger.info(logging.dump_log_data(stats, "test_epoch"))
コード例 #12
0
ファイル: sng_search.py プロジェクト: dercaft/XNAS
def train_model():
    """SNG search model training"""
    setup_env()
    # Load search space
    search_space = build_space()
    search_space.cuda()
    loss_fun = build_loss_fun().cuda()

    # Weights optimizer
    w_optim = torch.optim.SGD(search_space.parameters(),
                              cfg.OPTIM.BASE_LR,
                              momentum=cfg.OPTIM.MOMENTUM,
                              weight_decay=cfg.OPTIM.WEIGHT_DECAY)

    # Build distribution_optimizer
    if cfg.SPACE.NAME in ['darts', 'nasbench301']:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)
    elif cfg.SPACE.NAME in ['proxyless', 'google', 'ofa']:
        distribution_optimizer = sng_builder([search_space.num_ops] *
                                             search_space.all_edges)
    elif cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3"
    ]:
        category = []
        cs = search_space.search_space.get_configuration_space()
        for h in cs.get_hyperparameters():
            if type(h
                    ) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                category.append(len(h.choices))
        distribution_optimizer = sng_builder(category)
    else:
        raise NotImplementedError

    # Load dataset
    [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT,
                                      cfg.SEARCH.BATCH_SIZE)

    lr_scheduler = lr_scheduler_builder(w_optim)
    all_timer = Timer()
    _over_all_epoch = 0

    # ===== Warm up training =====
    logger.info("start warm up training")
    warm_train_meter = meters.TrainMeter(len(train_))
    warm_val_meter = meters.TestMeter(len(val_))
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCHS):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        lr = lr_scheduler.get_last_lr()[0]
        if cfg.SNG.WARMUP_RANDOM_SAMPLE:
            sample = random_sampling(search_space,
                                     distribution_optimizer,
                                     epoch=cur_epoch)
            logger.info("Sampling: {}".format(one_hot_to_index(sample)))
            train_epoch(train_, val_, search_space, w_optim, lr,
                        _over_all_epoch, sample, loss_fun, warm_train_meter)
            top1 = test_epoch_with_sample(val_, search_space, warm_val_meter,
                                          _over_all_epoch, sample, writer)
            _over_all_epoch += 1
        else:
            num_ops, total_edges = search_space.num_ops, search_space.all_edges
            array_sample = [
                random.sample(list(range(num_ops)), num_ops)
                for i in range(total_edges)
            ]
            array_sample = np.array(array_sample)
            for i in range(num_ops):
                sample = np.transpose(array_sample[:, i])
                sample = index_to_one_hot(sample,
                                          distribution_optimizer.p_model.Cmax)
                train_epoch(train_, val_, search_space, w_optim, lr,
                            _over_all_epoch, sample, loss_fun,
                            warm_train_meter)
                top1 = test_epoch_with_sample(val_, search_space,
                                              warm_val_meter, _over_all_epoch,
                                              sample, writer)
                _over_all_epoch += 1
    all_timer.toc()
    logger.info("end warm up training")

    # ===== Training procedure =====
    logger.info("start one-shot training")
    train_meter = meters.TrainMeter(len(train_))
    val_meter = meters.TestMeter(len(val_))
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.MAX_EPOCH):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        if hasattr(distribution_optimizer, 'training_finish'):
            if distribution_optimizer.training_finish:
                break
        lr = w_optim.param_groups[0]['lr']

        sample = random_sampling(search_space,
                                 distribution_optimizer,
                                 epoch=cur_epoch,
                                 _random=cfg.SNG.RANDOM_SAMPLE)
        logger.info("Sampling: {}".format(one_hot_to_index(sample)))
        train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                    sample, loss_fun, train_meter)
        top1 = test_epoch_with_sample(val_, search_space, val_meter,
                                      _over_all_epoch, sample, writer)
        _over_all_epoch += 1

        lr_scheduler.step()
        distribution_optimizer.record_information(sample, top1)
        distribution_optimizer.update()

        # Evaluate the model
        next_epoch = cur_epoch + 1
        if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
            logger.info("Start testing")
            logger.info(
                "###############Optimal genotype at epoch: {}############".
                format(cur_epoch))
            logger.info(
                search_space.genotype(distribution_optimizer.p_model.theta))
            logger.info(
                "########################################################")
            logger.info("####### ALPHA #######")
            for alpha in distribution_optimizer.p_model.theta:
                logger.info(alpha)
            logger.info("#####################")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache(
            )  # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637
        gc.collect()
    all_timer.toc()

    # ===== Final epoch =====
    lr = w_optim.param_groups[0]['lr']
    all_timer.tic()
    for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH):

        # Save a checkpoint
        if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0:
            checkpoint_file = checkpoint.save_checkpoint(
                search_space, w_optim, _over_all_epoch)
            logger.info("Wrote checkpoint to: {}".format(checkpoint_file))

        if cfg.SPACE.NAME in ['darts', 'nasbench301']:
            genotype = search_space.genotype(
                distribution_optimizer.p_model.theta)
            sample = search_space.genotype_to_onehot_sample(genotype)
        else:
            sample = distribution_optimizer.sampling_best()
        _over_all_epoch += 1
        train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch,
                    sample, loss_fun, train_meter)
        test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch,
                               sample, writer)
    logger.info("Overall training time : {} hours".format(
        str((all_timer.total_time) / 3600.)))

    # Evaluate through nasbench
    if cfg.SPACE.NAME in [
            "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3",
            "nasbench201", "nasbench301"
    ]:
        logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME))
        theta = distribution_optimizer.p_model.theta
        EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
コード例 #13
0
def run(M=10,
        N=10,
        func='rastrigin',
        optimizer_name='SNG',
        running_times=500,
        running_epochs=200,
        step=4,
        gamma=0.9,
        save_dir=None,
        noise=0.0,
        sample_with_prob=True,
        utility_function='log',
        utility_function_hyper=0.4):
    category = [M] * N
    epoc_fun = 'linear'
    test_fun = EpochSumCategoryTestFunction(category,
                                            epoch_func=epoc_fun,
                                            func=func,
                                            noise_std=noise)

    distribution_optimizer = get_optimizer(
        optimizer_name,
        category,
        step=step,
        gamma=gamma,
        sample_with_prob=sample_with_prob,
        utility_function=utility_function,
        utility_function_hyper=utility_function_hyper)
    file_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.npz'.format(
        optimizer_name, str(N), str(M), str(running_epochs), epoc_fun, func,
        str(step), str(gamma), str(noise), str(sample_with_prob),
        utility_function, str(utility_function_hyper))

    file_name = os.path.join(save_dir, file_name)

    record = {
        'objective': np.zeros([running_times, running_epochs]) - 1,
        'l2_distance': np.zeros([running_times, running_epochs]) - 1
    }
    last_l2_distance = np.zeros([running_times])
    running_time_interval = np.zeros([running_times, running_epochs])
    _distance = 100
    run_timer = Timer()
    for i in tqdm.tqdm(range(running_times)):
        for j in range(running_epochs):
            run_timer.tic()
            if hasattr(distribution_optimizer,
                       'training_finish') or j == (running_epochs - 1):
                last_l2_distance[i] = _distance
            if hasattr(distribution_optimizer, 'training_finish'):
                if distribution_optimizer.training_finish:
                    break
            sample = distribution_optimizer.sampling()
            objective = test_fun.objective_function(sample)
            distribution_optimizer.record_information(sample, objective)
            distribution_optimizer.update()

            current_best = np.argmax(distribution_optimizer.p_model.theta,
                                     axis=1)
            _distance = test_fun.l2_distance(current_best)
            record['l2_distance'][i, j] = objective
            record['objective'][i, j] = _distance
            run_timer.toc()
            running_time_interval[i, j] = run_timer.diff
        test_fun.re_new()
        del distribution_optimizer
        # print(_distance)
        distribution_optimizer = get_optimizer(
            optimizer_name,
            category,
            step=step,
            gamma=gamma,
            sample_with_prob=sample_with_prob,
            utility_function=utility_function,
            utility_function_hyper=utility_function_hyper)
    np.savez(file_name, record['l2_distance'], running_time_interval)
    return distribution_optimizer