Esempio n. 1
0
    def _init_model(self):
        self.criterion = nn.CrossEntropyLoss()
        model = Network(
            self.args.image_channels,
            self.args.init_channels,
            self.args.train_classes,
            layers=self.args.layers,
            criterion=self.criterion,
            num_inp_node=2,
            num_meta_node=self.args.num_meta_node,
            reduce_level=0 if 'cifar' in self.args.train_dataset else 1,
            use_sparse=self.args.use_sparse)
        self.model = model.cuda()
        self.logger.info('param size = %fMB',
                         dutils.calc_parameters_count(model))

        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.args.learning_rate,
                                         momentum=self.args.momentum,
                                         weight_decay=self.args.weight_decay)

        last_epoch = -1 if self.args.start_epoch == 0 else self.args.start_epoch
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min,
            last_epoch=last_epoch)

        self.architect = Architecture(self.model, self.args)
Esempio n. 2
0
    def train(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()

        for step, (input, target) in enumerate(self.train_queue):
            self.model.train()
            n = input.size(0)
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # Get a random minibatch from the search queue(validation set) with replacement
            input_search, target_search = next(iter(self.valid_queue))
            input_search = input_search.cuda(non_blocking=True)
            target_search = target_search.cuda(non_blocking=True)

            # Update the architecture parameters
            self.architect.step(input,
                                target,
                                input_search,
                                target_search,
                                self.lr,
                                self.optimizer,
                                unrolled=self.args.sec_approx)

            self.optimizer.zero_grad()

            logits = self.model(input)
            loss = self.criterion(logits, target)

            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(),
                                     self.args.grad_clip)

            # Update the network parameters
            self.optimizer.step()

            prec1 = dutils.accuracy(logits, target, topk=(1, ))[0]
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)

            if step % args.report_freq == 0:
                self.logger.info('model size: %f',
                                 dutils.calc_parameters_count(self.model))
                self.logger.info('train %03d loss: %e top1: %f', step,
                                 objs.avg, top1.avg)

        return top1.avg, objs.avg
Esempio n. 3
0
    def _init_model(self):

        self.train_queue, self.valid_queue = self._load_dataset_queue()

        def _init_scheduler():
            if 'cifar' in self.args.train_dataset:
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs))
            else:
                scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period,
                                                            gamma=self.args.gamma)
            return scheduler

        genotype = eval('geno_types.%s' % self.args.arch)
        reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0)
        model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0,
                            self.args.layers, self.args.auxiliary, genotype, reduce_level)

        # Try move model to multi gpus
        if torch.cuda.device_count() > 1 and self.args.multi_gpus:
            self.logger.info('use: %d gpus', torch.cuda.device_count())
            model = nn.DataParallel(model)
        else:
            self.logger.info('gpu device = %d' % self.device_id)
            torch.cuda.set_device(self.device_id)
        self.model = model.to(self.device)

        self.logger.info('param size = %fM', dutils.calc_parameters_count(model))

        criterion = nn.CrossEntropyLoss()
        if self.args.num_classes >= 50:
            criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth)
        self.criterion = criterion.to(self.device)

        if self.args.opt == 'adam':
            self.optimizer = torch.optim.Adamax(
                model.parameters(),
                self.args.learning_rate,
                weight_decay=self.args.weight_decay
            )
        elif self.args.opt == 'adabound':
            self.optimizer = AdaBound(model.parameters(),
            self.args.learning_rate,
            weight_decay=self.args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(
                model.parameters(),
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )

        self.best_acc_top1 = 0
        # optionally resume from a checkpoint
        if self.args.resume:
            if os.path.isfile(self.args.resume):
                print("=> loading checkpoint {}".format(self.args.resume))
                checkpoint = torch.load(self.args.resume)
                self.dur_time = checkpoint['dur_time']
                self.args.start_epoch = checkpoint['epoch']
                self.best_acc_top1 = checkpoint['best_acc_top1']
                self.args.drop_path_prob = checkpoint['drop_path_prob']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(self.args.resume))

        self.scheduler = _init_scheduler()
        # reload the scheduler if possible
        if self.args.resume and os.path.isfile(self.args.resume):
            checkpoint = torch.load(self.args.resume)
            self.scheduler.load_state_dict(checkpoint['scheduler'])