Пример #1
0
    def __init__(self, dataset: str) -> None:
        """Creates ProblemFactory for NASBench201.

        Args:
            dataset:
                Accepts one of "cifar10", "cifar100" or "ImageNet16-120".
        """
        self._dataset = dataset
        if dataset == "cifar10":
            self._dataset = "cifar10-valid"  # Set name used in dataset API
        self._dataset_api = get_dataset_api("nasbench201", dataset)
Пример #2
0
supported_optimizers = {
    'bananas': Bananas(config),
    'oneshot': OneShotNASOptimizer(config),
    'rsws': RandomNASOptimizer(config),
}

supported_search_spaces = {
    'nasbench101': NasBench101SearchSpace(),
    'nasbench201': NasBench201SearchSpace(),
    'darts': DartsSearchSpace()
}

#load_labeled = (True if config.search_space == 'darts' else False)
load_labeled = False
dataset_api = get_dataset_api(config.search_space, config.dataset)
utils.set_seed(config.seed)

search_space = supported_search_spaces[config.search_space]

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space, dataset_api=dataset_api)

trainer = Trainer(optimizer, config, lightweight_output=True)

if config.optimizer == 'bananas':
    trainer.search(resume_from="")
    trainer.evaluate(resume_from="", dataset_api=dataset_api)
elif config.optimizer in ['oneshot', 'rsws']:
    predictor = OneShotPredictor(config, trainer, model_path=config.model_path)
Пример #3
0
utils.log_args(config)

supported_optimizers = {
    'darts': DARTSOptimizer(config),
    'gdas': GDASOptimizer(config),
    'oneshot': OneShotNASOptimizer(config),
    'rsws': RandomNASOptimizer(config),
    're': RegularizedEvolution(config),
    'rs': RandomSearch(config),
    'ls': RandomSearch(config),
    'bananas': Bananas(config),
    'bp': BasePredictor(config)
}

search_space = NasBench201SearchSpace()
dataset_api = get_dataset_api('nasbench201', config.dataset)

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config, lightweight_output=True)
#trainer.search()

#if not config.eval_only:
#    checkpoint = utils.get_last_checkpoint(config) if config.resume else ""
#    trainer.search(resume_from=checkpoint)

#checkpoint = utils.get_last_checkpoint(config, search=False) if config.resume else ""
#trainer.evaluate(resume_from=checkpoint, dataset_api=dataset_api)