class HierarchicalDartsIntegrationTest(unittest.TestCase):
    def setUp(self):
        utils.set_seed(1)
        self.optimizer = DARTSOptimizer(config)
        self.optimizer.adapt_search_space(HierarchicalSearchSpace())
        self.optimizer.before_training()

    def test_update(self):
        stats = self.optimizer.step(data_train, data_val)
        self.assertTrue(len(stats) == 4)
        self.assertAlmostEqual(stats[2].detach().cpu().numpy(),
                               2.4094,
                               places=3)
        self.assertAlmostEqual(stats[3].detach().cpu().numpy(),
                               2.4094,
                               places=3)

    def test_feed_forward(self):
        final_arch = self.optimizer.get_final_architecture()
        logits = final_arch(data_train[0])
        self.assertTrue(logits.shape == (2, 10))
        self.assertAlmostEqual(logits[0, 0].detach().cpu().numpy(),
                               -0.0545,
                               places=3)
 def setUp(self):
     utils.set_seed(1)
     self.optimizer = DARTSOptimizer(config)
     self.optimizer.adapt_search_space(HierarchicalSearchSpace())
     self.optimizer.before_training()
예제 #3
0
    HierarchicalSearchSpace,
)

from naslib.utils import utils, setup_logger

# Read args and config, setup logger
config = utils.get_config_from_args()
utils.set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)  # default DEBUG is too verbose

utils.log_args(config)

supported_optimizers = {
    'darts': DARTSOptimizer(config.search),
    'gdas': GDASOptimizer(config.search),
    'random': RandomSearch(sample_size=1),
    #'re': RegularizedEvolution(config.search),
}

# search_space = SimpleCellSearchSpace()
search_space = NasBench201SeachSpace()
# search_space = HierarchicalSearchSpace()
# search_space = DartsSearchSpace()

assert search_space.QUERYABLE

optimizer = supported_optimizers[config.optimizer]

optimizer.adapt_search_space(search_space)
예제 #4
0
파일: runner.py 프로젝트: kashankrm/NASLib
OneShotNASOptimizer, RandomNASOptimizer, RandomSearch, \
RegularizedEvolution, LocalSearch, Bananas, BasePredictor

from naslib.search_spaces import NasBench201SearchSpace
from naslib.utils import utils, setup_logger, get_dataset_api

config = utils.get_config_from_args()
utils.set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)   # default DEBUG is too verbose

utils.log_args(config)

supported_optimizers = {
    'darts': DARTSOptimizer(config),
    'gdas': GDASOptimizer(config),
    'oneshot': OneShotNASOptimizer(config),
    'rsws': RandomNASOptimizer(config),
    're': RegularizedEvolution(config),
    'rs': RandomSearch(config),
    'ls': RandomSearch(config),
    'bananas': Bananas(config),
    'bp': BasePredictor(config)
}

search_space = NasBench201SearchSpace()
dataset_api = get_dataset_api('nasbench201', config.dataset)

optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space)
예제 #5
0
import logging

from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch
from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace

from naslib.utils import set_seed, setup_logger, get_config_from_args

config = get_config_from_args()     # use --help so see the options
set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)   # default DEBUG is very verbose

search_space = DartsSearchSpace()   # use SimpleCellSearchSpace() for less heavy search

optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)
trainer.search()        # Search for an architecture
trainer.evaluate()      # Evaluate the best architecture

예제 #6
0
from naslib.search_spaces.hierarchical.graph import HierarchicalSearchSpace, LiuFinalArch
from naslib.optimizers.discrete.rs.optimizer import sample_random_architecture

config = utils.get_config_from_args()
utils.set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)  # default DEBUG is too verbose

utils.log_args(config)

search_space = HierarchicalSearchSpace()

best_arch = None
if config.optimizer == 'darts':
    optimizer = DARTSOptimizer(config)
elif config.optimizer == 'gdas':
    optimizer = GDASOptimizer(config)
elif config.optimizer == 'liu_et_al':
    optimizer = DARTSOptimizer(
        config)  # hack to instanciate the trainer (is ignored during eval)
    best_arch = LiuFinalArch()
elif config.optimizer == 'random':
    optimizer = DARTSOptimizer(
        config)  # hack to instanciate the trainer (is ignored during eval)
    best_arch = sample_random_architecture(search_space,
                                           search_space.OPTIMIZER_SCOPE)
else:
    raise ValueError("Unknown optimizer : {}".format(config.optimizer))

optimizer.adapt_search_space(search_space)