def main(): setup_env() # loadiong search space search_space = build_space() search_space.cuda() # init controller and architect loss_fun = nn.CrossEntropyLoss().cuda() # weights optimizer w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # load dataset [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) # build distribution_optimizer if cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: category = [] cs = search_space.search_space.get_configuration_space() for h in cs.get_hyperparameters(): if type(h ) == ConfigSpace.hyperparameters.CategoricalHyperparameter: category.append(len(h.choices)) distribution_optimizer = sng_builder(category) else: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) lr_scheduler = lr_scheduler_builder(w_optim) # training loop logger.info("start warm up training") warm_train_meter = meters.TrainMeter(len(train_)) warm_val_meter = meters.TestMeter(len(val_)) start_time = time.time() _over_all_epoch = 0 for epoch in range(cfg.OPTIM.WARMUP_EPOCHS): # lr_scheduler.step() lr = lr_scheduler.get_last_lr()[0] # warm up training if cfg.SNG.WARMUP_RANDOM_SAMPLE: sample = random_sampling(search_space, distribution_optimizer, epoch=epoch) logger.info("The sample is: {}".format(one_hot_to_index(sample))) train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 else: num_ops, total_edges = search_space.num_ops, search_space.all_edges array_sample = [ random.sample(list(range(num_ops)), num_ops) for i in range(total_edges) ] array_sample = np.array(array_sample) for i in range(num_ops): sample = np.transpose(array_sample[:, i]) sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 # new version of warmup epoch logger.info("end warm up training") logger.info("start One shot searching") train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) for epoch in range(cfg.OPTIM.MAX_EPOCH): if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break lr = w_optim.param_groups[0]['lr'] # sample by the distribution optimizer # _ = distribution_optimizer.sampling() # random sample sample = random_sampling(search_space, distribution_optimizer, epoch=epoch, _random=cfg.SNG.RANDOM_SAMPLE) logger.info("The sample is: {}".format(one_hot_to_index(sample))) # training train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) # validation top1 = test_epoch(val_, search_space, val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 lr_scheduler.step() distribution_optimizer.record_information(sample, top1) distribution_optimizer.update() # Evaluate the model next_epoch = epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") logger.info( "###############Optimal genotype at epoch: {}############". format(epoch)) logger.info( search_space.genotype(distribution_optimizer.p_model.theta)) logger.info( "########################################################") logger.info("####### ALPHA #######") for alpha in distribution_optimizer.p_model.theta: logger.info(alpha) logger.info("#####################") if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() end_time = time.time() lr = w_optim.param_groups[0]['lr'] for epoch in range(cfg.OPTIM.FINAL_EPOCH): if cfg.SPACE.NAME == 'darts': genotype = search_space.genotype( distribution_optimizer.p_model.theta) sample = search_space.genotype_to_onehot_sample(genotype) else: sample = distribution_optimizer.sampling_best() _over_all_epoch += 1 train(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) test_epoch(val_, search_space, val_meter, _over_all_epoch, sample, writer) logger.info("Overall training time (hr) is:{}".format( str((end_time - start_time) / 3600.))) # whether to evaluate through nasbench ; if cfg.SPACE.NAME in [ "nasbench201", "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME)) theta = distribution_optimizer.p_model.theta EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
def main(): setup_env() # loadiong search space search_space = build_space() search_space.cuda() # init controller and architect loss_fun = nn.CrossEntropyLoss().cuda() # weights optimizer w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # load dataset [train_, val_] = _construct_loader( cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) distribution_optimizer = sng_builder([search_space.num_ops]*search_space.all_edges) lr_scheduler = lr_scheduler_builder(w_optim) num_ops, total_edges = search_space.num_ops, search_space.all_edges # training loop logger.info("start warm up training") warm_train_meter = meters.TrainMeter(len(train_)) for epoch in range(cfg.OPTIM.WARMUP_EPOCHS): # lr_scheduler.step() lr = lr_scheduler.get_last_lr()[0] # warm up training array_sample = [random.sample(list(range(num_ops)), num_ops) for i in range(total_edges)] array_sample = np.array(array_sample) for i in range(num_ops): sample = np.transpose(array_sample[:, i]) train(train_, val_, search_space, w_optim, lr, epoch, sample, loss_fun, warm_train_meter) logger.info("end warm up training") logger.info("start One shot searching") train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) for epoch in range(cfg.OPTIM.MAX_EPOCH): if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break lr = w_optim.param_groups[0]['lr'] sample = distribution_optimizer.sampling() # training train(train_, val_, search_space, w_optim, lr, epoch, sample, loss_fun, train_meter) # validation top1 = test_epoch(val_, search_space, val_meter, epoch, sample, writer) lr_scheduler.step() distribution_optimizer.record_information(sample, top1) distribution_optimizer.update() # Evaluate the model next_epoch = epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") logger.info("###############Optimal genotype at epoch: {}############".format(epoch)) logger.info(search_space.genotype(distribution_optimizer.p_model.theta)) logger.info("########################################################") logger.info("####### ALPHA #######") for alpha in distribution_optimizer.p_model.theta: logger.info(alpha) logger.info("#####################") if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect()
def darts_train_model(): """train DARTS model""" setup_env() # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect loss_fun = build_loss_fun().cuda() darts_controller = DartsCNNController(search_space, loss_fun) darts_controller.cuda() architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) # weights optimizer w_optim = torch.optim.SGD(darts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(darts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, darts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, darts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(darts_controller, loss_fun, train_, val_) # Setup timer train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) train_timer.tic() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, darts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint(darts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, darts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(darts_controller.genotype()) logger.info( "########################################################") if cfg.SPACE.NAME == "nasbench301": logger.info("Evaluating with nasbench301") EvaluateNasbench(darts_controller.alpha, darts_controller.net, logger, "nasbench301") darts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() train_timer.toc() logger.info("Overall training time (hr) is:{}".format( str(train_timer.total_time)))
def pcdarts_train_model(): """train PC-DARTS model""" setup_env() # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect loss_fun = build_loss_fun().cuda() pcdarts_controller = PCDartsCNNController(search_space, loss_fun) pcdarts_controller.cuda() architect = Architect(pcdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # Load dataset train_transform, valid_transform = data_transforms_cifar10(cutout_length=0) train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train)) train_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) val_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[split:num_train]), pin_memory=True, num_workers=2) # weights optimizer w_optim = torch.optim.SGD(pcdarts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(pcdarts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, pcdarts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pcdarts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(pcdarts_controller, loss_fun, train_, val_) # Setup timer train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) train_timer.tic() for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, pcdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint(pcdarts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, pcdarts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(pcdarts_controller.genotype()) logger.info( "########################################################") pcdarts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() train_timer.toc() logger.info("Overall training time (hr) is:{}".format( str(train_timer.total_time)))
def main(): setup_env() # loadiong search space search_space = build_space() # init controller and architect loss_fun = nn.CrossEntropyLoss().cuda() darts_controller = DartsCNNController(search_space, loss_fun) darts_controller.cuda() architect = Architect(darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # load dataset [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) # weights optimizer w_optim = torch.optim.SGD(darts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer alpha_optim = torch.optim.Adam(darts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR) train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) start_epoch = 0 # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, darts_controller, architect, loss_fun, w_optim, alpha_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( darts_controller, w_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, darts_controller, val_meter, cur_epoch, tensorboard_writer=writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(darts_controller.genotype()) logger.info( "########################################################") darts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect()
def main(): tic = time.time() setup_env() # loadiong search space search_space = build_space() # init controller and architect loss_fun = nn.CrossEntropyLoss().cuda() darts_controller = PCDartsCNNController(search_space, loss_fun) darts_controller.cuda() architect = Architect( darts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # load dataset #[train_, val_] = _construct_loader( # cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) train_transform, valid_transform = _data_transforms_cifar10() train_data = dset.CIFAR10(root=cfg.SEARCH.DATASET, train=True, download=True, transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(cfg.SEARCH.SPLIT[0] * num_train)) train_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) val_ = torch.utils.data.DataLoader( train_data, batch_size=cfg.SEARCH.BATCH_SIZE, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]), pin_memory=True, num_workers=2) # weights optimizer w_optim = torch.optim.SGD(darts_controller.weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer alpha_optim = torch.optim.Adam(darts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR) train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) start_epoch = 0 # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): logger.info("###############Optimal genotype at epoch: {}############".format(cur_epoch)) logger.info(darts_controller.genotype()) logger.info("########################################################") darts_controller.print_alphas(logger) lr_scheduler.step() lr = lr_scheduler.get_last_lr()[0] train_epoch(train_, val_, darts_controller, architect, loss_fun, w_optim, alpha_optim, lr, train_meter, cur_epoch) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( darts_controller, w_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, darts_controller, val_meter, cur_epoch, tensorboard_writer=writer) logger.info("###############Optimal genotype at epoch: {}############".format(cur_epoch)) logger.info(darts_controller.genotype()) logger.info("########################################################") darts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() toc = time.time() logger.info("Search-time(GPUh): {}".format((toc - tic)/3600))
def train_model(): """SNG search model training""" setup_env() # Load search space search_space = build_space() search_space.cuda() loss_fun = build_loss_fun().cuda() # Weights optimizer w_optim = torch.optim.SGD(search_space.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # Build distribution_optimizer if cfg.SPACE.NAME in ['darts', 'nasbench301']: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) elif cfg.SPACE.NAME in ['proxyless', 'google', 'ofa']: distribution_optimizer = sng_builder([search_space.num_ops] * search_space.all_edges) elif cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3" ]: category = [] cs = search_space.search_space.get_configuration_space() for h in cs.get_hyperparameters(): if type(h ) == ConfigSpace.hyperparameters.CategoricalHyperparameter: category.append(len(h.choices)) distribution_optimizer = sng_builder(category) else: raise NotImplementedError # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) lr_scheduler = lr_scheduler_builder(w_optim) all_timer = Timer() _over_all_epoch = 0 # ===== Warm up training ===== logger.info("start warm up training") warm_train_meter = meters.TrainMeter(len(train_)) warm_val_meter = meters.TestMeter(len(val_)) all_timer.tic() for cur_epoch in range(cfg.OPTIM.WARMUP_EPOCHS): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr = lr_scheduler.get_last_lr()[0] if cfg.SNG.WARMUP_RANDOM_SAMPLE: sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch) logger.info("Sampling: {}".format(one_hot_to_index(sample))) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch_with_sample(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 else: num_ops, total_edges = search_space.num_ops, search_space.all_edges array_sample = [ random.sample(list(range(num_ops)), num_ops) for i in range(total_edges) ] array_sample = np.array(array_sample) for i in range(num_ops): sample = np.transpose(array_sample[:, i]) sample = index_to_one_hot(sample, distribution_optimizer.p_model.Cmax) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, warm_train_meter) top1 = test_epoch_with_sample(val_, search_space, warm_val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 all_timer.toc() logger.info("end warm up training") # ===== Training procedure ===== logger.info("start one-shot training") train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) all_timer.tic() for cur_epoch in range(cfg.OPTIM.MAX_EPOCH): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) if hasattr(distribution_optimizer, 'training_finish'): if distribution_optimizer.training_finish: break lr = w_optim.param_groups[0]['lr'] sample = random_sampling(search_space, distribution_optimizer, epoch=cur_epoch, _random=cfg.SNG.RANDOM_SAMPLE) logger.info("Sampling: {}".format(one_hot_to_index(sample))) train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) top1 = test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch, sample, writer) _over_all_epoch += 1 lr_scheduler.step() distribution_optimizer.record_information(sample, top1) distribution_optimizer.update() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info( search_space.genotype(distribution_optimizer.p_model.theta)) logger.info( "########################################################") logger.info("####### ALPHA #######") for alpha in distribution_optimizer.p_model.theta: logger.info(alpha) logger.info("#####################") if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() all_timer.toc() # ===== Final epoch ===== lr = w_optim.param_groups[0]['lr'] all_timer.tic() for cur_epoch in range(cfg.OPTIM.FINAL_EPOCH): # Save a checkpoint if (_over_all_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( search_space, w_optim, _over_all_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) if cfg.SPACE.NAME in ['darts', 'nasbench301']: genotype = search_space.genotype( distribution_optimizer.p_model.theta) sample = search_space.genotype_to_onehot_sample(genotype) else: sample = distribution_optimizer.sampling_best() _over_all_epoch += 1 train_epoch(train_, val_, search_space, w_optim, lr, _over_all_epoch, sample, loss_fun, train_meter) test_epoch_with_sample(val_, search_space, val_meter, _over_all_epoch, sample, writer) logger.info("Overall training time : {} hours".format( str((all_timer.total_time) / 3600.))) # Evaluate through nasbench if cfg.SPACE.NAME in [ "nasbench1shot1_1", "nasbench1shot1_2", "nasbench1shot1_3", "nasbench201", "nasbench301" ]: logger.info("starting test using nasbench:{}".format(cfg.SPACE.NAME)) theta = distribution_optimizer.p_model.theta EvaluateNasbench(theta, search_space, logger, cfg.SPACE.NAME)
def main(): setup_env() # loadiong search space # init controller and architect loss_fun = nn.CrossEntropyLoss().cuda() # load dataset [train_, val_] = _construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) num_to_keep = [5, 3, 1] eps_no_archs = [10, 10, 10] drop_rate = [0.1, 0.4, 0.7] add_layers = [0, 6, 12] add_width = [0, 0, 0] PRIMITIVES = cfg.SPACE.PRIMITIVES edgs_num = (cfg.SPACE.NODES + 3) * cfg.SPACE.NODES // 2 basic_op = [] for i in range(edgs_num * 2): basic_op.append(PRIMITIVES) for sp in range(len(num_to_keep)): # update the info of the supernet config cfg.defrost() cfg.SEARCH.add_layers = add_layers[sp] cfg.SEARCH.add_width = add_width[sp] cfg.SEARCH.dropout_rate = float(drop_rate[sp]) cfg.SPACE.BASIC_OP = basic_op search_space = build_space() controller = PdartsCNNController(search_space, loss_fun) controller.cuda() architect = Architect(controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # weights optimizer w_optim = torch.optim.SGD(controller.subnet_weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer alpha_optim = torch.optim.Adam( controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( w_optim, cfg.OPTIM.MAX_EPOCH, eta_min=cfg.OPTIM.MIN_LR) train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) start_epoch = 0 # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) scale_factor = 0.2 for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): print('cur_epoch', cur_epoch) lr = lr_scheduler.get_last_lr()[0] if cur_epoch < eps_no_archs[sp]: controller.update_p( float(drop_rate[sp]) * (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) / cfg.OPTIM.MAX_EPOCH) train_epoch(train_, val_, controller, architect, loss_fun, w_optim, alpha_optim, lr, train_meter, cur_epoch, train_arch=False) else: controller.update_p( float(drop_rate[sp]) * np.exp(-(cur_epoch - eps_no_archs[sp]) * scale_factor)) train_epoch(train_, val_, controller, architect, loss_fun, w_optim, alpha_optim, lr, train_meter, cur_epoch, train_arch=True) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( controller, w_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch >= cfg.OPTIM.MAX_EPOCH - 5: logger.info("Start testing") test_epoch(val_, controller, val_meter, cur_epoch, tensorboard_writer=writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(controller.genotype()) logger.info( "########################################################") controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() print("now top k primitive", num_to_keep[sp], controller.get_topk_op(num_to_keep[sp])) if sp == len(num_to_keep) - 1: logger.info( "###############final Optimal genotype: {}############") logger.info(controller.genotype(final=True)) logger.info( "########################################################") controller.print_alphas(logger) logger.info('Restricting skipconnect...') for sks in range(0, 9): max_sk = 8 - sks num_sk = controller.get_skip_number() if not num_sk > max_sk: continue while num_sk > max_sk: controller.delete_skip() num_sk = controller.get_skip_number() logger.info('Number of skip-connect: %d', max_sk) logger.info(controller.genotype(final=True)) else: basic_op = controller.get_topk_op(num_to_keep[sp]) logger.info("###############final Optimal genotype: {}############") logger.info(controller.genotype(final=True)) logger.info("########################################################") controller.print_alphas(logger) logger.info('Restricting skipconnect...') for sks in range(0, 9): max_sk = 8 - sks num_sk = controller.get_skip_number() if not num_sk > max_sk: continue while num_sk > max_sk: controller.delete_skip() num_sk = controller.get_skip_number() logger.info('Number of skip-connect: %d', max_sk) logger.info(controller.genotype(final=True))
def main(): setup_env() # 32 3 10 === 32 16 10 # print(input_size, input_channels, n_classes, '===', cfg.SEARCH.IM_SIZE, cfg.SPACE.CHANNEL, cfg.SEARCH.NUM_CLASSES) loss_fun = build_loss_fun().cuda() use_aux = cfg.TRAIN.AUX_WEIGHT > 0. # SEARCH.INIT_CHANNEL as 3 for rgb and TRAIN.CHANNELS as 32 by manual. # IM_SIZE, CHANNEL and NUM_CLASSES should be same with search period. model = AugmentCNN(cfg.SEARCH.IM_SIZE, cfg.SEARCH.INPUT_CHANNEL, cfg.TRAIN.CHANNELS, cfg.SEARCH.NUM_CLASSES, cfg.TRAIN.LAYERS, use_aux, cfg.TRAIN.GENOTYPE) # TODO: Parallel # model = nn.DataParallel(model, device_ids=cfg.NUM_GPUS).to(device) model.cuda() # weights optimizer optimizer = torch.optim.SGD(model.parameters(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # Get data loader [train_loader, valid_loader] = construct_loader(cfg.TRAIN.DATASET, cfg.TRAIN.SPLIT, cfg.TRAIN.BATCH_SIZE) lr_scheduler = lr_scheduler_builder(optimizer) best_top1err = 0. # TODO: DALI backend support # if config.data_loader_type == 'DALI': # len_train_loader = get_train_loader_len(config.dataset.lower(), config.batch_size, is_train=True) # else: len_train_loader = len(train_loader) # Training loop # TODO: RESUME train_meter = meters.TrainMeter(len(train_loader)) valid_meter = meters.TestMeter(len(valid_loader)) start_epoch = 0 for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): drop_prob = cfg.TRAIN.DROP_PATH_PROB * cur_epoch / cfg.OPTIM.MAX_EPOCH if cfg.NUM_GPUS > 1: model.module.drop_path_prob(drop_prob) else: model.drop_path_prob(drop_prob) # Training train_epoch(train_loader, model, optimizer, loss_fun, cur_epoch, train_meter) # Save a checkpoint if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: checkpoint_file = checkpoint.save_checkpoint( model, optimizer, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Validation cur_step = (cur_epoch + 1) * len(train_loader) top1_err = valid_epoch(valid_loader, model, loss_fun, cur_epoch, cur_step, valid_meter) logger.info("top1 error@epoch {}: {}".format(cur_epoch + 1, top1_err)) best_top1err = min(best_top1err, top1_err) logger.info("Final best Prec@1 = {:.4%}".format(100 - best_top1err))
def pdarts_train_model(): """train PDARTS model""" num_to_keep = [5, 3, 1] eps_no_archs = [10, 10, 10] drop_rate = [0.1, 0.4, 0.7] add_layers = [0, 6, 12] add_width = [0, 0, 0] scale_factor = 0.2 PRIMITIVES = cfg.SPACE.PRIMITIVES edgs_num = (cfg.SPACE.NODES + 3) * cfg.SPACE.NODES // 2 basic_op = [] for i in range(edgs_num * 2): basic_op.append(PRIMITIVES) setup_env() loss_fun = build_loss_fun().cuda() # Load dataset [train_, val_] = construct_loader(cfg.SEARCH.DATASET, cfg.SEARCH.SPLIT, cfg.SEARCH.BATCH_SIZE) for sp in range(len(num_to_keep)): # Update info of supernet config cfg.defrost() cfg.SEARCH.add_layers = add_layers[sp] cfg.SEARCH.add_width = add_width[sp] cfg.SEARCH.dropout_rate = float(drop_rate[sp]) cfg.SPACE.BASIC_OP = basic_op # Loading search space search_space = build_space() # TODO: fix the complexity function # search_space = setup_model() # Init controller and architect pdarts_controller = PDartsCNNController(search_space, loss_fun) pdarts_controller.cuda() architect = Architect(pdarts_controller, cfg.OPTIM.MOMENTUM, cfg.OPTIM.WEIGHT_DECAY) # weights optimizer w_optim = torch.optim.SGD(pdarts_controller.subnet_weights(), cfg.OPTIM.BASE_LR, momentum=cfg.OPTIM.MOMENTUM, weight_decay=cfg.OPTIM.WEIGHT_DECAY) # alphas optimizer a_optim = torch.optim.Adam(pdarts_controller.alphas(), cfg.DARTS.ALPHA_LR, betas=(0.5, 0.999), weight_decay=cfg.DARTS.ALPHA_WEIGHT_DECAY) lr_scheduler = lr_scheduler_builder(w_optim) # Init meters train_meter = meters.TrainMeter(len(train_)) val_meter = meters.TestMeter(len(val_)) # Load checkpoint or initial weights start_epoch = 0 if cfg.SEARCH.AUTO_RESUME and checkpoint.has_checkpoint(): last_checkpoint = checkpoint.get_last_checkpoint() checkpoint_epoch = darts_load_checkpoint(last_checkpoint, pdarts_controller, w_optim, a_optim) logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) start_epoch = checkpoint_epoch + 1 elif cfg.SEARCH.WEIGHTS: darts_load_checkpoint(cfg.SEARCH.WEIGHTS, pdarts_controller) logger.info("Loaded initial weights from: {}".format( cfg.SEARCH.WEIGHTS)) # Compute model and loader timings if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: benchmark.compute_time_full(pdarts_controller, loss_fun, train_, val_) # TODO: Setup timer # train_timer = Timer() # Perform the training loop logger.info("Start epoch: {}".format(start_epoch + 1)) for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): print('cur_epoch', cur_epoch) lr = lr_scheduler.get_last_lr()[0] if cur_epoch < eps_no_archs[sp]: pdarts_controller.update_p( float(drop_rate[sp]) * (cfg.OPTIM.MAX_EPOCH - cur_epoch - 1) / cfg.OPTIM.MAX_EPOCH) train_epoch(train_, val_, pdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch, train_arch=False) else: pdarts_controller.update_p( float(drop_rate[sp]) * np.exp(-(cur_epoch - eps_no_archs[sp]) * scale_factor)) train_epoch(train_, val_, pdarts_controller, architect, loss_fun, w_optim, a_optim, lr, train_meter, cur_epoch, train_arch=True) # Save a checkpoint if (cur_epoch + 1) % cfg.SEARCH.CHECKPOINT_PERIOD == 0: checkpoint_file = darts_save_checkpoint( pdarts_controller, w_optim, a_optim, cur_epoch) logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) lr_scheduler.step() # Evaluate the model next_epoch = cur_epoch + 1 if next_epoch >= cfg.OPTIM.MAX_EPOCH - 5: # if next_epoch % cfg.SEARCH.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: logger.info("Start testing") test_epoch(val_, pdarts_controller, val_meter, cur_epoch, writer) logger.info( "###############Optimal genotype at epoch: {}############". format(cur_epoch)) logger.info(pdarts_controller.genotype()) logger.info( "########################################################") pdarts_controller.print_alphas(logger) if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache( ) # https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637 gc.collect() print("Top-{} primitive: {}".format( num_to_keep[sp], pdarts_controller.get_topk_op(num_to_keep[sp]))) if sp == len(num_to_keep) - 1: logger.info( "###############final Optimal genotype: {}############") logger.info(pdarts_controller.genotype(final=True)) logger.info( "########################################################") pdarts_controller.print_alphas(logger) logger.info('Restricting skipconnect...') for sks in range(0, 9): max_sk = 8 - sks num_sk = pdarts_controller.get_skip_number() if not num_sk > max_sk: continue while num_sk > max_sk: pdarts_controller.delete_skip() num_sk = pdarts_controller.get_skip_number() logger.info('Number of skip-connect: %d', max_sk) logger.info(pdarts_controller.genotype(final=True)) else: basic_op = pdarts_controller.get_topk_op(num_to_keep[sp]) logger.info("###############final Optimal genotype: {}############") logger.info(pdarts_controller.genotype(final=True)) logger.info("########################################################") pdarts_controller.print_alphas(logger) logger.info('Restricting skipconnect...') for sks in range(0, 9): max_sk = 8 - sks num_sk = pdarts_controller.get_skip_number() if not num_sk > max_sk: continue while num_sk > max_sk: pdarts_controller.delete_skip() num_sk = pdarts_controller.get_skip_number() logger.info('Number of skip-connect: %d', max_sk) logger.info(pdarts_controller.genotype(final=True))