def SaveEvaluation(args, known_acc, auc): filefolder = osp.join('results', 'Test', 'accuracy', args.datasetname + '-' + args.split) mkdir(filefolder) filepath = osp.join( filefolder, 'adv-' + str(args.adv) + '-defense-' + str(args.defense) + '-' + args.denoisemean + '-' + str(args.defensesnapshot) + '.txt') output_file = open(filepath, 'w') output_file.write('Close-set Accuracy:\n' + str(np.array(known_acc.cpu()))) output_file.write('\nOpen-set AUROC:\n' + str(auc)) output_file.close()
def Test(args, FeatExtor, DepthEsmator, FeatEmbder, data_loader_target, savefilename): print("***The type of norm is: {}".format(normtype)) savepath = os.path.join(args.results_path, savefilename) mkdir(savepath) #################### # 1. setup network # #################### # set train state for Dropout and BN layers FeatExtor.eval() DepthEsmator.eval() FeatEmbder.eval() FeatExtor = DataParallel(FeatExtor) DepthEsmator = DataParallel(DepthEsmator) FeatEmbder = DataParallel(FeatEmbder) score_list = [] label_list = [] idx = 0 for (catimages, labels) in data_loader_target: images = catimages.cuda() # labels = labels.long().squeeze().cuda() feat_ext_all, feat_ext = FeatExtor(images) _, label_pred = FeatEmbder(feat_ext) score = F.sigmoid(label_pred).cpu().detach().numpy() labels = labels.numpy() score_list.append(score.squeeze()) label_list.append(labels) print('SampleNum:{} in total:{}, score:{}'.format( idx, len(data_loader_target), score.squeeze())) idx += 1 with h5py.File(os.path.join(savepath, 'Test_data.h5'), 'w') as hf: hf.create_dataset('score', data=score_list) hf.create_dataset('label', data=label_list)
def Train(args, FeatExtor, DepthEstor, FeatEmbder, data_loader1_real, data_loader1_fake, data_loader2_real, data_loader2_fake, data_loader3_real, data_loader3_fake, data_loader_target, summary_writer, Saver, savefilename): #################### # 1. setup network # #################### # set train state for Dropout and BN layers FeatExtor.train() DepthEstor.train() FeatEmbder.train() FeatExtor = DataParallel(FeatExtor) DepthEstor = DataParallel(DepthEstor) # setup criterion and optimizer criterionCls = nn.BCEWithLogitsLoss() criterionDepth = torch.nn.MSELoss() if args.optimizer_meta is 'adam': optimizer_all = optim.Adam(itertools.chain(FeatExtor.parameters(), DepthEstor.parameters(), FeatEmbder.parameters()), lr=args.lr_meta, betas=(args.beta1, args.beta2)) else: raise NotImplementedError('Not a suitable optimizer') iternum = max(len(data_loader1_real), len(data_loader1_fake), len(data_loader2_real), len(data_loader2_fake), len(data_loader3_real), len(data_loader3_fake)) print('iternum={}'.format(iternum)) #################### # 2. train network # #################### global_step = 0 for epoch in range(args.epochs): data1_real = get_inf_iterator(data_loader1_real) data1_fake = get_inf_iterator(data_loader1_fake) data2_real = get_inf_iterator(data_loader2_real) data2_fake = get_inf_iterator(data_loader2_fake) data3_real = get_inf_iterator(data_loader3_real) data3_fake = get_inf_iterator(data_loader3_fake) for step in range(iternum): #============ one batch extraction ============# cat_img1_real, depth_img1_real, lab1_real = next(data1_real) cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake) cat_img2_real, depth_img2_real, lab2_real = next(data2_real) cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake) cat_img3_real, depth_img3_real, lab3_real = next(data3_real) cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake) #============ one batch collection ============# catimg1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda() depth_img1 = torch.cat([depth_img1_real, depth_img1_fake], 0).cuda() lab1 = torch.cat([lab1_real, lab1_fake], 0).float().cuda() catimg2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda() depth_img2 = torch.cat([depth_img2_real, depth_img2_fake], 0).cuda() lab2 = torch.cat([lab2_real, lab2_fake], 0).float().cuda() catimg3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda() depth_img3 = torch.cat([depth_img3_real, depth_img3_fake], 0).cuda() lab3 = torch.cat([lab3_real, lab3_fake], 0).float().cuda() catimg = torch.cat([catimg1, catimg2, catimg3], 0) depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0) label = torch.cat([lab1, lab2, lab3], 0) #============ doamin list augmentation ============# catimglist = [catimg1, catimg2, catimg3] lablist = [lab1, lab2, lab3] deplist = [depth_img1, depth_img2, depth_img3] domain_list = list(range(len(catimglist))) random.shuffle(domain_list) meta_train_list = domain_list[:args.metatrainsize] meta_test_list = domain_list[args.metatrainsize:] print('metatrn={}, metatst={}'.format(meta_train_list, meta_test_list[0])) #============ meta training ============# Loss_dep_train = 0.0 Loss_cls_train = 0.0 adapted_state_dicts = [] for index in meta_train_list: catimg_meta = catimglist[index] lab_meta = lablist[index] depGT_meta = deplist[index] batchidx = list(range(len(catimg_meta))) random.shuffle(batchidx) img_rand = catimg_meta[batchidx, :] lab_rand = lab_meta[batchidx] depGT_rand = depGT_meta[batchidx, :] feat_ext_all, feat = FeatExtor(img_rand) pred = FeatEmbder(feat) depth_Pre = DepthEstor(feat_ext_all) Loss_cls = criterionCls(pred.squeeze(), lab_rand) Loss_dep = criterionDepth(depth_Pre, depGT_rand) Loss_dep_train += Loss_dep Loss_cls_train += Loss_cls zero_param_grad(FeatEmbder.parameters()) grads_FeatEmbder = torch.autograd.grad(Loss_cls, FeatEmbder.parameters(), create_graph=True) fast_weights_FeatEmbder = FeatEmbder.cloned_state_dict() adapted_params = OrderedDict() for (key, val), grad in zip(FeatEmbder.named_parameters(), grads_FeatEmbder): adapted_params[key] = val - args.meta_step_size * grad fast_weights_FeatEmbder[key] = adapted_params[key] adapted_state_dicts.append(fast_weights_FeatEmbder) #============ meta testing ============# Loss_dep_test = 0.0 Loss_cls_test = 0.0 index = meta_test_list[0] catimg_meta = catimglist[index] lab_meta = lablist[index] depGT_meta = deplist[index] batchidx = list(range(len(catimg_meta))) random.shuffle(batchidx) img_rand = catimg_meta[batchidx, :] lab_rand = lab_meta[batchidx] depGT_rand = depGT_meta[batchidx, :] feat_ext_all, feat = FeatExtor(img_rand) depth_Pre = DepthEstor(feat_ext_all) Loss_dep = criterionDepth(depth_Pre, depGT_rand) for n_scr in range(len(meta_train_list)): a_dict = adapted_state_dicts[n_scr] pred = FeatEmbder(feat, a_dict) Loss_cls = criterionCls(pred.squeeze(), lab_rand) Loss_cls_test += Loss_cls Loss_dep_test = Loss_dep Loss_dep_train_ave = Loss_dep_train / len(meta_train_list) Loss_dep_test = Loss_dep_test Loss_meta_train = Loss_cls_train + args.W_depth * Loss_dep_train Loss_meta_test = Loss_cls_test + args.W_depth * Loss_dep_test Loss_all = Loss_meta_train + args.W_metatest * Loss_meta_test optimizer_all.zero_grad() Loss_all.backward() optimizer_all.step() if (step + 1) % args.log_step == 0: errors = OrderedDict([ ('Loss_meta_train', Loss_meta_train.item()), ('Loss_meta_test', Loss_meta_test.item()), ('Loss_cls_train', Loss_cls_train.item()), ('Loss_cls_test', Loss_cls_test.item()), ('Loss_dep_train_ave', Loss_dep_train_ave.item()), ('Loss_dep_test', Loss_dep_test.item()), ]) Saver.print_current_errors((epoch + 1), (step + 1), errors) #============ tensorboard the log info ============# info = { 'Loss_meta_train': Loss_meta_train.item(), 'Loss_meta_test': Loss_meta_test.item(), 'Loss_cls_train': Loss_cls_train.item(), 'Loss_cls_test': Loss_cls_test.item(), 'Loss_dep_train_ave': Loss_dep_train_ave.item(), 'Loss_dep_test': Loss_dep_test.item(), } for tag, value in info.items(): summary_writer.add_scalar(tag, value, global_step) global_step += 1 ############################# # 2.4 save model parameters # ############################# if ((step + 1) % args.model_save_step == 0): model_save_path = os.path.join(args.results_path, 'snapshots', savefilename) mkdir(model_save_path) torch.save( FeatExtor.state_dict(), os.path.join( model_save_path, "FeatExtor-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( FeatEmbder.state_dict(), os.path.join( model_save_path, "FeatEmbder-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( DepthEstor.state_dict(), os.path.join( model_save_path, "DepthEstor-{}-{}.pt".format(epoch + 1, step + 1))) if ((epoch + 1) % args.model_save_epoch == 0): model_save_path = os.path.join(args.results_path, 'snapshots', savefilename) mkdir(model_save_path) torch.save( FeatExtor.state_dict(), os.path.join(model_save_path, "FeatExtor-{}.pt".format(epoch + 1))) torch.save( FeatEmbder.state_dict(), os.path.join(model_save_path, "FeatEmbder-{}.pt".format(epoch + 1))) torch.save( DepthEstor.state_dict(), os.path.join(model_save_path, "DepthEstor-{}.pt".format(epoch + 1))) torch.save(FeatExtor.state_dict(), os.path.join(model_save_path, "FeatExtor-final.pt")) torch.save(FeatEmbder.state_dict(), os.path.join(model_save_path, "FeatEmbder-final.pt")) torch.save(DepthEstor.state_dict(), os.path.join(model_save_path, "DepthEstor-final.pt"))
def Pre_train(args, FeatExtor, DepthEsmator, data_loader_real, data_loader_fake, summary_writer, saver, savefilename): # savepath = os.path.join(args.results_path, savefilename) # mkdir(savepath) #################### # 1. setup network # #################### # set train state for Dropout and BN layers FeatExtor.train() DepthEsmator.train() FeatExtor = DataParallel(FeatExtor) DepthEsmator = DataParallel(DepthEsmator) criterionDepth = torch.nn.MSELoss() optimizer_DG_depth = optim.Adam(list(FeatExtor.parameters()) + list(DepthEsmator.parameters()), lr=args.lr_DG_depth, betas=(args.beta1, args.beta2)) iternum = max(len(data_loader_real), len(data_loader_fake)) print('iternum={}'.format(iternum)) #################### # 2. train network # #################### global_step = 0 for epoch in range(args.pre_epochs): # epoch=epochNum+5 data_real = get_inf_iterator(data_loader_real) data_fake = get_inf_iterator(data_loader_fake) for step in range(iternum): cat_img_real, depth_img_real, lab_real = next(data_real) cat_img_fake, depth_img_fake, lab_fake = next(data_fake) ori_img = torch.cat([cat_img_real, cat_img_fake], 0) ori_img = ori_img.cuda() depth_img = torch.cat([depth_img_real, depth_img_fake], 0) depth_img = depth_img.cuda() feat_ext, _ = FeatExtor(ori_img) depth_Pre = DepthEsmator(feat_ext) Loss_depth = criterionDepth(depth_Pre, depth_img) optimizer_DG_depth.zero_grad() Loss_depth.backward() optimizer_DG_depth.step() info = { 'Loss_depth': Loss_depth.item(), } for tag, value in info.items(): summary_writer.add_scalar(tag, value, global_step) #============ print the log info ============# if (step + 1) % args.log_step == 0: errors = OrderedDict([('Loss_depth', Loss_depth.item())]) saver.print_current_errors((epoch + 1), (step + 1), errors) global_step += 1 if ((epoch + 1) % args.model_save_epoch == 0): model_save_path = os.path.join(args.results_path, 'snapshots', savefilename) mkdir(model_save_path) torch.save( FeatExtor.state_dict(), os.path.join(model_save_path, "DGFA-Ext-{}.pt".format(epoch + 1))) torch.save( DepthEsmator.state_dict(), os.path.join(model_save_path, "DGFA-Depth-{}.pt".format(epoch + 1))) torch.save(FeatExtor.state_dict(), os.path.join(model_save_path, "DGFA-Ext-final.pt")) torch.save(DepthEsmator.state_dict(), os.path.join(model_save_path, "DGFA-Depth-final.pt"))
def train_Ours(args, train_loader, val_loader, knownclass, Encoder, Decoder, NorClsfier, SSDClsfier, summary_writer, saver): seed = init_random_seed(args.manual_seed) criterionCls = nn.CrossEntropyLoss() criterionRec = nn.MSELoss() if args.parallel_train: Encoder = DataParallel(Encoder) Decoder = DataParallel(Decoder) NorClsfier = DataParallel(NorClsfier) SSDClsfier = DataParallel(SSDClsfier) optimizer = optim.Adam( list(Encoder.parameters()) + list(NorClsfier.parameters()) + list(SSDClsfier.parameters()) + list(Decoder.parameters()), lr=args.lr) if args.adv is 'PGDattack': print("**********Defense PGD Attack**********") elif args.adv is 'FGSMattack': print("**********Defense FGSM Attack**********") if args.adv is 'PGDattack': from advertorch.attacks import PGDAttack nor_adversary = PGDAttack(predict1=Encoder, predict2=NorClsfier, nb_iter=args.adv_iter) rot_adversary = PGDAttack(predict1=Encoder, predict2=SSDClsfier, nb_iter=args.adv_iter) elif args.adv is 'FGSMattack': from advertorch.attacks import GradientSignAttack nor_adversary = GradientSignAttack(predict1=Encoder, predict2=NorClsfier) rot_adversary = GradientSignAttack(predict1=Encoder, predict2=SSDClsfier) global_step = 0 # ---------- # Training # ---------- for epoch in range(args.n_epoch): Encoder.train() Decoder.train() NorClsfier.train() SSDClsfier.train() for steps, (orig, label, rot_orig, rot_label) in enumerate(train_loader): label = lab_conv(knownclass, label) orig, label = orig.cuda(), label.long().cuda() rot_orig, rot_label = rot_orig.cuda(), rot_label.long().cuda() with ctx_noparamgrad_and_eval(Encoder): with ctx_noparamgrad_and_eval(NorClsfier): with ctx_noparamgrad_and_eval(SSDClsfier): adv = nor_adversary.perturb(orig, label) rot_adv = rot_adversary.perturb(rot_orig, rot_label) latent_feat = Encoder(adv) norpred = NorClsfier(latent_feat) norlossCls = criterionCls(norpred, label) recon = Decoder(latent_feat) lossRec = criterionRec(recon, orig) ssdpred = SSDClsfier(Encoder(rot_adv)) rotlossCls = criterionCls(ssdpred, rot_label) loss = args.norClsWgt * norlossCls + args.rotClsWgt * rotlossCls + args.RecWgt * lossRec optimizer.zero_grad() loss.backward() optimizer.step() #============ tensorboard the log info ============# lossinfo = { 'loss': loss.item(), 'norlossCls': norlossCls.item(), 'lossRec': lossRec.item(), 'rotlossCls': rotlossCls.item(), } global_step += 1 #============ print the log info ============# if (steps + 1) % args.log_step == 0: errors = OrderedDict([ ('loss', loss.item()), ('norlossCls', norlossCls.item()), ('lossRec', lossRec.item()), ('rotlossCls', rotlossCls.item()), ]) saver.print_current_errors((epoch + 1), (steps + 1), errors) # evaluate performance on validation set periodically if ((epoch + 1) % args.val_epoch == 0): # switch model to evaluation mode Encoder.eval() NorClsfier.eval() running_corrects = 0.0 epoch_size = 0.0 val_loss_list = [] # calculate accuracy on validation set for steps, (images, label) in enumerate(val_loader): label = lab_conv(knownclass, label) images, label = images.cuda(), label.long().cuda() adv = nor_adversary.perturb(images, label) with torch.no_grad(): logits = NorClsfier(Encoder(adv)) _, preds = torch.max(logits, 1) running_corrects += torch.sum(preds == label.data) epoch_size += images.size(0) val_loss = criterionCls(logits, label) val_loss_list.append(val_loss.item()) val_loss_mean = sum(val_loss_list) / len(val_loss_list) val_acc = running_corrects.double() / epoch_size print('Val Acc: {:.4f}, Val Loss: {:.4f}'.format( val_acc, val_loss_mean)) valinfo = { 'Val Acc': val_acc.item(), 'Val Loss': val_loss.item(), } for tag, value in valinfo.items(): summary_writer.add_scalar(tag, value, (epoch + 1)) orig_show = vutils.make_grid(orig, normalize=True, scale_each=True) recon_show = vutils.make_grid(recon, normalize=True, scale_each=True) summary_writer.add_image('Ori_Image', orig_show, (epoch + 1)) summary_writer.add_image('Rec_Image', recon_show, (epoch + 1)) if ((epoch + 1) % args.model_save_epoch == 0): model_save_path = os.path.join(args.results_path, args.training_type, 'snapshots', args.datasetname + '-' + args.split, args.denoisemean, args.adv + str(args.adv_iter)) mkdir(model_save_path) torch.save( Encoder.state_dict(), os.path.join(model_save_path, "Encoder-{}.pt".format(epoch + 1))) torch.save( NorClsfier.state_dict(), os.path.join(model_save_path, "NorClsfier-{}.pt".format(epoch + 1))) torch.save( Decoder.state_dict(), os.path.join(model_save_path, "Decoder-{}.pt".format(epoch + 1))) torch.save(Encoder.state_dict(), os.path.join(model_save_path, "Encoder-final.pt")) torch.save(NorClsfier.state_dict(), os.path.join(model_save_path, "NorClsfier-final.pt")) torch.save(Decoder.state_dict(), os.path.join(model_save_path, "Decoder-final.pt"))
def Train(args, FeatExtor, DepthEsmator, FeatEmbder, Discriminator1, Discriminator2, Discriminator3, PreFeatExtorS1, PreFeatExtorS2, PreFeatExtorS3, data_loader1_real, data_loader1_fake, data_loader2_real, data_loader2_fake, data_loader3_real, data_loader3_fake, data_loader_target, summary_writer, Saver, savefilename): #################### # 1. setup network # #################### # set train state for Dropout and BN layers FeatExtor.train() FeatEmbder.train() DepthEsmator.train() Discriminator1.train() Discriminator2.train() Discriminator3.train() PreFeatExtorS1.eval() PreFeatExtorS2.eval() PreFeatExtorS3.eval() FeatExtor = DataParallel(FeatExtor) FeatEmbder = DataParallel(FeatEmbder) DepthEsmator = DataParallel(DepthEsmator) Discriminator1 = DataParallel(Discriminator1) Discriminator2 = DataParallel(Discriminator2) Discriminator3 = DataParallel(Discriminator3) PreFeatExtorS1 = DataParallel(PreFeatExtorS1) PreFeatExtorS2 = DataParallel(PreFeatExtorS2) PreFeatExtorS3 = DataParallel(PreFeatExtorS3) # setup criterion and optimizer criterionDepth = torch.nn.MSELoss() criterionAdv = loss.GANLoss() criterionCls = torch.nn.BCEWithLogitsLoss() optimizer_DG_depth = optim.Adam(itertools.chain(FeatExtor.parameters(), DepthEsmator.parameters()), lr=args.lr_DG_depth, betas=(args.beta1, args.beta2)) optimizer_DG_conf = optim.Adam(itertools.chain(FeatExtor.parameters(), FeatEmbder.parameters()), lr=args.lr_DG_conf, betas=(args.beta1, args.beta2)) optimizer_critic1 = optim.Adam(Discriminator1.parameters(), lr=args.lr_critic, betas=(args.beta1, args.beta2)) optimizer_critic2 = optim.Adam(Discriminator2.parameters(), lr=args.lr_critic, betas=(args.beta1, args.beta2)) optimizer_critic3 = optim.Adam(Discriminator3.parameters(), lr=args.lr_critic, betas=(args.beta1, args.beta2)) iternum = max(len(data_loader1_real), len(data_loader1_fake), len(data_loader2_real), len(data_loader2_fake), len(data_loader3_real), len(data_loader3_fake)) print('iternum={}'.format(iternum)) #################### # 2. train network # #################### global_step = 0 for epoch in range(args.epochs): data1_real = get_inf_iterator(data_loader1_real) data1_fake = get_inf_iterator(data_loader1_fake) data2_real = get_inf_iterator(data_loader2_real) data2_fake = get_inf_iterator(data_loader2_fake) data3_real = get_inf_iterator(data_loader3_real) data3_fake = get_inf_iterator(data_loader3_fake) for step in range(iternum): FeatExtor.train() FeatEmbder.train() DepthEsmator.train() Discriminator1.train() Discriminator2.train() Discriminator3.train() #============ one batch extraction ============# cat_img1_real, depth_img1_real, lab1_real = next(data1_real) cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake) cat_img2_real, depth_img2_real, lab2_real = next(data2_real) cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake) cat_img3_real, depth_img3_real, lab3_real = next(data3_real) cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake) #============ one batch collection ============# ori_img1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda() depth_img1 = torch.cat([depth_img1_real, depth_img1_fake], 0) lab1 = torch.cat([lab1_real, lab1_fake], 0) ori_img2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda() depth_img2 = torch.cat([depth_img2_real, depth_img2_fake], 0) lab2 = torch.cat([lab2_real, lab2_fake], 0) ori_img3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda() depth_img3 = torch.cat([depth_img3_real, depth_img3_fake], 0) lab3 = torch.cat([lab3_real, lab3_fake], 0) ori_img = torch.cat([ori_img1, ori_img2, ori_img3], 0) # ori_img = ori_img.cuda() depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0) depth_GT = depth_GT.cuda() label = torch.cat([lab1, lab2, lab3], 0) label = label.long().squeeze().cuda() with torch.no_grad(): pre_feat_ext1 = PreFeatExtorS1(ori_img1)[1] pre_feat_ext2 = PreFeatExtorS2(ori_img2)[1] pre_feat_ext3 = PreFeatExtorS3(ori_img3)[1] #============ Depth supervision ============# ######### 1. depth loss ######### optimizer_DG_depth.zero_grad() feat_ext_all, feat_ext = FeatExtor(ori_img) depth_Pre = DepthEsmator(feat_ext_all) Loss_depth = args.W_depth * criterionDepth(depth_Pre, depth_GT) Loss_depth.backward() optimizer_DG_depth.step() #============ domain generalization supervision ============# optimizer_DG_conf.zero_grad() _, feat_ext = FeatExtor(ori_img) feat_tgt = feat_ext #************************* confusion all **********************************# # predict on generator loss_generator1 = criterionAdv(Discriminator1(feat_tgt), True) loss_generator2 = criterionAdv(Discriminator2(feat_tgt), True) loss_generator3 = criterionAdv(Discriminator3(feat_tgt), True) feat_embd, label_pred = FeatEmbder(feat_ext) ########## cross-domain triplet loss ######### Loss_triplet = TripletLossCal(args, feat_embd, lab1, lab2, lab3) Loss_cls = criterionCls(label_pred.squeeze(), label.float()) Loss_gen = args.W_genave * (loss_generator1 + loss_generator2 + loss_generator3) Loss_G = args.W_trip * Loss_triplet + args.W_cls * Loss_cls + args.W_gen * Loss_gen Loss_G.backward() optimizer_DG_conf.step() #************************* confusion domain 1 with 2,3 **********************************# feat_src = torch.cat([pre_feat_ext1, pre_feat_ext1, pre_feat_ext1], 0) # predict on discriminator optimizer_critic1.zero_grad() real_loss = criterionAdv(Discriminator1(feat_src), True) fake_loss = criterionAdv(Discriminator1(feat_tgt.detach()), False) loss_critic1 = 0.5 * (real_loss + fake_loss) loss_critic1.backward() optimizer_critic1.step() #************************* confusion domain 2 with 1,3 **********************************# feat_src = torch.cat([pre_feat_ext2, pre_feat_ext2, pre_feat_ext2], 0) # predict on discriminator optimizer_critic2.zero_grad() real_loss = criterionAdv(Discriminator2(feat_src), True) fake_loss = criterionAdv(Discriminator2(feat_tgt.detach()), False) loss_critic2 = 0.5 * (real_loss + fake_loss) loss_critic2.backward() optimizer_critic2.step() #************************* confusion domain 3 with 1,2 **********************************# feat_src = torch.cat([pre_feat_ext3, pre_feat_ext3, pre_feat_ext3], 0) # predict on discriminator optimizer_critic3.zero_grad() real_loss = criterionAdv(Discriminator3(feat_src), True) fake_loss = criterionAdv(Discriminator3(feat_tgt.detach()), False) loss_critic3 = 0.5 * (real_loss + fake_loss) loss_critic3.backward() optimizer_critic3.step() #============ tensorboard the log info ============# info = { 'Loss_depth': Loss_depth.item(), 'Loss_triplet': Loss_triplet.item(), 'Loss_cls': Loss_cls.item(), 'Loss_G': Loss_G.item(), 'loss_critic1': loss_critic1.item(), 'loss_generator1': loss_generator1.item(), 'loss_critic2': loss_critic2.item(), 'loss_generator2': loss_generator2.item(), 'loss_critic3': loss_critic3.item(), 'loss_generator3': loss_generator3.item(), } for tag, value in info.items(): summary_writer.add_scalar(tag, value, global_step) if (step + 1) % args.tst_step == 0: depth_Pre_real = torch.cat([ depth_Pre[0:args.batchsize], depth_Pre[2 * args.batchsize:3 * args.batchsize], depth_Pre[4 * args.batchsize:5 * args.batchsize] ], 0) depth_Pre_fake = torch.cat([ depth_Pre[args.batchsize:2 * args.batchsize], depth_Pre[3 * args.batchsize:4 * args.batchsize], depth_Pre[5 * args.batchsize:6 * args.batchsize] ], 0) depth_Pre_all = vutils.make_grid(depth_Pre, normalize=True, scale_each=True) depth_Pre_real = vutils.make_grid(depth_Pre_real, normalize=True, scale_each=True) depth_Pre_fake = vutils.make_grid(depth_Pre_fake, normalize=True, scale_each=True) summary_writer.add_image('Depth_Image_all', depth_Pre_all, global_step) summary_writer.add_image('Depth_Image_real', depth_Pre_real, global_step) summary_writer.add_image('Depth_Image_fake', depth_Pre_fake, global_step) #============ print the log info ============# if (step + 1) % args.log_step == 0: errors = OrderedDict([ ('Loss_depth', Loss_depth.item()), ('Loss_triplet', Loss_triplet.item()), ('Loss_cls', Loss_cls.item()), ('Loss_G', Loss_G.item()), ('loss_critic1', loss_critic1.item()), ('loss_generator1', loss_generator1.item()), ('loss_critic2', loss_critic2.item()), ('loss_generator2', loss_generator2.item()), ('loss_critic3', loss_critic3.item()), ('loss_generator3', loss_generator3.item()) ]) Saver.print_current_errors((epoch + 1), (step + 1), errors) if (step + 1) % args.tst_step == 0: evaluate.evaluate_img(FeatExtor, DepthEsmator, data_loader_target, (epoch + 1), (step + 1), Saver) global_step += 1 ############################# # 2.4 save model parameters # ############################# if ((step + 1) % args.model_save_step == 0): model_save_path = os.path.join(args.results_path, 'snapshots', savefilename) mkdir(model_save_path) torch.save( FeatExtor.state_dict(), os.path.join( model_save_path, "DGFA-Ext-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( FeatEmbder.state_dict(), os.path.join( model_save_path, "DGFA-Embd-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( DepthEsmator.state_dict(), os.path.join( model_save_path, "DGFA-Depth-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( Discriminator1.state_dict(), os.path.join( model_save_path, "DGFA-D1-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( Discriminator2.state_dict(), os.path.join( model_save_path, "DGFA-D2-{}-{}.pt".format(epoch + 1, step + 1))) torch.save( Discriminator3.state_dict(), os.path.join( model_save_path, "DGFA-D3-{}-{}.pt".format(epoch + 1, step + 1))) if ((epoch + 1) % args.model_save_epoch == 0): model_save_path = os.path.join(args.results_path, 'snapshots', savefilename) mkdir(model_save_path) torch.save( FeatExtor.state_dict(), os.path.join(model_save_path, "DGFA-Ext-{}.pt".format(epoch + 1))) torch.save( FeatEmbder.state_dict(), os.path.join(model_save_path, "DGFA-Embd-{}.pt".format(epoch + 1))) torch.save( DepthEsmator.state_dict(), os.path.join(model_save_path, "DGFA-Depth-{}.pt".format(epoch + 1))) torch.save( Discriminator1.state_dict(), os.path.join(model_save_path, "DGFA-D1-{}.pt".format(epoch + 1))) torch.save( Discriminator2.state_dict(), os.path.join(model_save_path, "DGFA-D2-{}.pt".format(epoch + 1))) torch.save( Discriminator3.state_dict(), os.path.join(model_save_path, "DGFA-D3-{}.pt".format(epoch + 1))) torch.save(FeatExtor.state_dict(), os.path.join(model_save_path, "DGFA-Ext-final.pt")) torch.save(FeatEmbder.state_dict(), os.path.join(model_save_path, "DGFA-Embd-final.pt")) torch.save(DepthEsmator.state_dict(), os.path.join(model_save_path, "DGFA-Depth-final.pt")) torch.save(Discriminator1.state_dict(), os.path.join(model_save_path, "DGFA-D1-final.pt")) torch.save(Discriminator2.state_dict(), os.path.join(model_save_path, "DGFA-D2-final.pt")) torch.save(Discriminator3.state_dict(), os.path.join(model_save_path, "DGFA-D3-final.pt"))