Esempio n. 1
0
    def child_env(self, model, train=True):
        logging.info('Creating child {} environment'.format(
            'train' if train else 'valid'))
        if train:
            optimizer = build_optimizer(self.hparams, model.parameters())
            lr_scheduler = build_lr_scheduler(self.hparams, optimizer)
            logging.info('Child optimizer: {}'.format(
                optimizer.__class__.__name__))
            logging.info('Child LR Scheduler: {}'.format(
                lr_scheduler.__class__.__name__))

        old_model = self.model
        self.model = model
        if train:
            old_optimizer = self.optimizer
            old_lr_scheduler = self.lr_scheduler
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler

        try:
            yield
        finally:
            self.model = old_model
            if train:
                self.optimizer = old_optimizer
                self.lr_scheduler = old_lr_scheduler
            logging.info('Trainer restored')
Esempio n. 2
0
    def __init__(self, hparams, model, criterion):
        super().__init__(hparams, model, criterion)

        # [NOTE]: In DARTS, optimizer is fixed to Momentum SGD, and lr scheduler is fixed to CosineAnnealingLR.
        assert hparams.optimizer == 'sgd', 'DARTS training must use SGD as optimizer'

        # [NOTE]: In DARTS, arch optimizer is fixed to adam, and no arch lr scheduler.
        with hparams_env(
                hparams,
                optimizer=hparams.arch_optimizer,
                lr=[hparams.arch_lr],
                adam_betas=hparams.arch_adam_betas,
                adam_eps=1e-8,
                weight_decay=hparams.arch_weight_decay,
        ) as arch_hparams:
            self.arch_optimizer = build_optimizer(arch_hparams,
                                                  self.model.arch_parameters())
            logging.info('Arch optimizer: {}'.format(
                self.arch_optimizer.__class__.__name__))

        self.network_momentum = hparams.momentum
        self.network_weight_decay = hparams.weight_decay
Esempio n. 3
0
    def __init__(self, hparams, criterion, only_epd_cuda=False):
        # [NOTE]: Model is a "shared" model here.
        self.controller = NAOController(hparams).cuda(
            only_epd=only_epd_cuda, epd_device=hparams.epd_device)
        super().__init__(hparams, self.controller.shared_weights, criterion)

        self.only_epd_cuda = only_epd_cuda
        self.main_device = th.cuda.current_device()
        self.device_ids_for_gen = self._get_device_ids_for_gen()

        self.arch_pool = []
        self.arch_pool_prob = None
        self.eval_arch_pool = []
        self.performance_pool = []
        self._ref_tokens = None
        self._ref_dict = None
        self._current_child_size = 0
        self._current_grad_size = 0

        with hparams_env(
                hparams,
                optimizer=hparams.ctrl_optimizer,
                lr=[hparams.ctrl_lr],
                adam_eps=1e-8,
                weight_decay=hparams.ctrl_weight_decay,
        ) as ctrl_hparams:
            self.ctrl_optimizer = build_optimizer(
                ctrl_hparams, self.controller.epd.parameters())
            logging.info('Controller optimizer: {}'.format(
                self.ctrl_optimizer.__class__.__name__))

        # Meters.
        self._ctrl_best_pa = {
            'training': 0.00,
            'test': 0.00,
        }