Example #1
0
def train(train_source_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_s,
         accuracies_s, iou_s],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s = next(train_source_iter)
        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # compute gradient and do SGD step
        optimizer.step()
        lr_scheduler.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        accuracies_s.update(acc_s.mean().item())
        iou_s.update(iu_s.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
def validate(val_loader: DataLoader, model: Classifier,
             args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    classes = val_loader.dataset.classes
    confmat = ConfusionMatrix(len(classes))
    progress = ProgressMeter(len(val_loader), [batch_time], prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            softmax_output = F.softmax(output, dim=1)
            softmax_output[:, -1] = args.threshold

            # measure accuracy and record loss
            confmat.update(target, softmax_output.argmax(1))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        acc_global, accs, iu = confmat.compute()
        all_acc = torch.mean(accs).item() * 100
        known = torch.mean(accs[:-1]).item() * 100
        unknown = accs[-1].item() * 100
        h_score = 2 * known * unknown / (known + unknown)
        if args.per_class_eval:
            print(confmat.format(classes))
        print(
            ' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'
            .format(all=all_acc, known=known, unknown=unknown,
                    h_score=h_score))

    return h_score
Example #3
0
def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    acc = Meter('Acc', ':3.2f')
    iou = Meter('IoU', ':3.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, acc, iou],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    confmat = ConfusionMatrix(model.num_classes)

    with torch.no_grad():
        end = time.time()
        for i, (x, label) in enumerate(val_loader):
            x = x.to(device)
            label = label.long().to(device)

            # compute output
            output = interp(model(x))
            loss = criterion(output, label)

            # measure accuracy and record loss
            losses.update(loss.item(), x.size(0))
            confmat.update(label.flatten(), output.argmax(1).flatten())
            acc_global, accs, iu = confmat.compute()
            acc.update(accs.mean().item())
            iou.update(iu.mean().item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

                if visualize is not None:
                    visualize(x[0], output[0], label[0], "val_{}".format(i))

    return confmat
Example #4
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, interp, criterion,
          dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD,
          lr_scheduler_d: LambdaLR, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')
    losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        accuracies_s, accuracies_t, iou_s, iou_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()
        optimizer_d.zero_grad()

        # Step 1: Train the segmentation network, freeze the discriminator
        dann.eval()
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # adversarial training to fool the discriminator
        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_transfer = dann(pred_t, 'source')
        (loss_transfer * args.trade_off).backward()

        # Step 2: Train the discriminator
        dann.train()
        loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') +
                                    dann(pred_t.detach(), 'target'))
        loss_discriminator.backward()

        # compute gradient and do SGD step
        optimizer.step()
        optimizer_d.step()
        lr_scheduler.step()
        lr_scheduler_d.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))