def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') norm_losses = AverageMeter('Norm Loss', ':3.2f') src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f') tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs, tgt_accs ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # norm loss norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t) loss = cls_loss + norm_loss * args.trade_off_norm # using entropy minimization if args.trade_off_entropy: y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss += entropy_loss * args.trade_off_entropy # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # update statistics cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) norm_losses.update(norm_loss.item(), x_s.size(0)) src_feature_norm.update( f_s.norm(p=2, dim=1).mean().item(), x_s.size(0)) tgt_feature_norm.update( f_t.norm(p=2, dim=1).mean().item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv_D: DomainAdversarialLoss, domain_adv_D_0: DomainAdversarialLoss, importance_weight_module, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f') domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f') partial_classes_weights = AverageMeter('Partial Weight', ':3.2f') non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D, domain_accs_D_0, partial_classes_weights, non_partial_classes_weights ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv_D.train() domain_adv_D_0.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # domain adversarial loss for D adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach()) # get importance weights w_s = importance_weight_module.get_importance_weight(f_s) # domain adversarial loss for D_0 adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s) # entropy loss y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \ args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy, x_s.size(0)) domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy, x_s.size(0)) # debug: output class weight averaged on the partial classes and non-partial classes respectively partial_class_weight, non_partial_classes_weight = \ importance_weight_module.get_partial_classes_weight(w_s, labels_s) partial_classes_weights.update(partial_class_weight.item(), x_s.size(0)) non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)