def train(self, model, epoch, optim_obj='Weights', search_stage=0): assert optim_obj in ['Weights', 'Arch'] objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() sub_obj_avg = utils.AverageMeter() data_time = utils.AverageMeter() batch_time = utils.AverageMeter() model.train() start = time.time() if optim_obj == 'Weights': prefetcher = data_prefetcher(self.train_data) elif optim_obj == 'Arch': prefetcher = data_prefetcher(self.val_data) input, target = prefetcher.next() step = 0 while input is not None: input, target = input.cuda(), target.cuda() data_t = time.time() - start n = input.size(0) if optim_obj == 'Weights': self.scheduler.step() if step == 0: logging.info( 'epoch %d weight_lr %e', epoch, self.search_optim.weight_optimizer.param_groups[0] ['lr']) logits, loss, sub_obj = self.search_optim.weight_step( input, target, model, search_stage) elif optim_obj == 'Arch': if step == 0: logging.info( 'epoch %d arch_lr %e', epoch, self.search_optim.arch_optimizer.param_groups[0]['lr']) logits, loss, sub_obj = self.search_optim.arch_step( input, target, model, search_stage) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) del logits, input, target batch_t = time.time() - start objs.update(loss, n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) sub_obj_avg.update(sub_obj) data_time.update(data_t) batch_time.update(batch_t) if step != 0 and step % self.args.report_freq == 0: logging.info( 'Train%s epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', optim_obj, epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg) start = time.time() step += 1 input, target = prefetcher.next() return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg
def infer(self, model, epoch=0): top1 = utils.AverageMeter() top5 = utils.AverageMeter() data_time = utils.AverageMeter() batch_time = utils.AverageMeter() model.eval() start = time.time() prefetcher = data_prefetcher(self.val_data) input, target = prefetcher.next() step = 0 while input is not None: step += 1 data_t = time.time() - start n = input.size(0) logits, logits_aux = model(input) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) batch_t = time.time() - start top1.update(prec1.item(), n) top5.update(prec5.item(), n) data_time.update(data_t) batch_time.update(batch_t) if step % self.report_freq == 0: logging.info('Val epoch %03d step %03d | top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg) start = time.time() input, target = prefetcher.next() logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg) return top1.avg, top5.avg, batch_time.avg, data_time.avg
def train(self, model, epoch): objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() data_time = utils.AverageMeter() batch_time = utils.AverageMeter() model.train() start = time.time() prefetcher = data_prefetcher(self.train_data) input, target = prefetcher.next() step = 0 while input is not None: data_t = time.time() - start self.scheduler.step() n = input.size(0) if step == 0: logging.info('epoch %d lr %e', epoch, self.optimizer.param_groups[0]['lr']) self.optimizer.zero_grad() logits = model(input) if self.config.optim.label_smooth: loss = self.criterion(logits, target, self.config.optim.smooth_alpha) else: loss = self.criterion(logits, target) loss.backward() if self.config.optim.use_grad_clip: nn.utils.clip_grad_norm_(model.parameters(), self.config.optim.grad_clip) self.optimizer.step() prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) batch_t = time.time() - start start = time.time() objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) data_time.update(data_t) batch_time.update(batch_t) if step != 0 and step % self.report_freq == 0: logging.info( 'Train epoch %03d step %03d | loss %.4f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg) input, target = prefetcher.next() step += 1 logging.info( 'EPOCH%d Train_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg) return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg
def train_with_distill(train_loader, d_net, optimizer, epoch): d_net.train() d_net.module.s_net.train() d_net.module.t_net.train() train_loss = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() btic = time.time() prefetcher = data_prefetcher(train_loader, is_sample=False) inputs, targets = prefetcher.next() step = 0 while inputs is not None: batch_size = inputs.size(0) if step == 0: print('epoch %d lr %e' % (epoch, optimizer.param_groups[0]['lr'])) optimizer.zero_grad() outputs, loss = d_net(inputs, targets) loss = torch.mean(loss) err1, err5 = accuracy(outputs.data, targets, topk=(1, 5)) train_loss.update(loss.item(), batch_size) top1.update(err1.item(), batch_size) top5.update(err5.item(), batch_size) optimizer.zero_grad() loss.backward() optimizer.step() if step % config.train_params.print_freq == 0: speed = config.train_params.print_freq * config.data.batch_size / (time.time() - btic) print( 'Train with distillation: [Epoch %d/%d][Batch %d/%d]\t, speed %.3f, Loss %.3f, Top 1-error %.3f, Top 5-error %.3f' % (epoch, config.train_params.epochs, step, len(train_loader), speed, train_loss.avg, top1.avg, top5.avg)) btic = time.time() inputs, targets = prefetcher.next() step += 1
def validate(val_loader, model, epoch): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() prefetcher = data_prefetcher(val_loader) input, target = prefetcher.next() step = 0 while input is not None: # for PyTorch 0.4.x, volatile=True is replaced by with torch.no.grad(), so uncomment the followings: with torch.no_grad(): output = model(input) # measure accuracy and record loss err1, err5 = accuracy(output.data, target, topk=(1, 5)) top1.update(err1.item(), input.size(0)) top5.update(err5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % config.train_params.print_freq == 0: print('Test (on val set): [Epoch {0}/{1}][Batch {2}/{3}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( epoch, config.train_params.epochs, step, len(val_loader), batch_time=batch_time, top1=top1, top5=top5)) input, target = prefetcher.next() step += 1 print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Test Loss {loss.avg:.3f}' .format(epoch, config.train_params.epochs, top1=top1, top5=top5, loss=losses)) return top1.avg, top5.avg
def infer(self, model, epoch): objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() sub_obj_avg = utils.AverageMeter() data_time = utils.AverageMeter() batch_time = utils.AverageMeter() model.train() # don't use running_mean and running_var during search start = time.time() prefetcher = data_prefetcher(self.val_data) input, target = prefetcher.next() step = 0 while input is not None: step += 1 data_t = time.time() - start n = input.size(0) logits, loss, sub_obj = self.search_optim.valid_step( input, target, model) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) batch_t = time.time() - start objs.update(loss, n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) sub_obj_avg.update(sub_obj) data_time.update(data_t) batch_time.update(batch_t) if step % self.args.report_freq == 0: logging.info( 'Val epoch %03d step %03d | loss %.4f %s %.2f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', epoch, step, objs.avg, self.sub_obj_type, sub_obj_avg.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg) start = time.time() input, target = prefetcher.next() return top1.avg, top5.avg, objs.avg, sub_obj_avg.avg, batch_time.avg