def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_t = AverageMeter('Loss (t)', ':3.2f') losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f') accuracies_s = Meter('Acc (s)', ':3.2f') accuracies_t = Meter('Acc (t)', ':3.2f') iou_s = Meter('IoU (s)', ':3.2f') iou_t = Meter('IoU (t)', ':3.2f') confmat_s = ConfusionMatrix(model.num_classes) confmat_t = ConfusionMatrix(model.num_classes) progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses_s, losses_t, losses_entropy_t, accuracies_s, accuracies_t, iou_s, iou_t ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s = next(train_source_iter) x_t, label_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.long().to(device) x_t = x_t.to(device) label_t = label_t.long().to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) pred_s = interp(y_s) loss_cls_s = criterion(pred_s, label_s) loss_cls_s.backward() y_t = model(x_t) pred_t = interp(y_t) loss_cls_t = criterion(pred_t, label_t) loss_entropy_t = robust_entropy(y_t, args.ita) (args.entropy_weight * loss_entropy_t).backward() # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure accuracy and record loss losses_s.update(loss_cls_s.item(), x_s.size(0)) losses_t.update(loss_cls_t.item(), x_s.size(0)) losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0)) confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten()) confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten()) acc_global_s, acc_s, iu_s = confmat_s.compute() acc_global_t, acc_t, iu_t = confmat_t.compute() accuracies_s.update(acc_s.mean().item()) accuracies_t.update(acc_t.mean().item()) iou_s.update(iu_s.mean().item()) iou_t.update(iu_t.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i)) visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() mkmmd_loss.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = mkmmd_loss(f_s, f_t) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter, train_target_iter, model, criterion, regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") losses_gf = AverageMeter('Loss (t, false)', ":.2e") losses_gt = AverageMeter('Loss (t, truth)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") acc_t = AverageMeter("Acc (t)", ":3.2f") acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f") acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f") progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, label_s, weight_s, meta_s = next(train_source_iter) x_t, label_t, weight_t, meta_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) x_t = x_t.to(device) label_t = label_t.to(device) weight_t = weight_t.to(device) # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_f.zero_grad() optimizer_h.zero_grad() optimizer_h_adv.zero_grad() y_s, y_s_adv = model(x_s) loss_s = criterion(y_s, label_s, weight_s) + \ args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min') loss_s.backward() optimizer_f.step() optimizer_h.step() optimizer_h_adv.step() # Step B train adv regressor to maximize regression disparity optimizer_h_adv.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_false = args.trade_off * regression_disparity( y_t, y_t_adv, weight_t, mode='max') loss_ground_false.backward() optimizer_h_adv.step() # Step C train feature extractor to minimize regression disparity optimizer_f.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_truth = args.trade_off * regression_disparity( y_t, y_t_adv, weight_t, mode='min') loss_ground_truth.backward() optimizer_f.step() # do update step model.step() lr_scheduler_f.step() lr_scheduler_h.step() lr_scheduler_h_adv.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t.update(avg_acc_t, cnt_t) _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy( y_s_adv.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s_adv.update(avg_acc_s_adv, cnt_s) _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy( y_t_adv.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t_adv.update(avg_acc_t_adv, cnt_t) losses_s.update(loss_s, cnt_s) losses_gf.update(loss_ground_false, cnt_s) losses_gt.update(loss_ground_truth, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred".format(i)) visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label".format(i)) visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, "target_{}_pred".format(i)) visualize(x_t[0], meta_t['keypoint2d'][0], "target_{}_label".format(i)) visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, "source_adv_{}_pred".format(i)) visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, "target_adv_{}_pred".format(i))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_ce = AverageMeter('CeLoss', ':3.2f') losses_triplet = AverageMeter('TripletLoss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, _, labels_s, _ = next(train_source_iter) x_t, _, _, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) # cross entropy loss loss_ce = criterion_ce(y_s, labels_s) # triplet loss loss_triplet = criterion_triplet(f_s, f_s, labels_s) loss = loss_ce + loss_triplet * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses_ce.update(loss_ce.item(), x_s.size(0)) losses_triplet.update(loss_triplet.item(), x_s.size(0)) losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_iter: ForeverDataIterator, model: Classifier, backbone_regularization: nn.Module, head_regularization: nn.Module, target_getter: IntermediateLayerGetter, source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f') losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses, losses_reg_head, losses_reg_backbone, cls_accs ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output intermediate_output_s, output_s = source_getter(x) intermediate_output_t, output_t = target_getter(x) y, f = output_t # measure accuracy and record loss cls_acc = accuracy(y, label)[0] cls_loss = F.cross_entropy(y, label) if args.regularization_type == 'feature_map': loss_reg_backbone = backbone_regularization( intermediate_output_s, intermediate_output_t) elif args.regularization_type == 'attention_feature_map': loss_reg_backbone = backbone_regularization( intermediate_output_s, intermediate_output_t) else: loss_reg_backbone = backbone_regularization() loss_reg_head = head_regularization() loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head losses_reg_backbone.update( loss_reg_backbone.item() * args.trade_off_backbone, x.size(0)) losses_reg_head.update(loss_reg_head.item() * args.trade_off_head, x.size(0)) losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_G_S2T = AverageMeter('G_S2T', ':3.2f') losses_G_T2S = AverageMeter('G_T2S', ':3.2f') losses_D_S = AverageMeter('D_S', ':3.2f') losses_D_T = AverageMeter('D_T', ':3.2f') losses_cycle_S = AverageMeter('cycle_S', ':3.2f') losses_cycle_T = AverageMeter('cycle_T', ':3.2f') losses_identity_S = AverageMeter('idt_S', ':3.2f') losses_identity_T = AverageMeter('idt_T', ':3.2f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T ], prefix="Epoch: [{}]".format(epoch)) end = time.time() for i in range(args.iters_per_epoch): real_S, _ = next(train_source_iter) real_T, _ = next(train_target_iter) real_S = real_S.to(device) real_T = real_T.to(device) # measure data loading time data_time.update(time.time() - end) # Compute fake images and reconstruction images. fake_T = netG_S2T(real_S) rec_S = netG_T2S(fake_T) fake_S = netG_T2S(real_T) rec_T = netG_S2T(fake_S) # Optimizing generators # discriminators require no gradients set_requires_grad(netD_S, False) set_requires_grad(netD_T, False) optimizer_G.zero_grad() # GAN loss D_T(G_S2T(S)) loss_G_S2T = criterion_gan(netD_T(fake_T), real=True) # GAN loss D_S(G_T2S(B)) loss_G_T2S = criterion_gan(netD_S(fake_S), real=True) # Cycle loss || G_T2S(G_S2T(S)) - S|| loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle # Cycle loss || G_S2T(G_T2S(T)) - T|| loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle # Identity loss # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T|| identity_T = netG_S2T(real_T) loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S|| identity_S = netG_T2S(real_S) loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity # combined loss and calculate gradients loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T loss_G.backward() optimizer_G.step() # Optimize discriminator set_requires_grad(netD_S, True) set_requires_grad(netD_T, True) optimizer_D.zero_grad() # Calculate GAN loss for discriminator D_S fake_S_ = fake_S_pool.query(fake_S.detach()) loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False)) loss_D_S.backward() # Calculate GAN loss for discriminator D_T fake_T_ = fake_T_pool.query(fake_T.detach()) loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False)) loss_D_T.backward() optimizer_D.step() # measure elapsed time losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0)) losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0)) losses_D_S.update(loss_D_S.item(), real_S.size(0)) losses_D_T.update(loss_D_T.item(), real_S.size(0)) losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0)) losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0)) losses_identity_S.update(loss_identity_S.item(), real_S.size(0)) losses_identity_T.update(loss_identity_T.item(), real_S.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) for tensor, name in zip([ real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T ], [ "real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T", "identity_S", "identity_T" ]): visualize(tensor[0], "{}_{}".format(i, name))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') norm_losses = AverageMeter('Norm Loss', ':3.2f') src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f') tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs, tgt_accs ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # norm loss norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t) loss = cls_loss + norm_loss * args.trade_off_norm # using entropy minimization if args.trade_off_entropy: y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss += entropy_loss * args.trade_off_entropy # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # update statistics cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) norm_losses.update(norm_loss.item(), x_s.size(0)) src_feature_norm.update( f_s.norm(p=2, dim=1).mean().item(), x_s.size(0)) tgt_feature_norm.update( f_t.norm(p=2, dim=1).mean().item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def calculate_channel_attention(dataset, return_layers, args): backbone = models.__dict__[args.arch](pretrained=True) classifier = Classifier(backbone, dataset.num_classes).to(device) optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) data_loader = DataLoader(dataset, batch_size=args.attention_batch_size, shuffle=True, num_workers=args.workers, drop_last=False) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=math.exp(math.log(0.1) / args.attention_lr_decay_epochs)) criterion = nn.CrossEntropyLoss() channel_weights = [] for layer_id, name in enumerate(return_layers): layer = get_attribute(classifier, name) layer_channel_weight = [0] * layer.out_channels channel_weights.append(layer_channel_weight) # train the classifier classifier.train() classifier.backbone.requires_grad = False print("Pretrain a classifier to calculate channel attention.") for epoch in range(args.attention_epochs): losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(len(data_loader), [losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) for i, data in enumerate(data_loader): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs, _ = classifier(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() cls_acc = accuracy(outputs, labels)[0] losses.update(loss.item(), inputs.size(0)) cls_accs.update(cls_acc.item(), inputs.size(0)) if i % args.print_freq == 0: progress.display(i) lr_scheduler.step() # calculate the channel attention print('Calculating channel attention.') classifier.eval() if args.attention_iteration_limit > 0: total_iteration = min(len(data_loader), args.attention_iteration_limit) else: total_iteration = len(args.data_loader) progress = ProgressMeter(total_iteration, [], prefix="Iteration: ") for i, data in enumerate(data_loader): if i >= total_iteration: break inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs, _ = classifier(inputs) loss_0 = criterion(outputs, labels) progress.display(i) for layer_id, name in enumerate(tqdm(return_layers)): layer = get_attribute(classifier, name) for j in range(layer.out_channels): tmp = classifier.state_dict()[name + '.weight'][j, ].clone() classifier.state_dict()[name + '.weight'][j, ] = 0.0 outputs, _ = classifier(inputs) loss_1 = criterion(outputs, labels) difference = loss_1 - loss_0 difference = difference.detach().cpu().numpy().item() history_value = channel_weights[layer_id][j] channel_weights[layer_id][j] = 1.0 * (i * history_value + difference) / (i + 1) classifier.state_dict()[name + '.weight'][j, ] = tmp channel_attention = [] for weight in channel_weights: weight = np.array(weight) weight = (weight - np.mean(weight)) / np.std(weight) weight = torch.from_numpy(weight).float().to(device) channel_attention.append(F.softmax(weight / 5).detach()) return channel_attention
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Regressor, rsd, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') mse_losses = AverageMeter('MSE Loss', ':6.3f') rsd_losses = AverageMeter('RSD Loss', ':6.3f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, mse_losses, rsd_losses, mae_losses_s, mae_losses_t ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) mse_loss = F.mse_loss(y_s, labels_s) mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) rsd_loss = rsd(f_s, f_t) loss = mse_loss + rsd_loss * args.trade_off mse_losses.update(mse_loss.item(), x_s.size(0)) rsd_losses.update(rsd_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, mdd: MarginDisparityDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') source_losses = AverageMeter('Source Loss', ':6.3f') trans_losses = AverageMeter('Trans Loss', ':6.3f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, source_losses, trans_losses, mae_losses_s, mae_losses_t ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() mdd.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat([x_s, x_t], dim=0) outputs, outputs_adv = model(x) y_s, y_t = outputs.chunk(2, dim=0) y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0) # compute mean square loss on source domain mse_loss = F.mse_loss(y_s, labels_s) # compute margin disparity discrepancy between domains transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv) # for adversarial classifier, minimize negative mdd is equal to maximize mdd loss = mse_loss - transfer_loss * args.trade_off model.step() mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) source_losses.update(mse_loss.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, interp, criterion, dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_transfer = AverageMeter('Loss (transfer)', ':3.2f') losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f') accuracies_s = Meter('Acc (s)', ':3.2f') accuracies_t = Meter('Acc (t)', ':3.2f') iou_s = Meter('IoU (s)', ':3.2f') iou_t = Meter('IoU (t)', ':3.2f') confmat_s = ConfusionMatrix(model.num_classes) confmat_t = ConfusionMatrix(model.num_classes) progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses_s, losses_transfer, losses_discriminator, accuracies_s, accuracies_t, iou_s, iou_t ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, label_s = next(train_source_iter) x_t, label_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.long().to(device) x_t = x_t.to(device) label_t = label_t.long().to(device) # measure data loading time data_time.update(time.time() - end) optimizer.zero_grad() optimizer_d.zero_grad() # Step 1: Train the segmentation network, freeze the discriminator dann.eval() y_s = model(x_s) pred_s = interp(y_s) loss_cls_s = criterion(pred_s, label_s) loss_cls_s.backward() # adversarial training to fool the discriminator y_t = model(x_t) pred_t = interp(y_t) loss_transfer = dann(pred_t, 'source') (loss_transfer * args.trade_off).backward() # Step 2: Train the discriminator dann.train() loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') + dann(pred_t.detach(), 'target')) loss_discriminator.backward() # compute gradient and do SGD step optimizer.step() optimizer_d.step() lr_scheduler.step() lr_scheduler_d.step() # measure accuracy and record loss losses_s.update(loss_cls_s.item(), x_s.size(0)) losses_transfer.update(loss_transfer.item(), x_s.size(0)) losses_discriminator.update(loss_discriminator.item(), x_s.size(0)) confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten()) confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten()) acc_global_s, acc_s, iu_s = confmat_s.compute() acc_global_t, acc_t, iu_t = confmat_t.compute() accuracies_s.update(acc_s.mean().item()) accuracies_t.update(acc_t.mean().item()) iou_s.update(iu_s.mean().item()) iou_t.update(iu_t.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i)) visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD, optimizer_f: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode G.train() F1.train() F2.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) x = torch.cat((x_s, x_t), dim=0) assert x.requires_grad is False # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ 0.01 * (entropy(y1_t) + entropy(y2_t)) loss.backward() optimizer_g.step() optimizer_f.step() # Step B train classifier to maximize discrepancy optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ 0.01 * (entropy(y1_t) + entropy(y2_t)) - classifier_discrepancy(y1_t, y2_t) * args.trade_off loss.backward() optimizer_f.step() # Step C train genrator to minimize discrepancy for k in range(args.num_k): optimizer_g.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off mcd_loss.backward() optimizer_g.step() cls_acc = accuracy(y1_s, labels_s)[0] tgt_acc = accuracy(y1_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(mcd_loss.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int, n_domains_per_batch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter(args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels, _ = next(train_iter) x = x.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # split into support domain and query domain x_list = x.chunk(n_domains_per_batch, dim=0) labels_list = labels.chunk(n_domains_per_batch, dim=0) support_domain_list, query_domain_list = random_split( x_list, labels_list, n_domains_per_batch, args.n_support_domains) # clear grad optimizer.zero_grad() # compute output with higher.innerloop_ctx( model, optimizer, copy_initial_weights=False) as (inner_model, inner_optimizer): # perform inner optimization for _ in range(args.inner_iters): loss_inner = 0 for (x_s, labels_s) in support_domain_list: y_s, _ = inner_model(x_s) # normalize loss by support domain num loss_inner += F.cross_entropy( y_s, labels_s) / args.n_support_domains inner_optimizer.step(loss_inner) # calculate outer loss loss_outer = 0 cls_acc = 0 # loss on support domains for (x_s, labels_s) in support_domain_list: y_s, _ = model(x_s) # normalize loss by support domain num loss_outer += F.cross_entropy( y_s, labels_s) / args.n_support_domains # loss on query domains for (x_q, labels_q) in query_domain_list: y_q, _ = inner_model(x_q) # normalize loss by query domain num loss_outer += F.cross_entropy( y_q, labels_q) * args.trade_off / args.n_query_domains cls_acc += accuracy(y_q, labels_q)[0] / args.n_query_domains # update statistics losses.update(loss_outer.item(), args.batch_size) cls_accs.update(cls_acc.item(), args.batch_size) # compute gradient and do SGD step loss_outer.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, classifier: ImageClassifier, mdd: MarginDisparityDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode classifier.train() mdd.train() criterion = nn.CrossEntropyLoss().to(device) end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) outputs, outputs_adv = classifier(x) y_s, y_t = outputs.chunk(2, dim=0) y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0) # compute cross entropy loss on source domain cls_loss = criterion(y_s, labels_s) # compute margin disparity discrepancy between domains # for adversarial classifier, minimize negative mdd is equal to maximize mdd transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv) loss = cls_loss + transfer_loss * args.trade_off classifier.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, teacher: EmaTeacher, consistent_loss, class_balance_loss, optimizer: Adam, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') cons_losses = AverageMeter('Cons Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, cls_losses, cons_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() teacher.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) (x_t1, x_t2), labels_t = next(train_target_iter) x_s = x_s.to(device) x_t1 = x_t1.to(device) x_t2 = x_t2.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, _ = model(x_s) y_t, _ = model(x_t1) y_t_teacher, _ = teacher(x_t2) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # compute output and mask y_t = F.softmax(y_t, dim=1) y_t_teacher = F.softmax(y_t_teacher, dim=1) max_prob, _ = y_t_teacher.max(dim=1) mask = (max_prob > args.threshold).float() # consistent loss cons_loss = consistent_loss(y_t, y_t_teacher, mask) # balance loss balance_loss = class_balance_loss(y_t) * mask.mean() loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # update teacher teacher.update() # update statistics cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) cons_losses.update(cons_loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)