def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output, _ = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg
def validate(val_loader, model, args, device, visualize=None) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, = accuracy(output, target, topk=(1, ))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
                if visualize is not None:
                    visualize(images[0], "val_{}".format(i))

        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

    return top1.avg
def extract_reid_feature(data_loader,
                         model,
                         device,
                         normalize,
                         print_freq=200):
    """Extract feature for person ReID. If `normalize` is True, `cosine` distance will be employed as distance
    metric, otherwise `euclidean` distance.
    """
    batch_time = AverageMeter('Time', ':6.3f')
    progress = ProgressMeter(len(data_loader), [batch_time],
                             prefix='Collect feature: ')

    # switch to eval mode
    model.eval()
    feature_dict = dict()

    with torch.no_grad():
        end = time.time()
        for i, (images_batch, filenames_batch, _, _) in enumerate(data_loader):

            images_batch = images_batch.to(device)
            features_batch = model(images_batch)
            if normalize:
                features_batch = F.normalize(features_batch)

            for filename, feature in zip(filenames_batch, features_batch):
                feature_dict[filename] = feature

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)

    return feature_dict
Exemple #4
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':5.4f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mkmmd_loss.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = mkmmd_loss(f_s, f_t)
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, optimizer: Adam,
          xbm: XBM, epoch: int, args: argparse.Namespace):
    # define loss function
    criterion_ce = BridgeProbLoss(args.n_s_classes +
                                  args.n_t_classes).to(device)
    criterion_triplet = TripletLoss(margin=args.margin).to(device)
    criterion_triplet_xbm = TripletLossXBM(margin=args.margin).to(device)
    criterion_bridge_feat = BridgeFeatLoss().to(device)
    criterion_diverse = DivLoss().to(device)

    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_ce = AverageMeter('CeLoss', ':3.2f')
    losses_triplet = AverageMeter('TripletLoss', ':3.2f')
    losses_triplet_xbm = AverageMeter('XBMTripletLoss', ':3.2f')
    losses_bridge_prob = AverageMeter('BridgeProbLoss', ':3.2f')
    losses_bridge_feat = AverageMeter('BridgeFeatLoss', ':3.2f')
    losses_diverse = AverageMeter('DiverseLoss', ':3.2f')
    losses = AverageMeter('Loss', ':3.2f')

    cls_accs_s = AverageMeter('Src Cls Acc', ':3.1f')
    cls_accs_t = AverageMeter('Tgt Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_ce, losses_triplet, losses_triplet_xbm,
        losses_bridge_prob, losses_bridge_feat, losses_diverse, losses,
        cls_accs_s, cls_accs_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, _, labels_s, _ = next(train_source_iter)
        x_t, _, labels_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # arrange batch for domain-specific BN
        device_num = torch.cuda.device_count()
        B, C, H, W = x_s.size()

        def reshape(tensor):
            return tensor.view(device_num, -1, C, H, W)

        x_s, x_t = reshape(x_s), reshape(x_t)
        x = torch.cat((x_s, x_t), 1).view(-1, C, H, W)

        labels = torch.cat(
            (labels_s.view(device_num, -1), labels_t.view(device_num, -1)), 1)
        labels = labels.view(-1)

        # compute output
        y, f, attention_lam = model(x, stage=args.stage)
        y = y[:, 0:args.n_s_classes + args.
              n_t_classes]  # only (n_s_classes + n_t_classes) classes are meaningful

        # split feats
        ori_f = f.view(device_num, -1, f.size(-1))
        f_s, f_t, f_mixed = ori_f.split(ori_f.size(1) // 3, dim=1)
        ori_f = torch.cat((f_s, f_t), 1).view(-1, ori_f.size(-1))

        # cross entropy loss
        loss_ce, loss_bridge_prob = criterion_ce(y, labels,
                                                 attention_lam[:, 0].detach(),
                                                 device_num)
        # triplet loss
        loss_triplet = criterion_triplet(ori_f, labels)
        # diverse loss
        loss_diverse = criterion_diverse(attention_lam)
        # bridge feature loss
        f_s = f_s.contiguous().view(-1, f.size(-1))
        f_t = f_t.contiguous().view(-1, f.size(-1))
        f_mixed = f_mixed.contiguous().view(-1, f.size(-1))
        loss_bridge_feat = criterion_bridge_feat(f_s, f_t, f_mixed,
                                                 attention_lam)
        # xbm triplet loss
        xbm.enqueue_dequeue(ori_f.detach(), labels.detach())
        xbm_f, xbm_labels = xbm.get()
        loss_triplet_xbm = criterion_triplet_xbm(ori_f, labels, xbm_f,
                                                 xbm_labels)

        loss = (1. - args.mu1) * loss_ce + loss_triplet + loss_triplet_xbm + \
               args.mu1 * loss_bridge_prob + args.mu2 * loss_bridge_feat + args.mu3 * loss_diverse

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ori_y = y.view(device_num, -1, y.size(-1))
        y_s, y_t, _ = ori_y.split(ori_y.size(1) // 3, dim=1)
        cls_acc_s = accuracy(y_s.reshape(-1, y_s.size(-1)), labels_s)[0]
        cls_acc_t = accuracy(y_t.reshape(-1, y_t.size(-1)), labels_t)[0]

        # update statistics
        losses_ce.update(loss_ce.item(), x_s.size(0))
        losses_triplet.update(loss_triplet.item(), x_s.size(0))
        losses_triplet_xbm.update(loss_triplet_xbm.item(), x_s.size(0))
        losses_bridge_prob.update(loss_bridge_prob.item(), x_s.size(0))
        losses_bridge_feat.update(loss_bridge_feat.item(), x_s.size(0))
        losses_diverse.update(loss_diverse.item(), x_s.size(0))
        losses.update(loss.item(), x_s.size(0))

        cls_accs_s.update(cls_acc_s.item(), x_s.size(0))
        cls_accs_t.update(cls_acc_t.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    source_losses = AverageMeter('Source Loss', ':6.3f')
    trans_losses = AverageMeter('Trans Loss', ':6.3f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, source_losses, trans_losses, mae_losses_s,
        mae_losses_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat([x_s, x_t], dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute mean square loss on source domain
        mse_loss = F.mse_loss(y_s, labels_s)

        # compute margin disparity discrepancy between domains
        transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        loss = mse_loss - transfer_loss * args.trade_off
        model.step()

        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)

        source_losses.update(mse_loss.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #7
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    norm_losses = AverageMeter('Norm Loss', ':3.2f')
    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, cls_losses, norm_losses, src_feature_norm,
        tgt_feature_norm, cls_accs, tgt_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # norm loss
        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
        loss = cls_loss + norm_loss * args.trade_off_norm

        # using entropy minimization
        if args.trade_off_entropy:
            y_t = F.softmax(y_t, dim=1)
            entropy_loss = entropy(y_t, reduction='mean')
            loss += entropy_loss * args.trade_off_entropy

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        cls_losses.update(cls_loss.item(), x_s.size(0))
        norm_losses.update(norm_loss.item(), x_s.size(0))
        src_feature_norm.update(
            f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
        tgt_feature_norm.update(
            f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #8
0
def train(train_source_iter, train_target_iter, model, criterion,
          regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
          lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ":.2e")
    losses_gf = AverageMeter('Loss (t, false)', ":.2e")
    losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
    acc_s = AverageMeter("Acc (s)", ":3.2f")
    acc_t = AverageMeter("Acc (t)", ":3.2f")
    acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
    acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t,
        acc_s_adv, acc_t_adv
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, label_s, weight_s, meta_s = next(train_source_iter)
        x_t, label_t, weight_t, meta_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.to(device)
        weight_s = weight_s.to(device)

        x_t = x_t.to(device)
        label_t = label_t.to(device)
        weight_t = weight_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_f.zero_grad()
        optimizer_h.zero_grad()
        optimizer_h_adv.zero_grad()

        y_s, y_s_adv = model(x_s)
        loss_s = criterion(y_s, label_s, weight_s) + \
                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
        loss_s.backward()
        optimizer_f.step()
        optimizer_h.step()
        optimizer_h_adv.step()

        # Step B train adv regressor to maximize regression disparity
        optimizer_h_adv.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_false = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='max')
        loss_ground_false.backward()
        optimizer_h_adv.step()

        # Step C train feature extractor to minimize regression disparity
        optimizer_f.zero_grad()
        y_t, y_t_adv = model(x_t)
        loss_ground_truth = args.trade_off * regression_disparity(
            y_t, y_t_adv, weight_t, mode='min')
        loss_ground_truth.backward()
        optimizer_f.step()

        # do update step
        model.step()
        lr_scheduler_f.step()
        lr_scheduler_h.step()
        lr_scheduler_h_adv.step()

        # measure accuracy and record loss
        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
                                               label_s.detach().cpu().numpy())
        acc_s.update(avg_acc_s, cnt_s)
        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
                                               label_t.detach().cpu().numpy())
        acc_t.update(avg_acc_t, cnt_t)
        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(
            y_s_adv.detach().cpu().numpy(),
            label_s.detach().cpu().numpy())
        acc_s_adv.update(avg_acc_s_adv, cnt_s)
        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(
            y_t_adv.detach().cpu().numpy(),
            label_t.detach().cpu().numpy())
        acc_t_adv.update(avg_acc_t_adv, cnt_t)
        losses_s.update(loss_s, cnt_s)
        losses_gf.update(loss_ground_false, cnt_s)
        losses_gt.update(loss_ground_truth, cnt_s)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            if visualize is not None:
                visualize(x_s[0],
                          pred_s[0] * args.image_size / args.heatmap_size,
                          "source_{}_pred".format(i))
                visualize(x_s[0], meta_s['keypoint2d'][0],
                          "source_{}_label".format(i))
                visualize(x_t[0],
                          pred_t[0] * args.image_size / args.heatmap_size,
                          "target_{}_pred".format(i))
                visualize(x_t[0], meta_t['keypoint2d'][0],
                          "target_{}_label".format(i))
                visualize(x_s[0],
                          pred_s_adv[0] * args.image_size / args.heatmap_size,
                          "source_adv_{}_pred".format(i))
                visualize(x_t[0],
                          pred_t_adv[0] * args.image_size / args.heatmap_size,
                          "target_adv_{}_pred".format(i))
Exemple #9
0
def train(train_iter: ForeverDataIterator, model: Classifier,
          backbone_regularization: nn.Module, head_regularization: nn.Module,
          target_getter: IntermediateLayerGetter,
          source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')
    losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, losses_reg_head, losses_reg_backbone,
        cls_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels = next(train_iter)
        x = x.to(device)
        label = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        intermediate_output_s, output_s = source_getter(x)
        intermediate_output_t, output_t = target_getter(x)
        y, f = output_t

        # measure accuracy and record loss
        cls_acc = accuracy(y, label)[0]
        cls_loss = F.cross_entropy(y, label)
        if args.regularization_type == 'feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        elif args.regularization_type == 'attention_feature_map':
            loss_reg_backbone = backbone_regularization(
                intermediate_output_s, intermediate_output_t)
        else:
            loss_reg_backbone = backbone_regularization()
        loss_reg_head = head_regularization()
        loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head

        losses_reg_backbone.update(
            loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))
        losses_reg_head.update(loss_reg_head.item() * args.trade_off_head,
                               x.size(0))
        losses.update(loss.item(), x.size(0))
        cls_accs.update(cls_acc.item(), x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #10
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          teacher: EmaTeacher, consistent_loss, class_balance_loss,
          optimizer: Adam, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    cons_losses = AverageMeter('Cons Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, cls_losses, cons_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    teacher.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        (x_t1, x_t2), labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t1 = x_t1.to(device)
        x_t2 = x_t2.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, _ = model(x_s)
        y_t, _ = model(x_t1)
        y_t_teacher, _ = teacher(x_t2)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # compute output and mask
        y_t = F.softmax(y_t, dim=1)
        y_t_teacher = F.softmax(y_t_teacher, dim=1)
        max_prob, _ = y_t_teacher.max(dim=1)
        mask = (max_prob > args.threshold).float()

        # consistent loss
        cons_loss = consistent_loss(y_t, y_t_teacher, mask)
        # balance loss
        balance_loss = class_balance_loss(y_t) * mask.mean()

        loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # update teacher
        teacher.update()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]
        cls_losses.update(cls_loss.item(), x_s.size(0))
        cons_losses.update(cons_loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #11
0
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer, lr_scheduler: CosineAnnealingLR,
          correlation_alignment_loss: CorrelationAlignmentLoss, n_domains_per_batch: int, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_ce = AverageMeter('CELoss', ':3.2f')
    losses_penalty = AverageMeter('Penalty Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_all, labels_all, _ = next(train_iter)
        x_all = x_all.to(device)
        labels_all = labels_all.to(device)

        # compute output
        y_all, f_all = model(x_all)

        # measure data loading time
        data_time.update(time.time() - end)

        # separate into different domains
        y_all = y_all.chunk(n_domains_per_batch, dim=0)
        f_all = f_all.chunk(n_domains_per_batch, dim=0)
        labels_all = labels_all.chunk(n_domains_per_batch, dim=0)

        loss_ce = 0
        loss_penalty = 0
        cls_acc = 0
        for domain_i in range(n_domains_per_batch):
            # cls loss
            y_i, labels_i = y_all[domain_i], labels_all[domain_i]
            loss_ce += F.cross_entropy(y_i, labels_i)
            # update acc
            cls_acc += accuracy(y_i, labels_i)[0] / n_domains_per_batch
            # correlation alignment loss
            for domain_j in range(domain_i + 1, n_domains_per_batch):
                f_i = f_all[domain_i]
                f_j = f_all[domain_j]
                loss_penalty += correlation_alignment_loss(f_i, f_j)

        # normalize loss
        loss_ce /= n_domains_per_batch
        loss_penalty /= n_domains_per_batch * (n_domains_per_batch - 1) / 2

        loss = loss_ce + loss_penalty * args.trade_off

        losses.update(loss.item(), x_all.size(0))
        losses_ce.update(loss_ce.item(), x_all.size(0))
        losses_penalty.update(loss_penalty.item(), x_all.size(0))
        cls_accs.update(cls_acc.item(), x_all.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #12
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, G: nn.Module,
          F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD,
          optimizer_f: SGD, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    G.train()
    F1.train()
    F2.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)
        x = torch.cat((x_s, x_t), dim=0)
        assert x.requires_grad is False

        # measure data loading time
        data_time.update(time.time() - end)

        # Step A train all networks to minimize loss on source domain
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

        g = G(x)
        y_1 = F1(g)
        y_2 = F2(g)
        y1_s, y1_t = y_1.chunk(2, dim=0)
        y2_s, y2_t = y_2.chunk(2, dim=0)

        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
               0.01 * (entropy(y1_t) + entropy(y2_t))
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        # Step B train classifier to maximize discrepancy
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()

        g = G(x)
        y_1 = F1(g)
        y_2 = F2(g)
        y1_s, y1_t = y_1.chunk(2, dim=0)
        y2_s, y2_t = y_2.chunk(2, dim=0)
        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
               0.01 * (entropy(y1_t) + entropy(y2_t)) - classifier_discrepancy(y1_t, y2_t) * args.trade_off
        loss.backward()
        optimizer_f.step()

        # Step C train genrator to minimize discrepancy
        for k in range(args.num_k):
            optimizer_g.zero_grad()
            g = G(x)
            y_1 = F1(g)
            y_2 = F2(g)
            y1_s, y1_t = y_1.chunk(2, dim=0)
            y2_s, y2_t = y_2.chunk(2, dim=0)
            y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
            mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off
            mcd_loss.backward()
            optimizer_g.step()

        cls_acc = accuracy(y1_s, labels_s)[0]
        tgt_acc = accuracy(y1_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        trans_losses.update(mcd_loss.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: Regressor,
          domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    mse_losses = AverageMeter('MSE Loss', ':6.3f')
    dann_losses = AverageMeter('DANN Loss', ':6.3f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, mse_losses, dann_losses, mae_losses_s,
        mae_losses_t, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, labels_s = next(train_source_iter)
        x_s = x_s.to(device)
        labels_s = labels_s.to(device).float()
        x_t, labels_t = next(train_target_iter)
        x_t = x_t.to(device)
        labels_t = labels_t.to(device).float()

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        mse_loss = F.mse_loss(y_s, labels_s)
        mae_loss_s = F.l1_loss(y_s, labels_s)
        mae_loss_t = F.l1_loss(y_t, labels_t)
        transfer_loss = domain_adv(f_s, f_t)
        loss = mse_loss + transfer_loss * args.trade_off
        domain_acc = domain_adv.domain_discriminator_accuracy

        mse_losses.update(mse_loss.item(), x_s.size(0))
        dann_losses.update(transfer_loss.item(), x_s.size(0))
        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #14
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, netG_S2T, netG_T2S, netD_S,
          netD_T, siamese_net: spgan.SiameseNetwork,
          criterion_gan: cyclegan.LeastSquaresGenerativeAdversarialLoss,
          criterion_cycle: nn.L1Loss, criterion_identity: nn.L1Loss,
          criterion_contrastive: spgan.ContrastiveLoss, optimizer_G: Adam,
          optimizer_D: Adam, optimizer_siamese: Adam, fake_S_pool: ImagePool,
          fake_T_pool: ImagePool, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')
    losses_contrastive_G = AverageMeter('contrastive_G', ':3.2f')
    losses_contrastive_siamese = AverageMeter('contrastive_siamese', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T, losses_contrastive_G, losses_contrastive_siamese
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _, _, _ = next(train_source_iter)
        real_T, _, _, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # ===============================================
        # train the generators (every two iterations)
        # ===============================================
        if i % 2 == 0:
            # save memory
            set_requires_grad(netD_S, False)
            set_requires_grad(netD_T, False)
            set_requires_grad(siamese_net, False)
            # GAN loss D_T(G_S2T(S))
            loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
            # GAN loss D_S(G_T2S(B))
            loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
            # Cycle loss || G_T2S(G_S2T(S)) - S||
            loss_cycle_S = criterion_cycle(rec_S,
                                           real_S) * args.trade_off_cycle
            # Cycle loss || G_S2T(G_T2S(T)) - T||
            loss_cycle_T = criterion_cycle(rec_T,
                                           real_T) * args.trade_off_cycle
            # Identity loss
            # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
            identity_T = netG_S2T(real_T)
            loss_identity_T = criterion_identity(
                identity_T, real_T) * args.trade_off_identity
            # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
            identity_S = netG_T2S(real_S)
            loss_identity_S = criterion_identity(
                identity_S, real_S) * args.trade_off_identity

            # siamese network output
            f_real_S = siamese_net(real_S)
            f_fake_T = siamese_net(fake_T)
            f_real_T = siamese_net(real_T)
            f_fake_S = siamese_net(fake_S)

            # positive pair
            loss_contrastive_p_G = criterion_contrastive(f_real_S, f_fake_T, 0) + \
                                   criterion_contrastive(f_real_T, f_fake_S, 0)
            # negative pair
            loss_contrastive_n_G = criterion_contrastive(f_fake_T, f_real_T, 1) + \
                                   criterion_contrastive(f_fake_S, f_real_S, 1) + \
                                   criterion_contrastive(f_real_S, f_real_T, 1)
            # contrastive loss
            loss_contrastive_G = (
                loss_contrastive_p_G +
                0.5 * loss_contrastive_n_G) / 4 * args.trade_off_contrastive

            # combined loss and calculate gradients
            loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
            if epoch > 1:
                loss_G += loss_contrastive_G
            netG_S2T.zero_grad()
            netG_T2S.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            # update corresponding statistics
            losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
            losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
            losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
            losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
            losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
            losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
            if epoch > 1:
                losses_contrastive_G.update(loss_contrastive_G, real_S.size(0))

        # ===============================================
        # train the siamese network (when epoch > 0)
        # ===============================================
        if epoch > 0:
            set_requires_grad(siamese_net, True)
            # siamese network output
            f_real_S = siamese_net(real_S)
            f_fake_T = siamese_net(fake_T.detach())
            f_real_T = siamese_net(real_T)
            f_fake_S = siamese_net(fake_S.detach())

            # positive pair
            loss_contrastive_p_siamese = criterion_contrastive(f_real_S, f_fake_T, 0) + \
                                         criterion_contrastive(f_real_T, f_fake_S, 0)
            # negative pair
            loss_contrastive_n_siamese = criterion_contrastive(
                f_real_S, f_real_T, 1)
            # contrastive loss
            loss_contrastive_siamese = (loss_contrastive_p_siamese +
                                        2 * loss_contrastive_n_siamese) / 3

            # update siamese network
            siamese_net.zero_grad()
            loss_contrastive_siamese.backward()
            optimizer_siamese.step()

            # update corresponding statistics
            losses_contrastive_siamese.update(loss_contrastive_siamese,
                                              real_S.size(0))

        # ===============================================
        # train the discriminators
        # ===============================================

        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))

        # update discriminators
        netD_S.zero_grad()
        netD_T.zero_grad()
        loss_D_S.backward()
        loss_D_T.backward()
        optimizer_D.step()

        # update corresponding statistics
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))
Exemple #15
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model, interp, criterion,
          dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD,
          lr_scheduler_d: LambdaLR, epoch: int, visualize,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')
    losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        accuracies_s, accuracies_t, iou_s, iou_t
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i in range(args.iters_per_epoch):
        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()
        optimizer_d.zero_grad()

        # Step 1: Train the segmentation network, freeze the discriminator
        dann.eval()
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        # adversarial training to fool the discriminator
        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_transfer = dann(pred_t, 'source')
        (loss_transfer * args.trade_off).backward()

        # Step 2: Train the discriminator
        dann.train()
        loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') +
                                    dann(pred_t.detach(), 'target'))
        loss_discriminator.backward()

        # compute gradient and do SGD step
        optimizer.step()
        optimizer_d.step()
        lr_scheduler.step()
        lr_scheduler_d.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
Exemple #16
0
def calculate_channel_attention(dataset, return_layers, args):
    backbone = models.__dict__[args.arch](pretrained=True)
    classifier = Classifier(backbone, dataset.num_classes).to(device)
    optimizer = SGD(classifier.get_parameters(args.lr),
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    data_loader = DataLoader(dataset,
                             batch_size=args.attention_batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             drop_last=False)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs))
    criterion = nn.CrossEntropyLoss()

    channel_weights = []
    for layer_id, name in enumerate(return_layers):
        layer = get_attribute(classifier, name)
        layer_channel_weight = [0] * layer.out_channels
        channel_weights.append(layer_channel_weight)

    # train the classifier
    classifier.train()
    classifier.backbone.requires_grad = False
    print("Pretrain a classifier to calculate channel attention.")
    for epoch in range(args.attention_epochs):
        losses = AverageMeter('Loss', ':3.2f')
        cls_accs = AverageMeter('Cls Acc', ':3.1f')
        progress = ProgressMeter(len(data_loader), [losses, cls_accs],
                                 prefix="Epoch: [{}]".format(epoch))

        for i, data in enumerate(data_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs, _ = classifier(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_acc = accuracy(outputs, labels)[0]

            losses.update(loss.item(), inputs.size(0))
            cls_accs.update(cls_acc.item(), inputs.size(0))

            if i % args.print_freq == 0:
                progress.display(i)
        lr_scheduler.step()

    # calculate the channel attention
    print('Calculating channel attention.')
    classifier.eval()
    if args.attention_iteration_limit > 0:
        total_iteration = min(len(data_loader), args.attention_iteration_limit)
    else:
        total_iteration = len(args.data_loader)

    progress = ProgressMeter(total_iteration, [], prefix="Iteration: ")

    for i, data in enumerate(data_loader):
        if i >= total_iteration:
            break
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs, _ = classifier(inputs)
        loss_0 = criterion(outputs, labels)
        progress.display(i)
        for layer_id, name in enumerate(tqdm(return_layers)):
            layer = get_attribute(classifier, name)
            for j in range(layer.out_channels):
                tmp = classifier.state_dict()[name + '.weight'][j, ].clone()
                classifier.state_dict()[name + '.weight'][j, ] = 0.0
                outputs, _ = classifier(inputs)
                loss_1 = criterion(outputs, labels)
                difference = loss_1 - loss_0
                difference = difference.detach().cpu().numpy().item()
                history_value = channel_weights[layer_id][j]
                channel_weights[layer_id][j] = 1.0 * (i * history_value +
                                                      difference) / (i + 1)
                classifier.state_dict()[name + '.weight'][j, ] = tmp

    channel_attention = []
    for weight in channel_weights:
        weight = np.array(weight)
        weight = (weight - np.mean(weight)) / np.std(weight)
        weight = torch.from_numpy(weight).float().to(device)
        channel_attention.append(F.softmax(weight / 5).detach())
    return channel_attention
Exemple #17
0
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer,
          lr_scheduler: CosineAnnealingLR, n_domains_per_batch: int,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    losses_ce = AverageMeter('CELoss', ':3.2f')
    losses_penalty = AverageMeter('Penalty Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_all, labels_all, _ = next(train_iter)
        x_all = x_all.to(device)
        labels_all = labels_all.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_all, _ = model(x_all)

        loss_ce_per_domain = torch.zeros(n_domains_per_batch).to(device)
        for domain_id, (y_per_domain, labels_per_domain) in enumerate(
                zip(y_all.chunk(n_domains_per_batch, dim=0),
                    labels_all.chunk(n_domains_per_batch, dim=0))):
            loss_ce_per_domain[domain_id] = F.cross_entropy(
                y_per_domain, labels_per_domain)

        # cls loss
        loss_ce = loss_ce_per_domain.mean()
        # penalty loss
        loss_penalty = ((loss_ce_per_domain - loss_ce)**2).mean()

        global_iter = epoch * args.iters_per_epoch + i
        if global_iter >= args.anneal_iters:
            trade_off = args.trade_off
        else:
            trade_off = 1

        loss = loss_ce + loss_penalty * trade_off
        cls_acc = accuracy(y_all, labels_all)[0]

        losses.update(loss.item(), x_all.size(0))
        losses_ce.update(loss_ce.item(), x_all.size(0))
        losses_penalty.update(loss_penalty.item(), x_all.size(0))
        cls_accs.update(cls_acc.item(), x_all.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #18
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model,
          mdd: MarginDisparityDiscrepancy, optimizer: Adam,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    cls_adv_accs = AverageMeter('Cls Adv Acc', ':3.1f')
    tgt_adv_accs = AverageMeter('Tgt Adv Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs,
        cls_adv_accs, tgt_adv_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    mdd.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        outputs, outputs_adv = model(x)
        y_s, y_t = outputs.chunk(2, dim=0)
        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)

        # compute cross entropy loss on source domain
        cls_loss = F.cross_entropy(y_s, labels_s)
        # compute margin disparity discrepancy between domains
        # for adversarial classifier, minimize negative mdd is equal to maximize mdd
        transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv)
        loss = cls_loss + transfer_loss * args.trade_off
        model.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]
        cls_adv_acc = accuracy(y_s_adv, labels_s)[0]
        tgt_adv_acc = accuracy(y_t_adv, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        cls_adv_accs.update(cls_adv_acc.item(), x_s.size(0))
        tgt_adv_accs.update(tgt_adv_acc.item(), x_t.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #19
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model, interp, criterion, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_s = AverageMeter('Loss (s)', ':3.2f')
    losses_t = AverageMeter('Loss (t)', ':3.2f')
    losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f')
    accuracies_s = Meter('Acc (s)', ':3.2f')
    accuracies_t = Meter('Acc (t)', ':3.2f')
    iou_s = Meter('IoU (s)', ':3.2f')
    iou_t = Meter('IoU (t)', ':3.2f')

    confmat_s = ConfusionMatrix(model.num_classes)
    confmat_t = ConfusionMatrix(model.num_classes)
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses_s, losses_t, losses_entropy_t,
         accuracies_s, accuracies_t, iou_s, iou_t],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        optimizer.zero_grad()

        x_s, label_s = next(train_source_iter)
        x_t, label_t = next(train_target_iter)

        x_s = x_s.to(device)
        label_s = label_s.long().to(device)
        x_t = x_t.to(device)
        label_t = label_t.long().to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s = model(x_s)
        pred_s = interp(y_s)
        loss_cls_s = criterion(pred_s, label_s)
        loss_cls_s.backward()

        y_t = model(x_t)
        pred_t = interp(y_t)
        loss_cls_t = criterion(pred_t, label_t)
        loss_entropy_t = robust_entropy(y_t, args.ita)
        (args.entropy_weight * loss_entropy_t).backward()

        # compute gradient and do SGD step
        optimizer.step()
        lr_scheduler.step()

        # measure accuracy and record loss
        losses_s.update(loss_cls_s.item(), x_s.size(0))
        losses_t.update(loss_cls_t.item(), x_s.size(0))
        losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0))

        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
        acc_global_s, acc_s, iu_s = confmat_s.compute()
        acc_global_t, acc_t, iu_t = confmat_t.compute()
        accuracies_s.update(acc_s.mean().item())
        accuracies_t.update(acc_t.mean().item())
        iou_s.update(iu_s.mean().item())
        iou_t.update(iu_t.mean().item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            if visualize is not None:
                visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
                visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
Exemple #20
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv_D: DomainAdversarialLoss,
          domain_adv_D_0: DomainAdversarialLoss, importance_weight_module,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')
    domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')
    partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')
    non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D,
        domain_accs_D_0, partial_classes_weights, non_partial_classes_weights
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv_D.train()
    domain_adv_D_0.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)

        # domain adversarial loss for D
        adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())

        # get importance weights
        w_s = importance_weight_module.get_importance_weight(f_s)
        # domain adversarial loss for D_0
        adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)

        # entropy loss
        y_t = F.softmax(y_t, dim=1)
        entropy_loss = entropy(y_t, reduction='mean')

        loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \
               args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy,
                             x_s.size(0))
        domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy,
                               x_s.size(0))

        # debug: output class weight averaged on the partial classes and non-partial classes respectively
        partial_class_weight, non_partial_classes_weight = \
            importance_weight_module.get_partial_classes_weight(w_s, labels_s)
        partial_classes_weights.update(partial_class_weight.item(),
                                       x_s.size(0))
        non_partial_classes_weights.update(non_partial_classes_weight.item(),
                                           x_s.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemple #21
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          teacher: EmaTeacher, cluster_assignment: ASoftmax,
          cluster_distribution: ClusterDistribution, consistent_loss,
          class_balance_loss, kl_loss: nn.KLDivLoss, conditional_loss,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):

    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        print(y.shape)
        print(f.shape)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        cdt_loss = conditional_loss(y_t)
        kls_loss = kl_loss(cluster_assignment(f),
                           cluster_distribution.calculate(f))

        loss = cls_loss + cdt_loss + kls_loss

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_discri: DomainDiscriminator,
          domain_adv: DomainAdversarialLoss, gl, optimizer: SGD,
          lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses_s = AverageMeter('Cls Loss', ':6.2f')
    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
    losses_discriminator = AverageMeter('Discriminator Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
        cls_accs, domain_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Step 1: Train the classifier, freeze the discriminator
        model.train()
        domain_discri.eval()
        set_requires_grad(model, True)
        set_requires_grad(domain_discri, False)
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        loss_s = F.cross_entropy(y_s, labels_s)

        # adversarial training to fool the discriminator
        d = domain_discri(gl(f))
        d_s, d_t = d.chunk(2, dim=0)
        loss_transfer = 0.5 * (domain_adv(d_s, 'target') +
                               domain_adv(d_t, 'source'))

        optimizer.zero_grad()
        (loss_s + loss_transfer * args.trade_off).backward()
        optimizer.step()
        lr_scheduler.step()

        # Step 2: Train the discriminator
        model.eval()
        domain_discri.train()
        set_requires_grad(model, False)
        set_requires_grad(domain_discri, True)
        d = domain_discri(f.detach())
        d_s, d_t = d.chunk(2, dim=0)
        loss_discriminator = 0.5 * (domain_adv(d_s, 'source') +
                                    domain_adv(d_t, 'target'))

        optimizer_d.zero_grad()
        loss_discriminator.backward()
        optimizer_d.step()
        lr_scheduler_d.step()

        losses_s.update(loss_s.item(), x_s.size(0))
        losses_transfer.update(loss_transfer.item(), x_s.size(0))
        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))

        cls_acc = accuracy(y_s, labels_s)[0]
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_acc = 0.5 * (binary_accuracy(d_s, torch.ones_like(d_s)) +
                            binary_accuracy(d_t, torch.zeros_like(d_t)))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
          netD_T, criterion_gan, criterion_cycle, criterion_identity,
          optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int,
          visualize, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
    losses_D_S = AverageMeter('D_S', ':3.2f')
    losses_D_T = AverageMeter('D_T', ':3.2f')
    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
    losses_identity_S = AverageMeter('idt_S', ':3.2f')
    losses_identity_T = AverageMeter('idt_T', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S,
        losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S,
        losses_identity_T
    ],
                             prefix="Epoch: [{}]".format(epoch))

    end = time.time()

    for i in range(args.iters_per_epoch):
        real_S, _ = next(train_source_iter)
        real_T, _ = next(train_target_iter)

        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # Compute fake images and reconstruction images.
        fake_T = netG_S2T(real_S)
        rec_S = netG_T2S(fake_T)
        fake_S = netG_T2S(real_T)
        rec_T = netG_S2T(fake_S)

        # Optimizing generators
        # discriminators require no gradients
        set_requires_grad(netD_S, False)
        set_requires_grad(netD_T, False)

        optimizer_G.zero_grad()
        # GAN loss D_T(G_S2T(S))
        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
        # GAN loss D_S(G_T2S(B))
        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
        # Cycle loss || G_T2S(G_S2T(S)) - S||
        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
        # Cycle loss || G_S2T(G_T2S(T)) - T||
        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
        # Identity loss
        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
        identity_T = netG_S2T(real_T)
        loss_identity_T = criterion_identity(identity_T,
                                             real_T) * args.trade_off_identity
        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
        identity_S = netG_T2S(real_S)
        loss_identity_S = criterion_identity(identity_S,
                                             real_S) * args.trade_off_identity
        # combined loss and calculate gradients
        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
        loss_G.backward()
        optimizer_G.step()

        # Optimize discriminator
        set_requires_grad(netD_S, True)
        set_requires_grad(netD_T, True)
        optimizer_D.zero_grad()
        # Calculate GAN loss for discriminator D_S
        fake_S_ = fake_S_pool.query(fake_S.detach())
        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) +
                          criterion_gan(netD_S(fake_S_), False))
        loss_D_S.backward()
        # Calculate GAN loss for discriminator D_T
        fake_T_ = fake_T_pool.query(fake_T.detach())
        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) +
                          criterion_gan(netD_T(fake_T_), False))
        loss_D_T.backward()
        optimizer_D.step()

        # measure elapsed time
        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
        losses_D_S.update(loss_D_S.item(), real_S.size(0))
        losses_D_T.update(loss_D_T.item(), real_S.size(0))
        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

            for tensor, name in zip([
                    real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S,
                    identity_T
            ], [
                    "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T",
                    "identity_S", "identity_T"
            ]):
                visualize(tensor[0], "{}_{}".format(i, name))
Exemple #24
0
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t)
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer,
          lr_scheduler: CosineAnnealingLR, epoch: int,
          n_domains_per_batch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':4.2f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch,
                             [batch_time, data_time, losses, cls_accs],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x, labels, _ = next(train_iter)
        x = x.to(device)
        labels = labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # split into support domain and query domain
        x_list = x.chunk(n_domains_per_batch, dim=0)
        labels_list = labels.chunk(n_domains_per_batch, dim=0)
        support_domain_list, query_domain_list = random_split(
            x_list, labels_list, n_domains_per_batch, args.n_support_domains)
        # clear grad
        optimizer.zero_grad()

        # compute output
        with higher.innerloop_ctx(
                model, optimizer,
                copy_initial_weights=False) as (inner_model, inner_optimizer):
            # perform inner optimization
            for _ in range(args.inner_iters):
                loss_inner = 0
                for (x_s, labels_s) in support_domain_list:
                    y_s, _ = inner_model(x_s)
                    # normalize loss by support domain num
                    loss_inner += F.cross_entropy(
                        y_s, labels_s) / args.n_support_domains

                inner_optimizer.step(loss_inner)

            # calculate outer loss
            loss_outer = 0
            cls_acc = 0

            # loss on support domains
            for (x_s, labels_s) in support_domain_list:
                y_s, _ = model(x_s)
                # normalize loss by support domain num
                loss_outer += F.cross_entropy(
                    y_s, labels_s) / args.n_support_domains

            # loss on query domains
            for (x_q, labels_q) in query_domain_list:
                y_q, _ = inner_model(x_q)
                # normalize loss by query domain num
                loss_outer += F.cross_entropy(
                    y_q, labels_q) * args.trade_off / args.n_query_domains
                cls_acc += accuracy(y_q, labels_q)[0] / args.n_query_domains

        # update statistics
        losses.update(loss_outer.item(), args.batch_size)
        cls_accs.update(cls_acc.item(), args.batch_size)

        # compute gradient and do SGD step
        loss_outer.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)