def _init_model(self): self.criterion = nn.CrossEntropyLoss() model = Network( self.args.image_channels, self.args.init_channels, self.args.train_classes, layers=self.args.layers, criterion=self.criterion, num_inp_node=2, num_meta_node=self.args.num_meta_node, reduce_level=0 if 'cifar' in self.args.train_dataset else 1, use_sparse=self.args.use_sparse) self.model = model.cuda() self.logger.info('param size = %fMB', dutils.calc_parameters_count(model)) self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay) last_epoch = -1 if self.args.start_epoch == 0 else self.args.start_epoch self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, float(self.args.epochs), eta_min=self.args.learning_rate_min, last_epoch=last_epoch) self.architect = Architecture(self.model, self.args)
def train(self): objs = dutils.AverageMeter() top1 = dutils.AverageMeter() for step, (input, target) in enumerate(self.train_queue): self.model.train() n = input.size(0) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # Get a random minibatch from the search queue(validation set) with replacement input_search, target_search = next(iter(self.valid_queue)) input_search = input_search.cuda(non_blocking=True) target_search = target_search.cuda(non_blocking=True) # Update the architecture parameters self.architect.step(input, target, input_search, target_search, self.lr, self.optimizer, unrolled=self.args.sec_approx) self.optimizer.zero_grad() logits = self.model(input) loss = self.criterion(logits, target) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) # Update the network parameters self.optimizer.step() prec1 = dutils.accuracy(logits, target, topk=(1, ))[0] objs.update(loss.item(), n) top1.update(prec1.item(), n) if step % args.report_freq == 0: self.logger.info('model size: %f', dutils.calc_parameters_count(self.model)) self.logger.info('train %03d loss: %e top1: %f', step, objs.avg, top1.avg) return top1.avg, objs.avg
def _init_model(self): self.train_queue, self.valid_queue = self._load_dataset_queue() def _init_scheduler(): if 'cifar' in self.args.train_dataset: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs)) else: scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period, gamma=self.args.gamma) return scheduler genotype = eval('geno_types.%s' % self.args.arch) reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0) model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0, self.args.layers, self.args.auxiliary, genotype, reduce_level) # Try move model to multi gpus if torch.cuda.device_count() > 1 and self.args.multi_gpus: self.logger.info('use: %d gpus', torch.cuda.device_count()) model = nn.DataParallel(model) else: self.logger.info('gpu device = %d' % self.device_id) torch.cuda.set_device(self.device_id) self.model = model.to(self.device) self.logger.info('param size = %fM', dutils.calc_parameters_count(model)) criterion = nn.CrossEntropyLoss() if self.args.num_classes >= 50: criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth) self.criterion = criterion.to(self.device) if self.args.opt == 'adam': self.optimizer = torch.optim.Adamax( model.parameters(), self.args.learning_rate, weight_decay=self.args.weight_decay ) elif self.args.opt == 'adabound': self.optimizer = AdaBound(model.parameters(), self.args.learning_rate, weight_decay=self.args.weight_decay) else: self.optimizer = torch.optim.SGD( model.parameters(), self.args.learning_rate, momentum=self.args.momentum, weight_decay=self.args.weight_decay ) self.best_acc_top1 = 0 # optionally resume from a checkpoint if self.args.resume: if os.path.isfile(self.args.resume): print("=> loading checkpoint {}".format(self.args.resume)) checkpoint = torch.load(self.args.resume) self.dur_time = checkpoint['dur_time'] self.args.start_epoch = checkpoint['epoch'] self.best_acc_top1 = checkpoint['best_acc_top1'] self.args.drop_path_prob = checkpoint['drop_path_prob'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(self.args.resume)) self.scheduler = _init_scheduler() # reload the scheduler if possible if self.args.resume and os.path.isfile(self.args.resume): checkpoint = torch.load(self.args.resume) self.scheduler.load_state_dict(checkpoint['scheduler'])