Exemplo n.º 1
0
    'nasbench101': NasBench101SearchSpace(),
    'nasbench201': NasBench201SearchSpace(),
    'darts': DartsSearchSpace()
}

#load_labeled = (True if config.search_space == 'darts' else False)
load_labeled = False
dataset_api = get_dataset_api(config.search_space, config.dataset)
utils.set_seed(config.seed)

search_space = supported_search_spaces[config.search_space]

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space, dataset_api=dataset_api)

trainer = Trainer(optimizer, config, lightweight_output=True)

if config.optimizer == 'bananas':
    trainer.search(resume_from="")
    trainer.evaluate(resume_from="", dataset_api=dataset_api)
elif config.optimizer in ['oneshot', 'rsws']:
    predictor = OneShotPredictor(config, trainer, model_path=config.model_path)

    predictor_evaluator = PredictorEvaluator(predictor, config=config)
    predictor_evaluator.adapt_search_space(search_space,
                                           load_labeled=load_labeled,
                                           dataset_api=dataset_api)

    # evaluate the predictor
    predictor_evaluator.evaluate()
Exemplo n.º 2
0
utils.log_args(config)

supported_optimizers = {
    'darts': DARTSOptimizer(config),
    'gdas': GDASOptimizer(config),
    'oneshot': OneShotNASOptimizer(config),
    'rsws': RandomNASOptimizer(config),
    're': RegularizedEvolution(config),
    'rs': RandomSearch(config),
    'ls': RandomSearch(config),
    'bananas': Bananas(config),
    'bp': BasePredictor(config)
}

search_space = NasBench201SearchSpace()
dataset_api = get_dataset_api('nasbench201', config.dataset)

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config, lightweight_output=True)
#trainer.search()

#if not config.eval_only:
#    checkpoint = utils.get_last_checkpoint(config) if config.resume else ""
#    trainer.search(resume_from=checkpoint)

#checkpoint = utils.get_last_checkpoint(config, search=False) if config.resume else ""
#trainer.evaluate(resume_from=checkpoint, dataset_api=dataset_api)
Exemplo n.º 3
0
import logging

from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch
from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace

from naslib.utils import set_seed, setup_logger, get_config_from_args

config = get_config_from_args()     # use --help so see the options
set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)   # default DEBUG is very verbose

search_space = DartsSearchSpace()   # use SimpleCellSearchSpace() for less heavy search

optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)
trainer.search()        # Search for an architecture
trainer.evaluate()      # Evaluate the best architecture

Exemplo n.º 4
0
if config.optimizer == 'darts':
    optimizer = DARTSOptimizer(config)
elif config.optimizer == 'gdas':
    optimizer = GDASOptimizer(config)
elif config.optimizer == 'liu_et_al':
    optimizer = DARTSOptimizer(
        config)  # hack to instanciate the trainer (is ignored during eval)
    best_arch = LiuFinalArch()
elif config.optimizer == 'random':
    optimizer = DARTSOptimizer(
        config)  # hack to instanciate the trainer (is ignored during eval)
    best_arch = sample_random_architecture(search_space,
                                           search_space.OPTIMIZER_SCOPE)
else:
    raise ValueError("Unknown optimizer : {}".format(config.optimizer))

optimizer.adapt_search_space(search_space)
trainer = Trainer(optimizer, config)

if config.eval_only and not best_arch:
    trainer.evaluate(resume_from=utils.get_last_checkpoint(
        config, search=False) if config.resume else "")
elif best_arch:
    best_arch.parse()
    trainer.evaluate(best_arch=best_arch)
else:
    trainer.search(
        resume_from=utils.get_last_checkpoint(config) if config.resume else "")
    trainer.evaluate(resume_from=utils.get_last_checkpoint(
        config, search=False) if config.resume else "")
Exemplo n.º 5
0
logger.setLevel(logging.INFO)  # default DEBUG is too verbose

utils.log_args(config)

supported_optimizers = {
    'darts': DARTSOptimizer(config),
    'gdas': GDASOptimizer(config),
    'oneshot': OneShotNASOptimizer(config),
    'rsws': RandomNASOptimizer(config),
    're': RegularizedEvolution(config),
    'rs': RandomSearch(config),
    'ls': RandomSearch(config),
    'bananas': Bananas(config),
    'bp': BasePredictor(config)
}

if config.dataset == 'cifar100': DartsSearchSpace.NUM_CLASSES = 100
search_space = DartsSearchSpace()

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)
#trainer.search(resume_from=utils.get_last_checkpoint(config) if config.resume else "")

#if config.eval_only:
#trainer.evaluate(resume_from=utils.get_last_checkpoint(config, search=False) if config.resume else "")
#else:
#trainer.search(resume_from=utils.get_last_checkpoint(config) if config.resume else "")
#trainer.evaluate(resume_from=utils.get_last_checkpoint(config, search=False) if config.resume else "")
Exemplo n.º 6
0
config = utils.get_config_from_args()
utils.set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)  # default DEBUG is too verbose

utils.log_args(config)

supported_optimizers = {
    'darts': DARTSOptimizer(config),
    'gdas': GDASOptimizer(config),
    're': RegularizedEvolution(config),
    'rs': RandomSearch(config),
}

search_space = NasBench201SearchSpace()

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)

if not config.eval_only:
    checkpoint = utils.get_last_checkpoint(config) if config.resume else ""
    trainer.search(resume_from=checkpoint)

checkpoint = utils.get_last_checkpoint(config,
                                       search=False) if config.resume else ""
trainer.evaluate(resume_from=checkpoint)