def validate(val_loader, model, args, device) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target, _) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1 = accuracy(output, target)[0]
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1))

    return top1.avg
def train(train_iter: ForeverDataIterator, bituning: BiTuning, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    contrastive_losses = AverageMeter('Contrastive Loss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_losses, contrastive_losses, cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    classifier_criterion = torch.nn.CrossEntropyLoss().to(device)
    contrastive_criterion = torch.nn.KLDivLoss(
        reduction='batchmean').to(device)

    # switch to train mode
    bituning.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        img_q, img_k = x[0], x[1]

        img_q = img_q.to(device)
        img_k = img_k.to(device)
        labels = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y, logits_z, logits_y, bituning_labels = bituning(img_q, img_k, labels)
        cls_loss = classifier_criterion(y, labels)
        contrastive_loss_z = contrastive_criterion(logits_z, bituning_labels)
        contrastive_loss_y = contrastive_criterion(logits_y, bituning_labels)
        contrastive_loss = (contrastive_loss_z + contrastive_loss_y)
        loss = cls_loss + contrastive_loss * args.trade_off

        # measure accuracy and record loss
        losses.update(loss.item(), x[0].size(0))
        cls_losses.update(cls_loss.item(), x[0].size(0))
        contrastive_losses.update(contrastive_loss.item(), x[0].size(0))

        cls_acc = accuracy(y, labels)[0]
        cls_accs.update(cls_acc.item(), x[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer, lr_scheduler: CosineAnnealingLR,
          domain_weight_module: AutomaticUpdateDomainWeightModule, n_domains_per_batch: int, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_all, labels_all, domain_labels = next(train_iter)
        x_all = x_all.to(device)
        labels_all = labels_all.to(device)
        domain_labels = domain_labels.to(device)

        # get selected domain idxes
        domain_labels = domain_labels.chunk(n_domains_per_batch, dim=0)
        sampled_domain_idxes = [domain_labels[i][0].item() for i in range(n_domains_per_batch)]

        # measure data loading time
        data_time.update(time.time() - end)

        loss_per_domain = torch.zeros(n_domains_per_batch).to(device)
        cls_acc = 0
        for domain_id, (x_per_domain, labels_per_domain) in enumerate(
                zip(x_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))):
            y_per_domain, _ = model(x_per_domain)
            loss_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain)
            cls_acc += accuracy(y_per_domain, labels_per_domain)[0] / n_domains_per_batch

        # update domain weight
        domain_weight_module.update(loss_per_domain, sampled_domain_idxes)
        domain_weight = domain_weight_module.get_domain_weight(sampled_domain_idxes)

        # weighted cls loss
        loss = (loss_per_domain * domain_weight).sum()

        losses.update(loss.item(), x_all.size(0))
        cls_accs.update(cls_acc.item(), x_all.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #4
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: Classifier,
          unknown_bce: UnknownClassBinaryCrossEntropy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, _ = model(x_s, grad_reverse=False)
        y_t, _ = model(x_t, grad_reverse=True)

        cls_loss = F.cross_entropy(y_s, labels_s)
        trans_loss = unknown_bce(y_t)
        loss = cls_loss + trans_loss

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        trans_losses.update(trans_loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #5
0
def validate(val_loader: DataLoader, model: Regressor,
             args: argparse.Namespace, factors) -> Tuple[float, float]:
    batch_time = AverageMeter('Time', ':6.3f')
    mae_losses = [
        AverageMeter('mae {}'.format(factor), ':6.3f') for factor in factors
    ]
    progress = ProgressMeter(len(val_loader), [batch_time] + mae_losses,
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            for j in range(len(factors)):
                mae_loss = F.l1_loss(output[:, j], target[:, j])
                mae_losses[j].update(mae_loss.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        for i, factor in enumerate(factors):
            print("{} MAE {mae.avg:6.3f}".format(factor, mae=mae_losses[i]))
        mean_mae = sum(l.avg for l in mae_losses) / len(factors)
    return mean_mae
Beispiel #6
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: Regressor, rsd,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    mse_losses = AverageMeter('MSE Loss', ':6.3f')
    rsd_losses = AverageMeter('RSD Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, mse_losses, rsd_losses, mae_losses_s,
        mae_losses_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        mse_loss = F.mse_loss(y_s, labels_s)
        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)
        rsd_loss = rsd(f_s, f_t)
        loss = mse_loss + rsd_loss * args.trade_off

        mse_losses.update(mse_loss.item(), x_s.size(0))
        rsd_losses.update(rsd_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          criterion_ce: CrossEntropyLossWithLabelSmooth,
          criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_ce = AverageMeter('CeLoss', ':3.2f')
    losses_triplet = AverageMeter('TripletLoss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, _, labels_s, _ = next(train_source_iter)
        x_t, _, _, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # cross entropy loss
        loss_ce = criterion_ce(y_s, labels_s)
        # triplet loss
        loss_triplet = criterion_triplet(f_s, f_s, labels_s)
        loss = loss_ce + loss_triplet * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        losses_ce.update(loss_ce.item(), x_s.size(0))
        losses_triplet.update(loss_triplet.item(), x_s.size(0))
        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv: ConditionalDomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(y_s, f_s, y_t, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc, x_s.size(0))
        domain_accs.update(domain_acc, x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #9
0
def train(train_source_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_s,
         accuracies_s, iou_s],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s = next(train_source_iter)
        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # compute gradient and do SGD step
        optimizer.step()
        lr_scheduler.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        accuracies_s.update(acc_s.mean().item())
        iou_s.update(iu_s.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
def train(train_source_iter, train_target_iter, model, criterion, optimizer,
          epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses_s, acc_s],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s, weight_s, meta_s = next(train_source_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s)

        # compute gradient and do SGD step
        loss_s.backward()
        optimizer.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        losses_s.update(loss_s, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred.jpg".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label.jpg".format(i))
Beispiel #11
0
def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead,
             F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]:
    batch_time = AverageMeter('Time', ':6.3f')
    top1_1 = AverageMeter('Acc_1', ':6.2f')
    top1_2 = AverageMeter('Acc_2', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, top1_1, top1_2],
        prefix='Test: ')

    # switch to evaluate mode
    G.eval()
    F1.eval()
    F2.eval()

    if args.per_class_eval:
        classes = val_loader.dataset.classes
        confmat = ConfusionMatrix(len(classes))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            g = G(images)
            y1, y2 = F1(g), F2(g)

            # measure accuracy and record loss
            acc1, = accuracy(y1, target)
            acc2, = accuracy(y2, target)
            if confmat:
                confmat.update(target, y1.argmax(1))
            top1_1.update(acc1.item(), images.size(0))
            top1_2.update(acc2.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'
              .format(top1_1=top1_1, top1_2=top1_2))
        if confmat:
            print(confmat.format(classes))

    return top1_1.avg, top1_2.avg
Beispiel #12
0
def train(train_iter: ForeverDataIterator, model: Classifier, kd,
          optimizer: SGD, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_kd = AverageMeter('Loss (KD)', ':5.4f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, losses_kd, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, label_t, label_s = next(train_iter)

        x = x.to(device)
        label_s = label_s.to(device)
        label_t = label_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, y_t = model(x)
        tgt_loss = F.cross_entropy(y_t, label_t)
        src_loss = kd(y_s, label_s)
        loss = tgt_loss + args.trade_off * src_loss

        # measure accuracy and record loss
        losses.update(tgt_loss.item(), x.size(0))
        losses_kd.update(src_loss.item(), x.size(0))
        cls_acc = accuracy(y_t, label_t)[0]
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def validate(val_loader: DataLoader, model: ImageClassifier, args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    if args.per_class_eval:
        classes = val_loader.dataset.classes
        confmat = ConfusionMatrix(len(classes))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            if confmat:
                confmat.update(target, output.argmax(1))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
        if confmat:
            print(confmat.format(classes))

    return top1.avg
def train(train_iter: ForeverDataIterator, model: Classifier, bss_module, optimizer: SGD,
        epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)

        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y, f = model(x)

        cls_loss = F.cross_entropy(y, label)
        bss_loss = bss_module(f)
        loss = cls_loss + args.trade_off * bss_loss

        cls_acc = accuracy(y, label)[0]

        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator, model, optimizer: Adam,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)

        cls_loss = F.cross_entropy(y_s, labels_s)
        loss = cls_loss

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #16
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_transfer, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    target_model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, _ = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        _, f_s = source_model(x_s)
        _, f_t = target_model(x_t)
        loss_transfer = domain_adv(f_s, f_t)

        # Compute gradient and do SGD step
        optimizer.zero_grad()
        loss_transfer.backward()
        optimizer.step()
        lr_scheduler.step()

        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        domain_acc = domain_adv.domain_discriminator_accuracy
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #17
0
def validate(val_loader, model, criterion, visualize,
             args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.2e')
    acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f")
    progress = ProgressMeter(len(val_loader), [batch_time, losses, acc['all']],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (x, label, weight, meta) in enumerate(val_loader):
            x = x.to(device)
            label = label.to(device)
            weight = weight.to(device)

            # compute output
            y = model(x)
            loss = criterion(y, label, weight)

            # measure accuracy and record loss
            losses.update(loss.item(), x.size(0))
            acc_per_points, avg_acc, cnt, pred = accuracy(
                y.cpu().numpy(),
                label.cpu().numpy())

            group_acc = val_loader.dataset.group_accuracy(acc_per_points)
            acc.update(group_acc, x.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
                if visualize is not None:
                    visualize(x[0],
                              pred[0] * args.image_size / args.heatmap_size,
                              "val_{}_pred.jpg".format(i))
                    visualize(x[0], meta['keypoint2d'][0],
                              "val_{}_label.jpg".format(i))

    return acc.average()
def validate(val_loader: DataLoader, model: Classifier,
             args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    classes = val_loader.dataset.classes
    confmat = ConfusionMatrix(len(classes))
    progress = ProgressMeter(len(val_loader), [batch_time], prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            softmax_output = F.softmax(output, dim=1)
            softmax_output[:, -1] = args.threshold

            # measure accuracy and record loss
            confmat.update(target, softmax_output.argmax(1))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        acc_global, accs, iu = confmat.compute()
        all_acc = torch.mean(accs).item() * 100
        known = torch.mean(accs[:-1]).item() * 100
        unknown = accs[-1].item() * 100
        h_score = 2 * known * unknown / (known + unknown)
        if args.per_class_eval:
            print(confmat.format(classes))
        print(
            ' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'
            .format(all=all_acc, known=known, unknown=unknown,
                    h_score=h_score))

    return h_score
Beispiel #19
0
def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    acc = Meter('Acc', ':3.2f')
    iou = Meter('IoU', ':3.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, acc, iou],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    confmat = ConfusionMatrix(model.num_classes)

    with torch.no_grad():
        end = time.time()
        for i, (x, label) in enumerate(val_loader):
            x = x.to(device)
            label = label.long().to(device)

            # compute output
            output = interp(model(x))
            loss = criterion(output, label)

            # measure accuracy and record loss
            losses.update(loss.item(), x.size(0))
            confmat.update(label.flatten(), output.argmax(1).flatten())
            acc_global, accs, iu = confmat.compute()
            acc.update(accs.mean().item())
            iou.update(iu.mean().item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

                if visualize is not None:
                    visualize(x[0], output[0], label[0], "val_{}".format(i))

    return confmat
Beispiel #20
0
def extract_reid_feature(data_loader,
                         model,
                         device,
                         normalize,
                         print_freq=200):
    """Extract feature for person ReID. If `normalize` is True, `cosine` distance will be employed as distance
    metric, otherwise `euclidean` distance.
    """
    batch_time = AverageMeter('Time', ':6.3f')
    progress = ProgressMeter(len(data_loader), [batch_time],
                             prefix='Collect feature: ')

    # switch to eval mode
    model.eval()
    feature_dict = dict()

    with torch.no_grad():
        end = time.time()
        for i, (images_batch, filenames_batch, _, _) in enumerate(data_loader):

            images_batch = images_batch.to(device)
            features_batch = model(images_batch)
            if normalize:
                features_batch = F.normalize(features_batch)

            for filename, feature in zip(filenames_batch, features_batch):
                feature_dict[filename] = feature

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)

    return feature_dict
Beispiel #21
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    source_losses = AverageMeter('Source Loss', ':6.3f')
    trans_losses = AverageMeter('Trans Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, source_losses, trans_losses, mae_losses_s,
        mae_losses_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat([x_s, x_t], dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute mean square loss on source domain
        mse_loss = F.mse_loss(y_s, labels_s)

        # compute margin disparity discrepancy between domains
        transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        loss = mse_loss - transfer_loss * args.trade_off
        model.step()

        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)

        source_losses.update(mse_loss.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #22
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_discri: DomainDiscriminator,
          domain_adv: DomainAdversarialLoss, gl, optimizer: SGD,
          lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_s = AverageMeter('Cls Loss', ':6.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        cls_accs, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step 1: Train the classifier, freeze the discriminator
        model.train()
        domain_discri.eval()
        set_requires_grad(model, True)
        set_requires_grad(domain_discri, False)
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        loss_s = F.cross_entropy(y_s, labels_s)

        # adversarial training to fool the discriminator
        d = domain_discri(gl(f))
        d_s, d_t = d.chunk(2, dim=0)
        loss_transfer = 0.5 * (domain_adv(d_s, 'target') +
                               domain_adv(d_t, 'source'))

        optimizer.zero_grad()
        (loss_s + loss_transfer * args.trade_off).backward()
        optimizer.step()
        lr_scheduler.step()

        # Step 2: Train the discriminator
        model.eval()
        domain_discri.train()
        set_requires_grad(model, False)
        set_requires_grad(domain_discri, True)
        d = domain_discri(f.detach())
        d_s, d_t = d.chunk(2, dim=0)
        loss_discriminator = 0.5 * (domain_adv(d_s, 'source') +
                                    domain_adv(d_t, 'target'))

        optimizer_d.zero_grad()
        loss_discriminator.backward()
        optimizer_d.step()
        lr_scheduler_d.step()

        losses_s.update(loss_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) +
                            binary_accuracy(d_t, torch.zeros_like(d_t)))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #23
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    norm_losses = AverageMeter('Norm Loss', ':3.2f')
    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, cls_losses, norm_losses, src_feature_norm,
        tgt_feature_norm, cls_accs, tgt_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # norm loss
        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
        loss = cls_loss + norm_loss * args.trade_off_norm

        # using entropy minimization
        if args.trade_off_entropy:
            y_t = F.softmax(y_t, dim=1)
            entropy_loss = entropy(y_t, reduction='mean')
            loss += entropy_loss * args.trade_off_entropy

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        cls_losses.update(cls_loss.item(), x_s.size(0))
        norm_losses.update(norm_loss.item(), x_s.size(0))
        src_feature_norm.update(
            f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
        tgt_feature_norm.update(
            f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #24
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv_D: DomainAdversarialLoss,
          domain_adv_D_0: DomainAdversarialLoss, importance_weight_module,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')
    domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')
    partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')
    non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D,
        domain_accs_D_0, partial_classes_weights, non_partial_classes_weights
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv_D.train()
    domain_adv_D_0.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)

        # domain adversarial loss for D
        adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())

        # get importance weights
        w_s = importance_weight_module.get_importance_weight(f_s)
        # domain adversarial loss for D_0
        adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)

        # entropy loss
        y_t = F.softmax(y_t, dim=1)
        entropy_loss = entropy(y_t, reduction='mean')

        loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \
               args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy,
                             x_s.size(0))
        domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy,
                               x_s.size(0))

        # debug: output class weight averaged on the partial classes and non-partial classes respectively
        partial_class_weight, non_partial_classes_weight = \
            importance_weight_module.get_partial_classes_weight(w_s, labels_s)
        partial_classes_weights.update(partial_class_weight.item(),
                                       x_s.size(0))
        non_partial_classes_weights.update(non_partial_classes_weight.item(),
                                           x_s.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #25
0
def train(train_source_iter, train_target_iter, model, criterion,
          regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
          lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    losses_gf = AverageMeter('Loss (t, false)', ":.2e")
    losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")
    acc_t = AverageMeter("Acc (t)", ":3.2f")
    acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
    acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t,
        acc_s_adv, acc_t_adv
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, label_s, weight_s, meta_s = next(train_source_iter)
        x_t, label_t, weight_t, meta_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        x_t = x_t.to(device)
        label_t = label_t.to(device)
        weight_t = weight_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_f.zero_grad()
        optimizer_h.zero_grad()
        optimizer_h_adv.zero_grad()

        y_s, y_s_adv = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s) + \
                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
        loss_s.backward()
        optimizer_f.step()
        optimizer_h.step()
        optimizer_h_adv.step()

        # Step B train adv regressor to maximize regression disparity
        optimizer_h_adv.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_false = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='max')
        loss_ground_false.backward()
        optimizer_h_adv.step()

        # Step C train feature extractor to minimize regression disparity
        optimizer_f.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_truth = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='min')
        loss_ground_truth.backward()
        optimizer_f.step()

        # do update step
        model.step()
        lr_scheduler_f.step()
        lr_scheduler_h.step()
        lr_scheduler_h_adv.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
                                               label_t.detach().cpu().numpy())
        acc_t.update(avg_acc_t, cnt_t)
        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(
            y_s_adv.detach().cpu().numpy(),
            label_s.detach().cpu().numpy())
        acc_s_adv.update(avg_acc_s_adv, cnt_s)
        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(
            y_t_adv.detach().cpu().numpy(),
            label_t.detach().cpu().numpy())
        acc_t_adv.update(avg_acc_t_adv, cnt_t)
        losses_s.update(loss_s, cnt_s)
        losses_gf.update(loss_ground_false, cnt_s)
        losses_gt.update(loss_ground_truth, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label".format(i))
                visualize(x_t[0],
                          pred_t[0] * args.image_size / args.heatmap_size,
                          "target_{}_pred".format(i))
                visualize(x_t[0], meta_t['keypoint2d'][0],
                          "target_{}_label".format(i))
                visualize(x_s[0],
                          pred_s_adv[0] * args.image_size / args.heatmap_size,
                          "source_adv_{}_pred".format(i))
                visualize(x_t[0],
                          pred_t_adv[0] * args.image_size / args.heatmap_size,
                          "target_adv_{}_pred".format(i))
Beispiel #26
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: Adam,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    cls_adv_accs = AverageMeter('Cls Adv Acc', ':3.1f')
    tgt_adv_accs = AverageMeter('Tgt Adv Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs,
        cls_adv_accs, tgt_adv_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute cross entropy loss on source domain
        cls_loss = F.cross_entropy(y_s, labels_s)
        # compute margin disparity discrepancy between domains
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv)
        loss = cls_loss + transfer_loss * args.trade_off
        model.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]
        cls_adv_acc = accuracy(y_s_adv, labels_s)[0]
        tgt_adv_acc = accuracy(y_t_adv, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        cls_adv_accs.update(cls_adv_acc.item(), x_s.size(0))
        tgt_adv_accs.update(tgt_adv_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #27
0
def train(train_iter: ForeverDataIterator, model: Classifier,
          backbone_regularization: nn.Module, head_regularization: nn.Module,
          target_getter: IntermediateLayerGetter,
          source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')
    losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, losses_reg_head, losses_reg_backbone,
        cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        intermediate_output_s, output_s = source_getter(x)
        intermediate_output_t, output_t = target_getter(x)
        y, f = output_t

        # measure accuracy and record loss
        cls_acc = accuracy(y, label)[0]
        cls_loss = F.cross_entropy(y, label)
        if args.regularization_type == 'feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        elif args.regularization_type == 'attention_feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        else:
            loss_reg_backbone = backbone_regularization()
        loss_reg_head = head_regularization()
        loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head

        losses_reg_backbone.update(
            loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))
        losses_reg_head.update(loss_reg_head.item() * args.trade_off_head,
                               x.size(0))
        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #28
0
def calculate_channel_attention(dataset, return_layers, args):
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = Classifier(backbone, dataset.num_classes).to(device)
    optimizer = SGD(classifier.get_parameters(args.lr),
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    data_loader = DataLoader(dataset,
                             batch_size=args.attention_batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             drop_last=False)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs))
    criterion = nn.CrossEntropyLoss()

    channel_weights = []
    for layer_id, name in enumerate(return_layers):
        layer = get_attribute(classifier, name)
        layer_channel_weight = [0] * layer.out_channels
        channel_weights.append(layer_channel_weight)

    # train the classifier
    classifier.train()
    classifier.backbone.requires_grad = False
    print("Pretrain a classifier to calculate channel attention.")
    for epoch in range(args.attention_epochs):
        losses = AverageMeter('Loss', ':3.2f')
        cls_accs = AverageMeter('Cls Acc', ':3.1f')
        progress = ProgressMeter(len(data_loader), [losses, cls_accs],
                                 prefix="Epoch: [{}]".format(epoch))

        for i, data in enumerate(data_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs, _ = classifier(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_acc = accuracy(outputs, labels)[0]

            losses.update(loss.item(), inputs.size(0))
            cls_accs.update(cls_acc.item(), inputs.size(0))

            if i % args.print_freq == 0:
                progress.display(i)
        lr_scheduler.step()

    # calculate the channel attention
    print('Calculating channel attention.')
    classifier.eval()
    if args.attention_iteration_limit > 0:
        total_iteration = min(len(data_loader), args.attention_iteration_limit)
    else:
        total_iteration = len(args.data_loader)

    progress = ProgressMeter(total_iteration, [], prefix="Iteration: ")

    for i, data in enumerate(data_loader):
        if i >= total_iteration:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs, _ = classifier(inputs)
        loss_0 = criterion(outputs, labels)
        progress.display(i)
        for layer_id, name in enumerate(tqdm(return_layers)):
            layer = get_attribute(classifier, name)
            for j in range(layer.out_channels):
                tmp = classifier.state_dict()[name + '.weight'][j, ].clone()
                classifier.state_dict()[name + '.weight'][j, ] = 0.0
                outputs, _ = classifier(inputs)
                loss_1 = criterion(outputs, labels)
                difference = loss_1 - loss_0
                difference = difference.detach().cpu().numpy().item()
                history_value = channel_weights[layer_id][j]
                channel_weights[layer_id][j] = 1.0 * (i * history_value +
                                                      difference) / (i + 1)
                classifier.state_dict()[name + '.weight'][j, ] = tmp

    channel_attention = []
    for weight in channel_weights:
        weight = np.array(weight)
        weight = (weight - np.mean(weight)) / np.std(weight)
        weight = torch.from_numpy(weight).float().to(device)
        channel_attention.append(F.softmax(weight / 5).detach())
    return channel_attention
Beispiel #29
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, interp, criterion,
          dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD,
          lr_scheduler_d: LambdaLR, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')
    losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        accuracies_s, accuracies_t, iou_s, iou_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()
        optimizer_d.zero_grad()

        # Step 1: Train the segmentation network, freeze the discriminator
        dann.eval()
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # adversarial training to fool the discriminator
        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_transfer = dann(pred_t, 'source')
        (loss_transfer * args.trade_off).backward()

        # Step 2: Train the discriminator
        dann.train()
        loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') +
                                    dann(pred_t.detach(), 'target'))
        loss_discriminator.backward()

        # compute gradient and do SGD step
        optimizer.step()
        optimizer_d.step()
        lr_scheduler.step()
        lr_scheduler_d.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
Beispiel #30
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, optimizer: Adam,
          xbm: XBM, epoch: int, args: argparse.Namespace):
    # define loss function
    criterion_ce = BridgeProbLoss(args.n_s_classes +
                                  args.n_t_classes).to(device)
    criterion_triplet = TripletLoss(margin=args.margin).to(device)
    criterion_triplet_xbm = TripletLossXBM(margin=args.margin).to(device)
    criterion_bridge_feat = BridgeFeatLoss().to(device)
    criterion_diverse = DivLoss().to(device)

    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_ce = AverageMeter('CeLoss', ':3.2f')
    losses_triplet = AverageMeter('TripletLoss', ':3.2f')
    losses_triplet_xbm = AverageMeter('XBMTripletLoss', ':3.2f')
    losses_bridge_prob = AverageMeter('BridgeProbLoss', ':3.2f')
    losses_bridge_feat = AverageMeter('BridgeFeatLoss', ':3.2f')
    losses_diverse = AverageMeter('DiverseLoss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')

    cls_accs_s = AverageMeter('Src Cls Acc', ':3.1f')
    cls_accs_t = AverageMeter('Tgt Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_ce, losses_triplet, losses_triplet_xbm,
        losses_bridge_prob, losses_bridge_feat, losses_diverse, losses,
        cls_accs_s, cls_accs_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, _, labels_s, _ = next(train_source_iter)
        x_t, _, labels_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # arrange batch for domain-specific BN
        device_num = torch.cuda.device_count()
        B, C, H, W = x_s.size()

        def reshape(tensor):
            return tensor.view(device_num, -1, C, H, W)

        x_s, x_t = reshape(x_s), reshape(x_t)
        x = torch.cat((x_s, x_t), 1).view(-1, C, H, W)

        labels = torch.cat(
            (labels_s.view(device_num, -1), labels_t.view(device_num, -1)), 1)
        labels = labels.view(-1)

        # compute output
        y, f, attention_lam = model(x, stage=args.stage)
        y = y[:, 0:args.n_s_classes + args.
              n_t_classes]  # only (n_s_classes + n_t_classes) classes are meaningful

        # split feats
        ori_f = f.view(device_num, -1, f.size(-1))
        f_s, f_t, f_mixed = ori_f.split(ori_f.size(1) // 3, dim=1)
        ori_f = torch.cat((f_s, f_t), 1).view(-1, ori_f.size(-1))

        # cross entropy loss
        loss_ce, loss_bridge_prob = criterion_ce(y, labels,
                                                 attention_lam[:, 0].detach(),
                                                 device_num)
        # triplet loss
        loss_triplet = criterion_triplet(ori_f, labels)
        # diverse loss
        loss_diverse = criterion_diverse(attention_lam)
        # bridge feature loss
        f_s = f_s.contiguous().view(-1, f.size(-1))
        f_t = f_t.contiguous().view(-1, f.size(-1))
        f_mixed = f_mixed.contiguous().view(-1, f.size(-1))
        loss_bridge_feat = criterion_bridge_feat(f_s, f_t, f_mixed,
                                                 attention_lam)
        # xbm triplet loss
        xbm.enqueue_dequeue(ori_f.detach(), labels.detach())
        xbm_f, xbm_labels = xbm.get()
        loss_triplet_xbm = criterion_triplet_xbm(ori_f, labels, xbm_f,
                                                 xbm_labels)

        loss = (1. - args.mu1) * loss_ce + loss_triplet + loss_triplet_xbm + \
               args.mu1 * loss_bridge_prob + args.mu2 * loss_bridge_feat + args.mu3 * loss_diverse

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ori_y = y.view(device_num, -1, y.size(-1))
        y_s, y_t, _ = ori_y.split(ori_y.size(1) // 3, dim=1)
        cls_acc_s = accuracy(y_s.reshape(-1, y_s.size(-1)), labels_s)[0]
        cls_acc_t = accuracy(y_t.reshape(-1, y_t.size(-1)), labels_t)[0]

        # update statistics
        losses_ce.update(loss_ce.item(), x_s.size(0))
        losses_triplet.update(loss_triplet.item(), x_s.size(0))
        losses_triplet_xbm.update(loss_triplet_xbm.item(), x_s.size(0))
        losses_bridge_prob.update(loss_bridge_prob.item(), x_s.size(0))
        losses_bridge_feat.update(loss_bridge_feat.item(), x_s.size(0))
        losses_diverse.update(loss_diverse.item(), x_s.size(0))
        losses.update(loss.item(), x_s.size(0))

        cls_accs_s.update(cls_acc_s.item(), x_s.size(0))
        cls_accs_t.update(cls_acc_t.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)