Exemple #1
0
def train(train_loader, models, criterion, distill_criterion, optimizer, logger, epoch):
    if len(models) == 1:
        # train teacher solo
        models[0].train()
    else:
        # train student (w. distill from teacher)
        models[0].eval()
        models[1].train()

    bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
    pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, bar_format=bar_format, ncols=80)
    dataloader = iter(train_loader)

    metrics = [ seg_metrics.Seg_Metrics(n_classes=config.num_classes) for _ in range(len(models)) ]
    lamb = 0.2
    for step in pbar:
        optimizer.zero_grad()

        minibatch = dataloader.next()
        imgs = minibatch['data']
        target = minibatch['label']
        imgs = imgs.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        logits_list = []
        loss = 0
        loss_kl = 0
        description = ""
        for idx, arch_idx in enumerate(config.arch_idx):
            model = models[idx]
            if arch_idx == 0 and len(models) > 1:
                with torch.no_grad():
                    logits8 = model(imgs)
                    logits_list.append(logits8)
            else:
                logits8, logits16, logits32 = model(imgs)
                logits_list.append(logits8)
                loss = loss + criterion(logits8, target)
                loss = loss + lamb * criterion(logits16, target)
                loss = loss + lamb * criterion(logits32, target)
                if len(logits_list) > 1:
                    loss = loss + distill_criterion(F.softmax(logits_list[1], dim=1).log(), F.softmax(logits_list[0], dim=1))

            metrics[idx].update(logits8.data, target)
            description += "[mIoU%d: %.3f]"%(arch_idx, metrics[idx].get_scores())

        pbar.set_description("[Step %d/%d]"%(step + 1, len(train_loader)) + description)
        logger.add_scalar('loss/train', loss+loss_kl, epoch*len(pbar)+step)

        loss.backward()
        optimizer.step()

    return [ metric.get_scores() for metric in metrics ]
def train(len_det2_train, det2_dataset, model, model_ema, criterion,
          num_classes, lr_scheduler, optimizer, logger, epoch, args, cfg):

    model.train()
    pixel_mean = cfg.MODEL.PIXEL_MEAN
    pixel_std = cfg.MODEL.PIXEL_STD
    pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1).cuda()
    pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1).cuda()

    metric = seg_metrics.Seg_Metrics(n_classes=num_classes)
    lamb = 0.2
    # for i, sample in enumerate(train_loader):
    for i in range(len_det2_train):
        cur_iter = epoch * len_det2_train + i
        lr_scheduler(optimizer, cur_iter)

        det2_data = next(det2_dataset)
        det2_inputs = [x["image"].cuda(non_blocking=True) for x in det2_data]
        det2_inputs = [(x - pixel_mean) / pixel_std for x in det2_inputs]
        det2_inputs = ImageList.from_tensors(det2_inputs,
                                             args.size_divisibility).tensor

        b, c, h, w = det2_inputs.shape
        if h % 32 != 0 or w % 32 != 0:
            logging.info("pass bad data!")
            continue

        det2_targets = [
            x["sem_seg"].cuda(non_blocking=True) for x in det2_data
        ]
        det2_targets = ImageList.from_tensors(det2_targets,
                                              args.size_divisibility,
                                              args.ignore).tensor

        N = det2_inputs.size(0)

        loss = 0
        description = ""

        logits8, logits16, logits32 = model(det2_inputs)
        loss = loss + criterion(logits8, det2_targets)
        if logits16 is not None:
            loss = loss + lamb * criterion(logits16, det2_targets)
        if logits32 is not None:
            loss = loss + lamb * criterion(logits32, det2_targets)

        inter, union = seg_metrics.batch_intersection_union(
            logits8.data, det2_targets, num_classes)
        inter = reduce_tensor(torch.FloatTensor(inter).cuda(), args.world_size)
        union = reduce_tensor(torch.FloatTensor(union).cuda(), args.world_size)
        metric.update(inter.cpu().numpy(), union.cpu().numpy(), N)

        if args.local_rank == 0:
            description += "[mIoU%d: %.3f]" % (0, metric.get_scores())

        torch.cuda.synchronize()

        reduced_loss = loss
        reduced_loss = reduce_tensor(reduced_loss.data, args.world_size)
        if args.local_rank == 0 and i % 20 == 0:
            logger.add_scalar('loss/train', reduced_loss,
                              epoch * len_det2_train + i)
            logging.info('epoch: {0}\t'
                         'iter: {1}/{2}\t'
                         'lr: {3:.6f}\t'
                         'loss: {4:.4f}'.format(epoch + 1, i + 1,
                                                len_det2_train,
                                                lr_scheduler.get_lr(optimizer),
                                                reduced_loss))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        torch.cuda.synchronize()

        if model_ema is not None:
            model_ema.update(model)

    return metric.get_scores()