def infer(self, graph, criterion, valid_queue, *args, **kwargs): try: config = kwargs.get('config', graph.config) device = kwargs['device'] except: raise ('No configuration specified in graph or kwargs') objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() graph.eval() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): input = input.to(device) target = target.to(device, non_blocking=True) # logits, _ = graph(input) logits = graph(input) loss = criterion(logits, target) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data.item(), n) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) if step % config.report_freq == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return top1.avg, objs.avg
def train(self, epoch, graph, optimizer, criterion, train_queue, valid_queue, *args, **kwargs): try: config = kwargs.get('config', graph.config) device = kwargs['device'] arch_optimizer = kwargs['arch_optimizer'] except Exception as e: raise ModuleNotFoundError( 'No configuration specified in graph or kwargs') objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() # Adjust arch optimizer for new search epoch arch_optimizer.new_epoch(epoch) start_time = time.time() for step, (input_train, target_train) in enumerate(train_queue): graph.train() n = input_train.size(0) input_train = input_train.to(device) target_train = target_train.to(device, non_blocking=True) # Architecture update arch_optimizer.forward_pass_adjustment() input_valid, target_valid = next(iter(valid_queue)) input_valid = input_valid.to(device) target_valid = target_valid.to(device, non_blocking=True) arch_optimizer.step(graph, criterion, input_train, target_train, input_valid, target_valid, self.lr, self.optimizer, config.unrolled) optimizer.zero_grad() # OP-weight update arch_optimizer.forward_pass_adjustment() logits = graph(input_train) loss = criterion(logits, target_train) loss.backward() nn.utils.clip_grad_norm_(graph.parameters(), config.grad_clip) optimizer.step() prec1, prec5 = utils.accuracy(logits, target_train, topk=(1, 5)) objs.update(loss.data.item(), n) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) if step % config.report_freq == 0: arch_key = list( arch_optimizer.architectural_weights.keys())[-1] logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) end_time = time.time() return top1.avg, objs.avg, end_time - start_time
def train_batch(self, arch): if self.steps % len(self.train_queue) == 0: self.scheduler.step() self.objs = utils.AvgrageMeter() self.top1 = utils.AvgrageMeter() self.top5 = utils.AvgrageMeter() lr = self.scheduler.get_lr()[0] weights = self.get_weights_from_arch(arch) self.set_arch_model_weights(weights) step = self.steps % len(self.train_queue) input, target = next(self.train_iter) self.model.train() n = input.size(0) input = input.cuda() target = target.cuda(non_blocking=True) # get a random_ws minibatch from the search queue with replacement self.optimizer.zero_grad() logits = self.model(input, discrete=True) loss = self.criterion(logits, target) loss.backward() nn.utils.clip_grad_norm(self.model.parameters(), self.args.grad_clip) self.optimizer.step() prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) self.objs.update(loss.data.item(), n) self.top1.update(prec1.data.item(), n) self.top5.update(prec5.data.item(), n) if step % self.args.report_freq == 0: logging.info('train %03d %e %f %f', step, self.objs.avg, self.top1.avg, self.top5.avg) self.steps += 1 if self.steps % len(self.train_queue) == 0: # Save the model weights self.epochs += 1 self.train_iter = iter(self.train_queue) valid_err = self.evaluate(arch) logging.info('epoch %d | train_acc %f | valid_acc %f' % (self.epochs, self.top1.avg, 1 - valid_err)) self.save(epoch=self.epochs)
def train(self, epoch, graph, optimizer, criterion, train_queue, valid_queue, *args, **kwargs): try: config = kwargs.get('config', graph.config) device = kwargs['device'] except Exception as e: raise ModuleNotFoundError( 'No configuration specified in graph or kwargs') objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() start_time = time.time() for step, (input, target) in enumerate(train_queue): graph.train() n = input.size(0) input = input.to(device) target = target.to(device, non_blocking=True) optimizer.zero_grad() # logits, logits_aux = graph(input) logits = graph(input) loss = criterion(logits, target) # if config.auxiliary: # loss_aux = criterion(logits_aux, target) # loss += config.auxiliary_weight * loss_aux loss.backward() nn.utils.clip_grad_norm_(graph.parameters(), config.grad_clip) optimizer.step() prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) objs.update(loss.data.item(), n) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) if step % config.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) end_time = time.time() return top1.avg, objs.avg, end_time - start_time
def evaluate_test(self, arch, split=None, discrete=False, normalize=True): # Return error since we want to minimize obj val logging.info(arch) objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() weights = self.get_weights_from_arch(arch) self.set_arch_model_weights(weights) self.model.eval() if split is None: n_batches = 10 else: n_batches = len(self.test_queue) for step in range(n_batches): try: input, target = next(self.test_iter) except Exception as e: logging.info('looping back over valid set') self.test_iter = iter(self.test_queue) input, target = next(self.test_iter) input = input.cuda() target = target.cuda(non_blocking=True) logits = self.model(input, discrete=discrete, normalize=normalize) loss = self.criterion(logits, target) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data.item(), n) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) if step % self.args.report_freq == 0: logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) return 1 - 0.01 * top1.avg