'nasbench101': NasBench101SearchSpace(), 'nasbench201': NasBench201SearchSpace(), 'darts': DartsSearchSpace() } #load_labeled = (True if config.search_space == 'darts' else False) load_labeled = False dataset_api = get_dataset_api(config.search_space, config.dataset) utils.set_seed(config.seed) search_space = supported_search_spaces[config.search_space] optimizer = supported_optimizers[config.optimizer] optimizer.adapt_search_space(search_space, dataset_api=dataset_api) trainer = Trainer(optimizer, config, lightweight_output=True) if config.optimizer == 'bananas': trainer.search(resume_from="") trainer.evaluate(resume_from="", dataset_api=dataset_api) elif config.optimizer in ['oneshot', 'rsws']: predictor = OneShotPredictor(config, trainer, model_path=config.model_path) predictor_evaluator = PredictorEvaluator(predictor, config=config) predictor_evaluator.adapt_search_space(search_space, load_labeled=load_labeled, dataset_api=dataset_api) # evaluate the predictor predictor_evaluator.evaluate()
utils.log_args(config) supported_optimizers = { 'darts': DARTSOptimizer(config), 'gdas': GDASOptimizer(config), 'oneshot': OneShotNASOptimizer(config), 'rsws': RandomNASOptimizer(config), 're': RegularizedEvolution(config), 'rs': RandomSearch(config), 'ls': RandomSearch(config), 'bananas': Bananas(config), 'bp': BasePredictor(config) } search_space = NasBench201SearchSpace() dataset_api = get_dataset_api('nasbench201', config.dataset) optimizer = supported_optimizers[config.optimizer] optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config, lightweight_output=True) #trainer.search() #if not config.eval_only: # checkpoint = utils.get_last_checkpoint(config) if config.resume else "" # trainer.search(resume_from=checkpoint) #checkpoint = utils.get_last_checkpoint(config, search=False) if config.resume else "" #trainer.evaluate(resume_from=checkpoint, dataset_api=dataset_api)
import logging from naslib.defaults.trainer import Trainer from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace from naslib.utils import set_seed, setup_logger, get_config_from_args config = get_config_from_args() # use --help so see the options set_seed(config.seed) logger = setup_logger(config.save + "/log.log") logger.setLevel(logging.INFO) # default DEBUG is very verbose search_space = DartsSearchSpace() # use SimpleCellSearchSpace() for less heavy search optimizer = DARTSOptimizer(config) optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config) trainer.search() # Search for an architecture trainer.evaluate() # Evaluate the best architecture
if config.optimizer == 'darts': optimizer = DARTSOptimizer(config) elif config.optimizer == 'gdas': optimizer = GDASOptimizer(config) elif config.optimizer == 'liu_et_al': optimizer = DARTSOptimizer( config) # hack to instanciate the trainer (is ignored during eval) best_arch = LiuFinalArch() elif config.optimizer == 'random': optimizer = DARTSOptimizer( config) # hack to instanciate the trainer (is ignored during eval) best_arch = sample_random_architecture(search_space, search_space.OPTIMIZER_SCOPE) else: raise ValueError("Unknown optimizer : {}".format(config.optimizer)) optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config) if config.eval_only and not best_arch: trainer.evaluate(resume_from=utils.get_last_checkpoint( config, search=False) if config.resume else "") elif best_arch: best_arch.parse() trainer.evaluate(best_arch=best_arch) else: trainer.search( resume_from=utils.get_last_checkpoint(config) if config.resume else "") trainer.evaluate(resume_from=utils.get_last_checkpoint( config, search=False) if config.resume else "")
logger.setLevel(logging.INFO) # default DEBUG is too verbose utils.log_args(config) supported_optimizers = { 'darts': DARTSOptimizer(config), 'gdas': GDASOptimizer(config), 'oneshot': OneShotNASOptimizer(config), 'rsws': RandomNASOptimizer(config), 're': RegularizedEvolution(config), 'rs': RandomSearch(config), 'ls': RandomSearch(config), 'bananas': Bananas(config), 'bp': BasePredictor(config) } if config.dataset == 'cifar100': DartsSearchSpace.NUM_CLASSES = 100 search_space = DartsSearchSpace() optimizer = supported_optimizers[config.optimizer] optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config) #trainer.search(resume_from=utils.get_last_checkpoint(config) if config.resume else "") #if config.eval_only: #trainer.evaluate(resume_from=utils.get_last_checkpoint(config, search=False) if config.resume else "") #else: #trainer.search(resume_from=utils.get_last_checkpoint(config) if config.resume else "") #trainer.evaluate(resume_from=utils.get_last_checkpoint(config, search=False) if config.resume else "")
config = utils.get_config_from_args() utils.set_seed(config.seed) logger = setup_logger(config.save + "/log.log") logger.setLevel(logging.INFO) # default DEBUG is too verbose utils.log_args(config) supported_optimizers = { 'darts': DARTSOptimizer(config), 'gdas': GDASOptimizer(config), 're': RegularizedEvolution(config), 'rs': RandomSearch(config), } search_space = NasBench201SearchSpace() optimizer = supported_optimizers[config.optimizer] optimizer.adapt_search_space(search_space) trainer = Trainer(optimizer, config) if not config.eval_only: checkpoint = utils.get_last_checkpoint(config) if config.resume else "" trainer.search(resume_from=checkpoint) checkpoint = utils.get_last_checkpoint(config, search=False) if config.resume else "" trainer.evaluate(resume_from=checkpoint)