Ejemplo n.º 1
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD,
          epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    norm_losses = AverageMeter('Norm Loss', ':3.2f')
    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, cls_losses, norm_losses, src_feature_norm,
        tgt_feature_norm, cls_accs, tgt_accs
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)
        # norm loss
        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
        loss = cls_loss + norm_loss * args.trade_off_norm

        # using entropy minimization
        if args.trade_off_entropy:
            y_t = F.softmax(y_t, dim=1)
            entropy_loss = entropy(y_t, reduction='mean')
            loss += entropy_loss * args.trade_off_entropy

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update statistics
        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        cls_losses.update(cls_loss.item(), x_s.size(0))
        norm_losses.update(norm_loss.item(), x_s.size(0))
        src_feature_norm.update(
            f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
        tgt_feature_norm.update(
            f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Ejemplo n.º 2
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv_D: DomainAdversarialLoss,
          domain_adv_D_0: DomainAdversarialLoss, importance_weight_module,
          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
    domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')
    domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')
    partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')
    non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')

    progress = ProgressMeter(args.iters_per_epoch, [
        batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D,
        domain_accs_D_0, partial_classes_weights, non_partial_classes_weights
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv_D.train()
    domain_adv_D_0.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s = next(train_source_iter)
        x_t, labels_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        # classification loss
        cls_loss = F.cross_entropy(y_s, labels_s)

        # domain adversarial loss for D
        adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())

        # get importance weights
        w_s = importance_weight_module.get_importance_weight(f_s)
        # domain adversarial loss for D_0
        adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)

        # entropy loss
        y_t = F.softmax(y_t, dim=1)
        entropy_loss = entropy(y_t, reduction='mean')

        loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \
               args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_s.size(0))
        domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy,
                             x_s.size(0))
        domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy,
                               x_s.size(0))

        # debug: output class weight averaged on the partial classes and non-partial classes respectively
        partial_class_weight, non_partial_classes_weight = \
            importance_weight_module.get_partial_classes_weight(w_s, labels_s)
        partial_classes_weights.update(partial_class_weight.item(),
                                       x_s.size(0))
        non_partial_classes_weights.update(non_partial_classes_weight.item(),
                                           x_s.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)