def run(space=1, optimizer_name='SNG', budget=108, runing_times=500, runing_epochs=200, step=4, gamma=0.9, save_dir=None, nasbench=None, noise=0.0, sample_with_prob=True, utility_function='log', utility_function_hyper=0.4): print('##### Search Space {} #####'.format(space)) search_space = eval('SearchSpace{}()'.format(space)) cat_variables = [] cs = search_space.get_configuration_space() for h in cs.get_hyperparameters(): if type(h) == ConfigSpace.hyperparameters.CategoricalHyperparameter: cat_variables.append(len(h.choices)) # get category using cat_variables category = cat_variables distribution_optimizer = get_optimizer(optimizer_name, category, step=step, gamma=gamma, sample_with_prob=sample_with_prob, utility_function=utility_function, utility_function_hyper=utility_function_hyper) # path to save the test_accuracy file_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}.npz'.format(optimizer_name, str(space), str(runing_epochs), str(step), str( gamma), str(noise), str(sample_with_prob), utility_function, str(utility_function_hyper)) file_name = os.path.join(save_dir, file_name) nb_reward = Reward(search_space, nasbench, budget) record = { 'validation_accuracy': np.zeros([runing_times, runing_epochs]) - 1, 'test_accuracy': np.zeros([runing_times, runing_epochs]) - 1, } last_test_accuracy = np.zeros([runing_times]) running_time_interval = np.zeros([runing_times, runing_epochs]) test_accuracy = 0 run_timer = Timer() for i in tqdm.tqdm(range(runing_times)): for j in range(runing_epochs): run_timer.tic() if hasattr(distribution_optimizer, 'training_finish') or j == (runing_epochs - 1): last_test_accuracy[i] = test_accuracy if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break sample = distribution_optimizer.sampling() sample_index = one_hot_to_index(np.array(sample)) validation_accuracy = nb_reward.compute_reward(sample_index) distribution_optimizer.record_information( sample, validation_accuracy) distribution_optimizer.update() current_best = np.argmax( distribution_optimizer.p_model.theta, axis=1) test_accuracy = nb_reward.get_accuracy(current_best) record['validation_accuracy'][i, j] = validation_accuracy record['test_accuracy'][i, j] = test_accuracy run_timer.toc() running_time_interval[i, j] = run_timer.diff del distribution_optimizer distribution_optimizer = get_optimizer(optimizer_name, category, step=step, gamma=gamma, sample_with_prob=sample_with_prob, utility_function=utility_function, utility_function_hyper=utility_function_hyper) np.savez(file_name, record['test_accuracy'], running_time_interval) return distribution_optimizer
def compute_time_loader(data_loader): """Computes loader time.""" timer = Timer() data_loader_iterator = iter(data_loader) total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER total_iter = min(total_iter, len(data_loader)) for cur_iter in range(total_iter): if cur_iter == cfg.PREC_TIME.WARMUP_ITER: timer.reset() timer.tic() next(data_loader_iterator) timer.toc() return timer.average_time
def compute_full_loader(data_loader, epoch=1): """Computes full loader time.""" timer = Timer() epoch_avg = [] data_loader_len = len(data_loader) for j in range(epoch): timer.tic() for i, (inputs, labels) in enumerate(data_loader): inputs = inputs.cuda() labels = labels.cuda() timer.toc() logger.info( "Epoch {}/{}, Iter {}/{}: Dataloader time is: {}".format( j + 1, epoch, i + 1, data_loader_len, timer.diff)) timer.tic() epoch_avg.append(timer.average_time) timer.reset() return epoch_avg
def compute_time_train(model, loss_fun): """Computes precise model forward + backward time using dummy data.""" # Use train mode model.train() # Generate a dummy mini-batch and copy data to GPU # NOTE: using cfg.SEARCH space instead # im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) im_size, batch_size = cfg.SEARCH.IM_SIZE, int(cfg.SEARCH.BATCH_SIZE / cfg.NUM_GPUS) inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False) labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) # Cache BatchNorm2D running stats bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] # Compute precise forward backward pass time fw_timer, bw_timer = Timer(), Timer() total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER for cur_iter in range(total_iter): # Reset the timers after the warmup phase if cur_iter == cfg.PREC_TIME.WARMUP_ITER: fw_timer.reset() bw_timer.reset() # Forward fw_timer.tic() preds = model(inputs) loss = loss_fun(preds, labels) torch.cuda.synchronize() fw_timer.toc() # Backward bw_timer.tic() loss.backward() torch.cuda.synchronize() bw_timer.toc() # Restore BatchNorm2D running stats for bn, (mean, var) in zip(bns, bn_stats): bn.running_mean, bn.running_var = mean, var return fw_timer.average_time, bw_timer.average_time
def compute_time_eval(model): """Computes precise model forward test time using dummy data.""" # Use eval mode model.eval() # Generate a dummy mini-batch and copy data to GPU im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False) # Compute precise forward pass time timer = Timer() total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER for cur_iter in range(total_iter): # Reset the timers after the warmup phase if cur_iter == cfg.PREC_TIME.WARMUP_ITER: timer.reset() # Forward timer.tic() model(inputs) torch.cuda.synchronize() timer.toc() return timer.average_time
def basic_darts_cnn_test(): # dartscnn test time_ = Timer() print("Testing darts CNN") search_net = DartsCNN().cuda() _random_architecture_weight = torch.randn( [search_net.num_edges * 2, len(search_net.basic_op_list)]).cuda() _input = torch.randn([2, 3, 32, 32]).cuda() time_.tic() _out_put = search_net(_input, _random_architecture_weight) time_.toc() print(_out_put.shape) print(time_.average_time) time_.reset() _random_one_hot = torch.Tensor(np.eye(len(search_net.basic_op_list))[ np.random.choice(len(search_net.basic_op_list), search_net.num_edges * 2)]).cuda() _input = torch.randn([2, 3, 32, 32]).cuda() time_.tic() _out_put = search_net(_input, _random_one_hot) time_.toc() print(_out_put.shape) print(time_.average_time)
def basic_nas_bench_201_cnn(): # nas_bench_201 test time_ = Timer() print("Testing nas bench 201 CNN") search_net = NASBench201CNN() _random_architecture_weight = torch.randn( [search_net.num_edges, len(search_net.basic_op_list)]) _input = torch.randn([2, 3, 32, 32]) time_.tic() _out_put = search_net(_input, _random_architecture_weight) time_.toc() print(_out_put.shape) print(time_.average_time) time_.reset() _random_one_hot = torch.Tensor(np.eye(len(search_net.basic_op_list))[ np.random.choice(len(search_net.basic_op_list), search_net.num_edges)]) _input = torch.randn([2, 3, 32, 32]) time_.tic() _out_put = search_net(_input, _random_one_hot) time_.toc() print(_out_put.shape) print(time_.average_time)
def darts_train_model(): """train DARTS model""" setup_env() # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect loss_fun = build_loss_fun().cuda() darts_controller = DartsCNNController(search_space, loss_fun) darts_controller.cuda() architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) # weights optimizer w_optim = torch.optim.SGD(darts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(darts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, darts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, darts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(darts_controller, loss_fun, train_, val_) # Setup timer train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) train_timer.tic() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, darts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint(darts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, darts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(darts_controller.genotype()) logger.info( "########################################################") if cfg.SPACE.NAME == "nasbench301": logger.info("Evaluating with nasbench301") EvaluateNasbench(darts_controller.alpha, darts_controller.net, logger, "nasbench301") darts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() train_timer.toc() logger.info("Overall training time (hr) is:{}".format( str(train_timer.total_time)))
def pcdarts_train_model(): """train PC-DARTS model""" setup_env() # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect loss_fun = build_loss_fun().cuda() pcdarts_controller = PCDartsCNNController(search_space, loss_fun) pcdarts_controller.cuda() architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # Load dataset train_transform, valid_transform = data_transforms_cifar10(cutout_length=0) train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train)) train_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) val_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[split:num_train]), pin_memory=True, num_workers=2) # weights optimizer w_optim = torch.optim.SGD(pcdarts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(pcdarts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, pcdarts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pcdarts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(pcdarts_controller, loss_fun, train_, val_) # Setup timer train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) train_timer.tic() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, pcdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint(pcdarts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, pcdarts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(pcdarts_controller.genotype()) logger.info( "########################################################") pcdarts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() train_timer.toc() logger.info("Overall training time (hr) is:{}".format( str(train_timer.total_time)))
class TrainMeter(object): """Measures training stats.""" def __init__(self, epoch_iters): self.epoch_iters = epoch_iters self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters self.iter_timer = Timer() self.loss = ScalarMeter(cfg.LOG_PERIOD) self.loss_total = 0.0 self.lr = None # Current minibatch errors (smoothed over a window) self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) # Number of misclassified examples self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def reset(self, timer=False): if timer: self.iter_timer.reset() self.loss.reset() self.loss_total = 0.0 self.lr = None self.mb_top1_err.reset() self.mb_top5_err.reset() self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def iter_tic(self): self.iter_timer.tic() def iter_toc(self): self.iter_timer.toc() def update_stats(self, top1_err, top5_err, loss, lr, mb_size): # Current minibatch stats self.mb_top1_err.add_value(top1_err) self.mb_top5_err.add_value(top5_err) self.loss.add_value(loss) self.lr = lr # Aggregate stats self.num_top1_mis += top1_err * mb_size self.num_top5_mis += top5_err * mb_size self.loss_total += loss * mb_size self.num_samples += mb_size def get_iter_stats(self, cur_epoch, cur_iter): cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1 eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) mem_usage = gpu_mem_usage() stats = { "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), "time_avg": self.iter_timer.average_time, "time_diff": self.iter_timer.diff, "eta": time_string(eta_sec), "top1_err": self.mb_top1_err.get_win_median(), "top5_err": self.mb_top5_err.get_win_median(), "loss": self.loss.get_win_median(), "lr": self.lr, "mem": int(np.ceil(mem_usage)), } return stats def log_iter_stats(self, cur_epoch, cur_iter): if (cur_iter + 1) % cfg.LOG_PERIOD != 0: return stats = self.get_iter_stats(cur_epoch, cur_iter) logger.info(logging.dump_log_data(stats, "train_iter")) def get_epoch_stats(self, cur_epoch): cur_iter_total = (cur_epoch + 1) * self.epoch_iters eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) mem_usage = gpu_mem_usage() top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples avg_loss = self.loss_total / self.num_samples stats = { "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "time_avg": self.iter_timer.average_time, "eta": time_string(eta_sec), "top1_err": top1_err, "top5_err": top5_err, "loss": avg_loss, "lr": self.lr, "mem": int(np.ceil(mem_usage)), } return stats def log_epoch_stats(self, cur_epoch): stats = self.get_epoch_stats(cur_epoch) logger.info(logging.dump_log_data(stats, "train_epoch"))
class TestMeter(object): """Measures testing stats.""" def __init__(self, max_iter): self.max_iter = max_iter self.iter_timer = Timer() # Current minibatch errors (smoothed over a window) self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) # Min errors (over the full test set) self.min_top1_err = 100.0 self.min_top5_err = 100.0 # Number of misclassified examples self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def reset(self, min_errs=False): if min_errs: self.min_top1_err = 100.0 self.min_top5_err = 100.0 self.iter_timer.reset() self.mb_top1_err.reset() self.mb_top5_err.reset() self.num_top1_mis = 0 self.num_top5_mis = 0 self.num_samples = 0 def iter_tic(self): self.iter_timer.tic() def iter_toc(self): self.iter_timer.toc() def update_stats(self, top1_err, top5_err, mb_size): self.mb_top1_err.add_value(top1_err) self.mb_top5_err.add_value(top5_err) self.num_top1_mis += top1_err * mb_size self.num_top5_mis += top5_err * mb_size self.num_samples += mb_size def get_iter_stats(self, cur_epoch, cur_iter): mem_usage = gpu_mem_usage() iter_stats = { "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "iter": "{}/{}".format(cur_iter + 1, self.max_iter), "time_avg": self.iter_timer.average_time, "time_diff": self.iter_timer.diff, "top1_err": self.mb_top1_err.get_win_median(), "top5_err": self.mb_top5_err.get_win_median(), "mem": int(np.ceil(mem_usage)), } return iter_stats def log_iter_stats(self, cur_epoch, cur_iter): if (cur_iter + 1) % cfg.LOG_PERIOD != 0: return stats = self.get_iter_stats(cur_epoch, cur_iter) logger.info(logging.dump_log_data(stats, "test_iter")) def get_epoch_stats(self, cur_epoch): top1_err = self.num_top1_mis / self.num_samples top5_err = self.num_top5_mis / self.num_samples self.min_top1_err = min(self.min_top1_err, top1_err) self.min_top5_err = min(self.min_top5_err, top5_err) mem_usage = gpu_mem_usage() stats = { "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), "time_avg": self.iter_timer.average_time, "top1_err": top1_err, "top5_err": top5_err, "min_top1_err": self.min_top1_err, "min_top5_err": self.min_top5_err, "mem": int(np.ceil(mem_usage)), } return stats def log_epoch_stats(self, cur_epoch): stats = self.get_epoch_stats(cur_epoch) logger.info(logging.dump_log_data(stats, "test_epoch"))
def train_model(): """SNG search model training""" setup_env() # Load search space search_space = build_space() search_space.cuda() loss_fun = build_loss_fun().cuda() # Weights optimizer w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # Build distribution_optimizer if cfg.SPACE.NAME in ['darts', 'nasbench301']: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) elif cfg.SPACE.NAME in ['proxyless', 'google', 'ofa']: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) elif cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: category = [] cs = search_space.search_space.get_configuration_space() for h in cs.get_hyperparameters(): if type(h ) == ConfigSpace.hyperparameters.CategoricalHyperparameter: category.append(len(h.choices)) distribution_optimizer = sng_builder(category) else: raise NotImplementedError # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) lr_scheduler = lr_scheduler_builder(w_optim) all_timer = Timer() _over_all_epoch = 0 # ===== Warm up training ===== logger.info("start warm up training") warm_train_meter = meters.TrainMeter(len(train_)) warm_val_meter = meters.TestMeter(len(val_)) all_timer.tic() for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCHS): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr = lr_scheduler.get_last_lr()[0] if cfg.SNG.WARMUP_RANDOM_SAMPLE: sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch) logger.info("Sampling: {}".format(one_hot_to_index(sample))) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch_with_sample(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 else: num_ops, total_edges = search_space.num_ops, search_space.all_edges array_sample = [ random.sample(list(range(num_ops)), num_ops) for i in range(total_edges) ] array_sample = np.array(array_sample) for i in range(num_ops): sample = np.transpose(array_sample[:, i]) sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch_with_sample(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 all_timer.toc() logger.info("end warm up training") # ===== Training procedure ===== logger.info("start one-shot training") train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) all_timer.tic() for cur_epoch in range(cfg.OPTIM.MAX_EPOCH): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break lr = w_optim.param_groups[0]['lr'] sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch, _random=cfg.SNG.RANDOM_SAMPLE) logger.info("Sampling: {}".format(one_hot_to_index(sample))) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) top1 = test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 lr_scheduler.step() distribution_optimizer.record_information(sample, top1) distribution_optimizer.update() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info( search_space.genotype(distribution_optimizer.p_model.theta)) logger.info( "########################################################") logger.info("####### ALPHA #######") for alpha in distribution_optimizer.p_model.theta: logger.info(alpha) logger.info("#####################") if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() all_timer.toc() # ===== Final epoch ===== lr = w_optim.param_groups[0]['lr'] all_timer.tic() for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) if cfg.SPACE.NAME in ['darts', 'nasbench301']: genotype = search_space.genotype( distribution_optimizer.p_model.theta) sample = search_space.genotype_to_onehot_sample(genotype) else: sample = distribution_optimizer.sampling_best() _over_all_epoch += 1 train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch, sample, writer) logger.info("Overall training time : {} hours".format( str((all_timer.total_time) / 3600.))) # Evaluate through nasbench if cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3", "nasbench201", "nasbench301" ]: logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME)) theta = distribution_optimizer.p_model.theta EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
def run(M=10, N=10, func='rastrigin', optimizer_name='SNG', running_times=500, running_epochs=200, step=4, gamma=0.9, save_dir=None, noise=0.0, sample_with_prob=True, utility_function='log', utility_function_hyper=0.4): category = [M] * N epoc_fun = 'linear' test_fun = EpochSumCategoryTestFunction(category, epoch_func=epoc_fun, func=func, noise_std=noise) distribution_optimizer = get_optimizer( optimizer_name, category, step=step, gamma=gamma, sample_with_prob=sample_with_prob, utility_function=utility_function, utility_function_hyper=utility_function_hyper) file_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.npz'.format( optimizer_name, str(N), str(M), str(running_epochs), epoc_fun, func, str(step), str(gamma), str(noise), str(sample_with_prob), utility_function, str(utility_function_hyper)) file_name = os.path.join(save_dir, file_name) record = { 'objective': np.zeros([running_times, running_epochs]) - 1, 'l2_distance': np.zeros([running_times, running_epochs]) - 1 } last_l2_distance = np.zeros([running_times]) running_time_interval = np.zeros([running_times, running_epochs]) _distance = 100 run_timer = Timer() for i in tqdm.tqdm(range(running_times)): for j in range(running_epochs): run_timer.tic() if hasattr(distribution_optimizer, 'training_finish') or j == (running_epochs - 1): last_l2_distance[i] = _distance if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break sample = distribution_optimizer.sampling() objective = test_fun.objective_function(sample) distribution_optimizer.record_information(sample, objective) distribution_optimizer.update() current_best = np.argmax(distribution_optimizer.p_model.theta, axis=1) _distance = test_fun.l2_distance(current_best) record['l2_distance'][i, j] = objective record['objective'][i, j] = _distance run_timer.toc() running_time_interval[i, j] = run_timer.diff test_fun.re_new() del distribution_optimizer # print(_distance) distribution_optimizer = get_optimizer( optimizer_name, category, step=step, gamma=gamma, sample_with_prob=sample_with_prob, utility_function=utility_function, utility_function_hyper=utility_function_hyper) np.savez(file_name, record['l2_distance'], running_time_interval) return distribution_optimizer