def validate(val_loader, net): top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode net.eval() prefetcher = DataPrefetcher(val_loader) inputs, labels = prefetcher.next() with torch.no_grad(): while inputs is not None: inputs = inputs.float().cuda() labels = labels.cuda() stu_outputs, _ = net(inputs) acc1, acc5 = accuracy(stu_outputs[-1], labels, topk=(1, 5)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) inputs, labels = prefetcher.next() return top1.avg, top5.avg
def train(train_loader, net, criterion, optimizer, scheduler, epoch, logger): top1 = AverageMeter() top5 = AverageMeter() loss_total = AverageMeter() loss_ams = [AverageMeter()] * len(criterion) loss_alphas = [] for loss_item in Config.loss_list: loss_rate = loss_item["loss_rate"] factor = loss_item["factor"] loss_type = loss_item["loss_type"] loss_rate_decay = loss_item["loss_rate_decay"] loss_alphas.append( adjust_loss_alpha(loss_rate, epoch, factor, loss_type, loss_rate_decay)) # switch to train mode net.train() iters = len(train_loader.dataset) // args.batch_size prefetcher = DataPrefetcher(train_loader) inputs, labels = prefetcher.next() iter_index = 1 while inputs is not None: inputs, labels = inputs.float().cuda(), labels.cuda() # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize stu_outputs, tea_outputs = net(inputs) loss = 0 loss_detail = [] for i, loss_item in enumerate(Config.loss_list): loss_type = loss_item["loss_type"] if loss_type == "ce_family": tmp_loss = loss_alphas[i] * criterion[i](stu_outputs[-1], labels) elif loss_type == "kd_family": tmp_loss = loss_alphas[i] * criterion[i](stu_outputs[-1], tea_outputs[-1]) elif loss_type == "gkd_family": tmp_loss = loss_alphas[i] * criterion[i]( stu_outputs[-1], tea_outputs[-1], labels) elif loss_type == "fd_family": tmp_loss = loss_alphas[i] * criterion[i](stu_outputs[:-1], tea_outputs[:-1]) loss_detail.append(tmp_loss.item()) loss_ams[i].update(tmp_loss.item(), inputs.size(0)) loss += tmp_loss loss = loss / args.accumulation_steps if args.apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if iter_index % args.accumulation_steps == 0: optimizer.step() optimizer.zero_grad() acc1, acc5 = accuracy(stu_outputs[-1], labels, topk=(1, 5)) top1.update(acc1.item(), inputs.size(0)) top5.update(acc5.item(), inputs.size(0)) loss_total.update(loss.item(), inputs.size(0)) inputs, labels = prefetcher.next() loss_log = "" if iter_index % args.print_interval == 0: loss_log += f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}, " for i, loss_item in enumerate(Config.loss_list): loss_name = loss_item["loss_name"] loss_log += f"{loss_name}: {loss_detail[i]:2f}, alpha: {loss_alphas[i]:2f}, " logger.info(loss_log) iter_index += 1 scheduler.step() return top1.avg, top5.avg, loss_total.avg