def train(train_loader, models, criterion, distill_criterion, optimizer, logger, epoch): if len(models) == 1: # train teacher solo models[0].train() else: # train student (w. distill from teacher) models[0].eval() models[1].train() bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, bar_format=bar_format, ncols=80) dataloader = iter(train_loader) metrics = [ seg_metrics.Seg_Metrics(n_classes=config.num_classes) for _ in range(len(models)) ] lamb = 0.2 for step in pbar: optimizer.zero_grad() minibatch = dataloader.next() imgs = minibatch['data'] target = minibatch['label'] imgs = imgs.cuda(non_blocking=True) target = target.cuda(non_blocking=True) logits_list = [] loss = 0 loss_kl = 0 description = "" for idx, arch_idx in enumerate(config.arch_idx): model = models[idx] if arch_idx == 0 and len(models) > 1: with torch.no_grad(): logits8 = model(imgs) logits_list.append(logits8) else: logits8, logits16, logits32 = model(imgs) logits_list.append(logits8) loss = loss + criterion(logits8, target) loss = loss + lamb * criterion(logits16, target) loss = loss + lamb * criterion(logits32, target) if len(logits_list) > 1: loss = loss + distill_criterion(F.softmax(logits_list[1], dim=1).log(), F.softmax(logits_list[0], dim=1)) metrics[idx].update(logits8.data, target) description += "[mIoU%d: %.3f]"%(arch_idx, metrics[idx].get_scores()) pbar.set_description("[Step %d/%d]"%(step + 1, len(train_loader)) + description) logger.add_scalar('loss/train', loss+loss_kl, epoch*len(pbar)+step) loss.backward() optimizer.step() return [ metric.get_scores() for metric in metrics ]
def train(len_det2_train, det2_dataset, model, model_ema, criterion, num_classes, lr_scheduler, optimizer, logger, epoch, args, cfg): model.train() pixel_mean = cfg.MODEL.PIXEL_MEAN pixel_std = cfg.MODEL.PIXEL_STD pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1).cuda() pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1).cuda() metric = seg_metrics.Seg_Metrics(n_classes=num_classes) lamb = 0.2 # for i, sample in enumerate(train_loader): for i in range(len_det2_train): cur_iter = epoch * len_det2_train + i lr_scheduler(optimizer, cur_iter) det2_data = next(det2_dataset) det2_inputs = [x["image"].cuda(non_blocking=True) for x in det2_data] det2_inputs = [(x - pixel_mean) / pixel_std for x in det2_inputs] det2_inputs = ImageList.from_tensors(det2_inputs, args.size_divisibility).tensor b, c, h, w = det2_inputs.shape if h % 32 != 0 or w % 32 != 0: logging.info("pass bad data!") continue det2_targets = [ x["sem_seg"].cuda(non_blocking=True) for x in det2_data ] det2_targets = ImageList.from_tensors(det2_targets, args.size_divisibility, args.ignore).tensor N = det2_inputs.size(0) loss = 0 description = "" logits8, logits16, logits32 = model(det2_inputs) loss = loss + criterion(logits8, det2_targets) if logits16 is not None: loss = loss + lamb * criterion(logits16, det2_targets) if logits32 is not None: loss = loss + lamb * criterion(logits32, det2_targets) inter, union = seg_metrics.batch_intersection_union( logits8.data, det2_targets, num_classes) inter = reduce_tensor(torch.FloatTensor(inter).cuda(), args.world_size) union = reduce_tensor(torch.FloatTensor(union).cuda(), args.world_size) metric.update(inter.cpu().numpy(), union.cpu().numpy(), N) if args.local_rank == 0: description += "[mIoU%d: %.3f]" % (0, metric.get_scores()) torch.cuda.synchronize() reduced_loss = loss reduced_loss = reduce_tensor(reduced_loss.data, args.world_size) if args.local_rank == 0 and i % 20 == 0: logger.add_scalar('loss/train', reduced_loss, epoch * len_det2_train + i) logging.info('epoch: {0}\t' 'iter: {1}/{2}\t' 'lr: {3:.6f}\t' 'loss: {4:.4f}'.format(epoch + 1, i + 1, len_det2_train, lr_scheduler.get_lr(optimizer), reduced_loss)) loss.backward() optimizer.step() optimizer.zero_grad() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) return metric.get_scores()