Exemplo n.º 1
0
    def __init__(self, optimizer, config, lightweight_output=False):
        """
        Initializes the trainer.

        Args:
            optimizer: A NASLib optimizer
            config (AttrDict): The configuration loaded from a yaml file, e.g
                via  `utils.get_config_from_args()`
        """
        self.optimizer = optimizer
        self.config = config
        self.epochs = self.config.search.epochs
        self.lightweight_output = lightweight_output

        # preparations
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # measuring stuff
        self.train_top1 = utils.AverageMeter()
        self.train_top5 = utils.AverageMeter()
        self.train_loss = utils.AverageMeter()
        self.val_top1 = utils.AverageMeter()
        self.val_top5 = utils.AverageMeter()
        self.val_loss = utils.AverageMeter()

        n_parameters = optimizer.get_model_size()
        logger.info("param size = %fMB", n_parameters)
        self.errors_dict = utils.AttrDict(
            {'train_acc': [],
             'train_loss': [],
             'valid_acc': [],
             'valid_loss': [],
             'test_acc': [],
             'test_loss': [],
             'runtime': [],
             'train_time': [],
             'arch_eval': [],
             'params': n_parameters}
        )
import unittest
import logging
import torch
import os

from naslib.search_spaces import HierarchicalSearchSpace
from naslib.optimizers import DARTSOptimizer, GDASOptimizer
from naslib.utils import utils, setup_logger

logger = setup_logger(
    os.path.join(utils.get_project_root().parent, "tmp", "tests.log"))
logger.handlers[0].setLevel(logging.FATAL)

config = utils.AttrDict()
config.dataset = 'cifar10'
config.search = utils.AttrDict()
config.search.grad_clip = None
config.search.learning_rate = 0.01
config.search.momentum = 0.1
config.search.weight_decay = 0.1
config.search.arch_learning_rate = 0.01
config.search.arch_weight_decay = 0.1
config.search.tau_max = 10
config.search.tau_min = 1
config.search.epochs = 2

data_train = (torch.ones([2, 3, 32, 32]), torch.ones([2]).long())
data_val = (torch.ones([2, 3, 32, 32]), torch.ones([2]).long())

if torch.cuda.is_available():
    data_train = tuple(x.cuda() for x in data_train)
Exemplo n.º 3
0
    def __init__(self, graph, parser, *args, **kwargs):
        self.graph = graph
        self.parser = parser
        try:
            self.config = kwargs.get('config', graph.config)
        except:
            raise ('No configuration specified in graph or kwargs')
        np.random.seed(self.config.seed)
        random.seed(self.config.seed)
        if torch.cuda.is_available():
            torch.manual_seed(self.config.seed)
            torch.cuda.set_device(self.config.gpu)
            cudnn.benchmark = False
            cudnn.enabled = True
            cudnn.deterministic = True
            torch.cuda.manual_seed_all(self.config.seed)

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        # dataloaders
        train_queue, valid_queue, test_queue, train_transform, valid_transform = parser.get_train_val_loaders(
        )
        self.train_queue = train_queue
        self.valid_queue = valid_queue
        self.test_queue = test_queue
        self.train_transform = train_transform
        self.valid_transform = valid_transform

        criterion = eval('nn.' + self.config.criterion)()
        self.criterion = criterion.cuda()

        self.model = self.graph.to(self.device)

        n_parameters = utils.count_parameters_in_MB(self.model)
        logging.info("param size = %fMB", n_parameters)

        optimizer = torch.optim.SGD(self.model.parameters(),
                                    self.config.learning_rate,
                                    momentum=self.config.momentum,
                                    weight_decay=self.config.weight_decay)
        self.optimizer = optimizer

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            float(self.config.epochs),
            eta_min=self.config.learning_rate_min)

        logging.info('Args: {}'.format(self.config))
        self.run_kwargs = {}

        self.errors_dict = utils.AttrDict({
            'train_acc': [],
            'train_loss': [],
            'valid_acc': [],
            'valid_loss': [],
            'test_acc': [],
            'test_loss': [],
            'runtime': [],
            'params': n_parameters
        })