def sampling_best(self): if not self.dynamic_sampling: return one_hot_to_index(np.array(self.sampling())) sample = [] for i in range(self.p_model.d): sample.append(np.argmax(self.p_model.theta[i])) sample = np.array(sample) return index_to_one_hot(sample, self.p_model.Cmax)
def sampling(self): if not self.dynamic_sampling: rand = np.random.rand(self.p_model.d, 1) # range of random number is [0, 1) cum_theta = self.p_model.theta.cumsum(axis=1) # (d, Cmax) # x[i, j] becomes 1 if cum_theta[i, j] - theta[i, j] <= rand[i] < cum_theta[i, j] c = (cum_theta - self.p_model.theta <= rand) & (rand < cum_theta) return c if self.sampling_number_per_edge == 1: return index_to_one_hot(self.sampling_index(), self.p_model.Cmax) else: sample = [] sample_one_hot_like = np.zeros([self.p_model.d, self.p_model.Cmax]) for i in range(self.p_model.d): # get the prob if self.sample_with_prob: prob = copy.deepcopy( self.p_model.theta[i, self.sample_index[i]]) prob = prob / prob.sum() sample.append( np.random.choice(self.sample_index[i], size=self.sampling_number_per_edge, p=prob, replace=False)) else: sample.append( np.random.choice(self.sample_index[i], size=self.sampling_number_per_edge, replace=False)) if len(self.sample_index[i]) > 0: for j in sample[i]: self.sample_index[i].remove(int(j)) for j in range(self.sampling_number_per_edge): sample_one_hot_like[i, int(sample[i][j])] = 1 return sample_one_hot_like
def random_sampling(search_space, distribution_optimizer, epoch=-1000, _random=False): if _random: num_ops, total_edges = search_space.num_ops, search_space.all_edges # edge importance non_edge_idx = [] if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH: assert cfg.SPACE.NAME == 'darts', "only support darts for now!" norm_indexes = search_space.norm_node_index non_edge_idx = [] for node in norm_indexes: edge_non_prob = distribution_optimizer.p_model.theta[ np.array(node), 7] edge_non_prob = edge_non_prob / np.sum(edge_non_prob) if len(node) == 2: pass else: non_edge_sampling_num = len(node) - 2 non_edge_idx += list( np.random.choice(node, non_edge_sampling_num, p=edge_non_prob, replace=False)) if random.random() < cfg.SNG.BIGMODEL_SAMPLE_PROB: # sample the network with high complexity _num = 100 while _num > cfg.SNG.BIGMODEL_NON_PARA: _error = False if cfg.SNG.PROB_SAMPLING: sample = np.array([ np.random.choice( num_ops, 1, p=distribution_optimizer.p_model.theta[i, :])[0] for i in range(total_edges) ]) else: sample = np.array([ np.random.choice(num_ops, 1)[0] for i in range(total_edges) ]) _num = 0 for i in sample[0:search_space.num_edges]: if i in non_edge_idx: pass elif i in search_space.non_op_idx: if i == 7: _error = True _num = _num + 1 if _error: _num = 100 else: if cfg.SNG.PROB_SAMPLING: sample = np.array([ np.random.choice( num_ops, 1, p=distribution_optimizer.p_model.theta[i, :])[0] for i in range(total_edges) ]) else: sample = np.array([ np.random.choice(num_ops, 1)[0] for i in range(total_edges) ]) if cfg.SNG.EDGE_SAMPLING and epoch > cfg.SNG.EDGE_SAMPLING_EPOCH: for i in non_edge_idx: sample[i] = 7 sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) # in the pruning method we have to sampling anyway distribution_optimizer.sampling() return sample else: return distribution_optimizer.sampling()
def main(): setup_env() # loadiong search space search_space = build_space() search_space.cuda() # init controller and architect loss_fun = nn.CrossEntropyLoss().cuda() # weights optimizer w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # load dataset [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) # build distribution_optimizer if cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: category = [] cs = search_space.search_space.get_configuration_space() for h in cs.get_hyperparameters(): if type(h ) == ConfigSpace.hyperparameters.CategoricalHyperparameter: category.append(len(h.choices)) distribution_optimizer = sng_builder(category) else: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) lr_scheduler = lr_scheduler_builder(w_optim) # training loop logger.info("start warm up training") warm_train_meter = meters.TrainMeter(len(train_)) warm_val_meter = meters.TestMeter(len(val_)) start_time = time.time() _over_all_epoch = 0 for epoch in range(cfg.OPTIM.WARMUP_EPOCHS): # lr_scheduler.step() lr = lr_scheduler.get_last_lr()[0] # warm up training if cfg.SNG.WARMUP_RANDOM_SAMPLE: sample = random_sampling(search_space, distribution_optimizer, epoch=epoch) logger.info("The sample is: {}".format(one_hot_to_index(sample))) train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 else: num_ops, total_edges = search_space.num_ops, search_space.all_edges array_sample = [ random.sample(list(range(num_ops)), num_ops) for i in range(total_edges) ] array_sample = np.array(array_sample) for i in range(num_ops): sample = np.transpose(array_sample[:, i]) sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 # new version of warmup epoch logger.info("end warm up training") logger.info("start One shot searching") train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) for epoch in range(cfg.OPTIM.MAX_EPOCH): if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break lr = w_optim.param_groups[0]['lr'] # sample by the distribution optimizer # _ = distribution_optimizer.sampling() # random sample sample = random_sampling(search_space, distribution_optimizer, epoch=epoch, _random=cfg.SNG.RANDOM_SAMPLE) logger.info("The sample is: {}".format(one_hot_to_index(sample))) # training train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) # validation top1 = test_epoch(val_, search_space, val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 lr_scheduler.step() distribution_optimizer.record_information(sample, top1) distribution_optimizer.update() # Evaluate the model next_epoch = epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") logger.info( "###############Optimal genotype at epoch: {}############". format(epoch)) logger.info( search_space.genotype(distribution_optimizer.p_model.theta)) logger.info( "########################################################") logger.info("####### ALPHA #######") for alpha in distribution_optimizer.p_model.theta: logger.info(alpha) logger.info("#####################") if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() end_time = time.time() lr = w_optim.param_groups[0]['lr'] for epoch in range(cfg.OPTIM.FINAL_EPOCH): if cfg.SPACE.NAME == 'darts': genotype = search_space.genotype( distribution_optimizer.p_model.theta) sample = search_space.genotype_to_onehot_sample(genotype) else: sample = distribution_optimizer.sampling_best() _over_all_epoch += 1 train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) test_epoch(val_, search_space, val_meter, _over_all_epoch, sample, writer) logger.info("Overall training time (hr) is:{}".format( str((end_time - start_time) / 3600.))) # whether to evaluate through nasbench ; if cfg.SPACE.NAME in [ "nasbench201", "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME)) theta = distribution_optimizer.p_model.theta EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
def sampling(self): # return self.sampling_index() self.sample = self.sampling_index() return index_to_one_hot(self.sample, self.p_model.Cmax)
def sampling(self): return index_to_one_hot(self.sampling_index(), self.p_model.Cmax)
def train_model(): """SNG search model training""" setup_env() # Load search space search_space = build_space() search_space.cuda() loss_fun = build_loss_fun().cuda() # Weights optimizer w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # Build distribution_optimizer if cfg.SPACE.NAME in ['darts', 'nasbench301']: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) elif cfg.SPACE.NAME in ['proxyless', 'google', 'ofa']: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) elif cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: category = [] cs = search_space.search_space.get_configuration_space() for h in cs.get_hyperparameters(): if type(h ) == ConfigSpace.hyperparameters.CategoricalHyperparameter: category.append(len(h.choices)) distribution_optimizer = sng_builder(category) else: raise NotImplementedError # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) lr_scheduler = lr_scheduler_builder(w_optim) all_timer = Timer() _over_all_epoch = 0 # ===== Warm up training ===== logger.info("start warm up training") warm_train_meter = meters.TrainMeter(len(train_)) warm_val_meter = meters.TestMeter(len(val_)) all_timer.tic() for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCHS): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr = lr_scheduler.get_last_lr()[0] if cfg.SNG.WARMUP_RANDOM_SAMPLE: sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch) logger.info("Sampling: {}".format(one_hot_to_index(sample))) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch_with_sample(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 else: num_ops, total_edges = search_space.num_ops, search_space.all_edges array_sample = [ random.sample(list(range(num_ops)), num_ops) for i in range(total_edges) ] array_sample = np.array(array_sample) for i in range(num_ops): sample = np.transpose(array_sample[:, i]) sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch_with_sample(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 all_timer.toc() logger.info("end warm up training") # ===== Training procedure ===== logger.info("start one-shot training") train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) all_timer.tic() for cur_epoch in range(cfg.OPTIM.MAX_EPOCH): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break lr = w_optim.param_groups[0]['lr'] sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch, _random=cfg.SNG.RANDOM_SAMPLE) logger.info("Sampling: {}".format(one_hot_to_index(sample))) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) top1 = test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 lr_scheduler.step() distribution_optimizer.record_information(sample, top1) distribution_optimizer.update() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info( search_space.genotype(distribution_optimizer.p_model.theta)) logger.info( "########################################################") logger.info("####### ALPHA #######") for alpha in distribution_optimizer.p_model.theta: logger.info(alpha) logger.info("#####################") if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() all_timer.toc() # ===== Final epoch ===== lr = w_optim.param_groups[0]['lr'] all_timer.tic() for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) if cfg.SPACE.NAME in ['darts', 'nasbench301']: genotype = search_space.genotype( distribution_optimizer.p_model.theta) sample = search_space.genotype_to_onehot_sample(genotype) else: sample = distribution_optimizer.sampling_best() _over_all_epoch += 1 train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch, sample, writer) logger.info("Overall training time : {} hours".format( str((all_timer.total_time) / 3600.))) # Evaluate through nasbench if cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3", "nasbench201", "nasbench301" ]: logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME)) theta = distribution_optimizer.p_model.theta EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)