Ejemplo n.º 1
0
import logging
import sys
import naslib as nl

from naslib.defaults.predictor_evaluator import PredictorEvaluator
from naslib.defaults.trainer import Trainer
from naslib.optimizers import Bananas, OneShotNASOptimizer, RandomNASOptimizer
from naslib.predictors import OneShotPredictor

from naslib.search_spaces import NasBench101SearchSpace, NasBench201SearchSpace, DartsSearchSpace
from naslib.utils import utils, setup_logger, get_dataset_api
from naslib.utils.utils import get_project_root

config = utils.get_config_from_args(config_type='oneshot')

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)

utils.log_args(config)

supported_optimizers = {
    'bananas': Bananas(config),
    'oneshot': OneShotNASOptimizer(config),
    'rsws': RandomNASOptimizer(config),
}

supported_search_spaces = {
    'nasbench101': NasBench101SearchSpace(),
    'nasbench201': NasBench201SearchSpace(),
    'darts': DartsSearchSpace()
}
import unittest
import logging
import torch
import os

from naslib.search_spaces import HierarchicalSearchSpace
from naslib.optimizers import DARTSOptimizer, GDASOptimizer
from naslib.utils import utils, setup_logger

logger = setup_logger(
    os.path.join(utils.get_project_root().parent, "tmp", "tests.log"))
logger.handlers[0].setLevel(logging.FATAL)

config = utils.AttrDict()
config.dataset = 'cifar10'
config.search = utils.AttrDict()
config.search.grad_clip = None
config.search.learning_rate = 0.01
config.search.momentum = 0.1
config.search.weight_decay = 0.1
config.search.arch_learning_rate = 0.01
config.search.arch_weight_decay = 0.1
config.search.tau_max = 10
config.search.tau_min = 1
config.search.epochs = 2

data_train = (torch.ones([2, 3, 32, 32]), torch.ones([2]).long())
data_val = (torch.ones([2, 3, 32, 32]), torch.ones([2]).long())

if torch.cuda.is_available():
    data_train = tuple(x.cuda() for x in data_train)