def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: StepwiseLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): lr_scheduler.step() # measure data loading time data_time.update(time.time() - end) x_s, labels_s = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses_transfer = AverageMeter('Transfer Loss', ':6.2f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_transfer, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode target_model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, _ = next(train_source_iter) x_t, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) # measure data loading time data_time.update(time.time() - end) _, f_s = source_model(x_s) _, f_t = target_model(x_t) loss_transfer = domain_adv(f_s, f_t) # Compute gradient and do SGD step optimizer.zero_grad() loss_transfer.backward() optimizer.step() lr_scheduler.step() losses_transfer.update(loss_transfer.item(), x_s.size(0)) domain_acc = domain_adv.domain_discriminator_accuracy domain_accs.update(domain_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv_D: DomainAdversarialLoss, domain_adv_D_0: DomainAdversarialLoss, importance_weight_module, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f') domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f') partial_classes_weights = AverageMeter('Partial Weight', ':3.2f') non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D, domain_accs_D_0, partial_classes_weights, non_partial_classes_weights ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv_D.train() domain_adv_D_0.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # domain adversarial loss for D adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach()) # get importance weights w_s = importance_weight_module.get_importance_weight(f_s) # domain adversarial loss for D_0 adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s) # entropy loss y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \ args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy, x_s.size(0)) domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy, x_s.size(0)) # debug: output class weight averaged on the partial classes and non-partial classes respectively partial_class_weight, non_partial_classes_weight = \ importance_weight_module.get_partial_classes_weight(w_s, labels_s) partial_classes_weights.update(partial_class_weight.item(), x_s.size(0)) non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Regressor, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') mse_losses = AverageMeter('MSE Loss', ':6.3f') dann_losses = AverageMeter('DANN Loss', ':6.3f') domain_accs = AverageMeter('Domain Acc', ':3.1f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter(args.iters_per_epoch, [ batch_time, data_time, mse_losses, dann_losses, mae_losses_s, mae_losses_t, domain_accs ], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) mse_loss = F.mse_loss(y_s, labels_s) mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) transfer_loss = domain_adv(f_s, f_t) loss = mse_loss + transfer_loss * args.trade_off domain_acc = domain_adv.domain_discriminator_accuracy mse_losses.update(mse_loss.item(), x_s.size(0)) dann_losses.update(transfer_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def train(feature_extractor: FeatureExtractor, domain_adv: DomainAdversarialLoss, src_iter: ForeverDataIterator, tar_iter: ForeverDataIterator, src_val_loader, tar_val_loader): optimizer = Adam(itertools.chain(feature_extractor.parameters(), domain_adv.parameters()), lr=lr, weight_decay=weight_decay) npair_loss = NPairsLoss() # n pair loss # loss loss_rec = AverageMeter('tot_loss', tb_tag='Loss/tot', writer=writer) loss_lb_rec = AverageMeter('lb_loss', tb_tag='Loss/lb', writer=writer) loss_lb_g_rec = AverageMeter('lb_g_loss', tb_tag='Loss/lb_g', writer=writer) # loss_ulb_rec = AverageMeter('ulb_loss', tb_tag='Loss/ulb') loss_da_rec = AverageMeter('da_loss', tb_tag='Loss/da', writer=writer) # acc da_acc_rec = AverageMeter('da_acc', tb_tag='Acc/da', writer=writer) n_iter = 0 best_nmi = 0 for e_i in range(epoch): feature_extractor.train() domain_adv.train() progress = ProgressMeter( iter_per_epoch, [loss_lb_g_rec, loss_lb_rec, loss_da_rec, da_acc_rec], prefix="Epoch: [{}]".format(e_i), logger=global_logger) for i in range(iter_per_epoch): x_s, l_s = next(src_iter) x_t, l_t = next(tar_iter) # for obj in [x_s, x_t, l_s, l_t]: # to device # obj = obj.to(device) x_s, l_s, x_t, l_t = x_s.to(device), l_s.to(device), x_t.to( device), l_t.to(device) x = torch.cat((x_s, x_t), dim=0) f, g = feature_extractor(x) f_s, f_t = f.chunk(2, dim=0) g_s, g_t = g.chunk(2, dim=0) # source only part loss_s = npair_loss(f_s, l_s) # get n-pair loss on source domain loss_s_g = npair_loss(g_s, l_s) # get n-pair loss on source domain loss_lb_rec.update(loss_s.item(), x_s.size(0), iter=n_iter) loss_lb_g_rec.update(loss_s_g.item(), x_s.size(0), iter=n_iter) # dann # da_loss = domain_adv(f_s,f_t) da_loss = domain_adv(g_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss_da_rec.update(da_loss.item(), f.size(0), iter=n_iter) da_acc_rec.update(domain_acc.item(), f.size(0), iter=n_iter) loss = 0.5 * (loss_s + loss_s_g) + w_da * da_loss # loss = loss_s optimizer.zero_grad() loss.backward() optimizer.step() n_iter += 1 if i % print_freq == 0: progress.display(i) if e_i % 5 == 0: # global_logger.info(f"saving embedding in epoch{e_i}") # # show embedding # show_embedding(backbone, [src_val_loader], tag=f'src_{e_i}', epoch=e_i, writer, device) # show_embedding(backbone, [tar_val_loader], tag=f'tar_{e_i}', epoch=e_i, writer, device) nmi = NMI_eval(feature_extractor, src_val_loader, 5, device, type='src') global_logger.info(f'test on train set nmi: {nmi}') nmi = NMI_eval(feature_extractor, tar_val_loader, 5, device, type='tar') global_logger.info(f'test on test set nmi: {nmi}') if nmi > best_nmi: global_logger.info(f"save best model to {model_dir}") torch.save(backbone.state_dict(), os.path.join(model_dir, 'minst_best_model.pth')) best_nmi = nmi