Beispiel #1
0
def SaveEvaluation(args, known_acc, auc):

    filefolder = osp.join('results', 'Test', 'accuracy',
                          args.datasetname + '-' + args.split)
    mkdir(filefolder)

    filepath = osp.join(
        filefolder, 'adv-' + str(args.adv) + '-defense-' + str(args.defense) +
        '-' + args.denoisemean + '-' + str(args.defensesnapshot) + '.txt')

    output_file = open(filepath, 'w')
    output_file.write('Close-set Accuracy:\n' + str(np.array(known_acc.cpu())))
    output_file.write('\nOpen-set AUROC:\n' + str(auc))
    output_file.close()
Beispiel #2
0
def Test(args, FeatExtor, DepthEsmator, FeatEmbder, data_loader_target,
         savefilename):

    print("***The type of norm is: {}".format(normtype))

    savepath = os.path.join(args.results_path, savefilename)
    mkdir(savepath)
    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.eval()
    DepthEsmator.eval()
    FeatEmbder.eval()

    FeatExtor = DataParallel(FeatExtor)
    DepthEsmator = DataParallel(DepthEsmator)
    FeatEmbder = DataParallel(FeatEmbder)

    score_list = []
    label_list = []

    idx = 0

    for (catimages, labels) in data_loader_target:

        images = catimages.cuda()
        # labels = labels.long().squeeze().cuda()

        feat_ext_all, feat_ext = FeatExtor(images)

        _, label_pred = FeatEmbder(feat_ext)
        score = F.sigmoid(label_pred).cpu().detach().numpy()

        labels = labels.numpy()

        score_list.append(score.squeeze())
        label_list.append(labels)

        print('SampleNum:{} in total:{}, score:{}'.format(
            idx, len(data_loader_target), score.squeeze()))

        idx += 1

    with h5py.File(os.path.join(savepath, 'Test_data.h5'), 'w') as hf:
        hf.create_dataset('score', data=score_list)
        hf.create_dataset('label', data=label_list)
Beispiel #3
0
def Train(args, FeatExtor, DepthEstor, FeatEmbder, data_loader1_real,
          data_loader1_fake, data_loader2_real, data_loader2_fake,
          data_loader3_real, data_loader3_fake, data_loader_target,
          summary_writer, Saver, savefilename):

    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    DepthEstor.train()
    FeatEmbder.train()

    FeatExtor = DataParallel(FeatExtor)
    DepthEstor = DataParallel(DepthEstor)

    # setup criterion and optimizer
    criterionCls = nn.BCEWithLogitsLoss()
    criterionDepth = torch.nn.MSELoss()

    if args.optimizer_meta is 'adam':

        optimizer_all = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                   DepthEstor.parameters(),
                                                   FeatEmbder.parameters()),
                                   lr=args.lr_meta,
                                   betas=(args.beta1, args.beta2))

    else:
        raise NotImplementedError('Not a suitable optimizer')

    iternum = max(len(data_loader1_real), len(data_loader1_fake),
                  len(data_loader2_real), len(data_loader2_fake),
                  len(data_loader3_real), len(data_loader3_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.epochs):

        data1_real = get_inf_iterator(data_loader1_real)
        data1_fake = get_inf_iterator(data_loader1_fake)

        data2_real = get_inf_iterator(data_loader2_real)
        data2_fake = get_inf_iterator(data_loader2_fake)

        data3_real = get_inf_iterator(data_loader3_real)
        data3_fake = get_inf_iterator(data_loader3_fake)

        for step in range(iternum):

            #============ one batch extraction ============#

            cat_img1_real, depth_img1_real, lab1_real = next(data1_real)
            cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake)

            cat_img2_real, depth_img2_real, lab2_real = next(data2_real)
            cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake)

            cat_img3_real, depth_img3_real, lab3_real = next(data3_real)
            cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake)

            #============ one batch collection ============#

            catimg1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda()
            depth_img1 = torch.cat([depth_img1_real, depth_img1_fake],
                                   0).cuda()
            lab1 = torch.cat([lab1_real, lab1_fake], 0).float().cuda()

            catimg2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda()
            depth_img2 = torch.cat([depth_img2_real, depth_img2_fake],
                                   0).cuda()
            lab2 = torch.cat([lab2_real, lab2_fake], 0).float().cuda()

            catimg3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda()
            depth_img3 = torch.cat([depth_img3_real, depth_img3_fake],
                                   0).cuda()
            lab3 = torch.cat([lab3_real, lab3_fake], 0).float().cuda()

            catimg = torch.cat([catimg1, catimg2, catimg3], 0)
            depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0)
            label = torch.cat([lab1, lab2, lab3], 0)

            #============ doamin list augmentation ============#
            catimglist = [catimg1, catimg2, catimg3]
            lablist = [lab1, lab2, lab3]
            deplist = [depth_img1, depth_img2, depth_img3]

            domain_list = list(range(len(catimglist)))
            random.shuffle(domain_list)

            meta_train_list = domain_list[:args.metatrainsize]
            meta_test_list = domain_list[args.metatrainsize:]
            print('metatrn={}, metatst={}'.format(meta_train_list,
                                                  meta_test_list[0]))

            #============ meta training ============#

            Loss_dep_train = 0.0
            Loss_cls_train = 0.0

            adapted_state_dicts = []

            for index in meta_train_list:

                catimg_meta = catimglist[index]
                lab_meta = lablist[index]
                depGT_meta = deplist[index]

                batchidx = list(range(len(catimg_meta)))
                random.shuffle(batchidx)

                img_rand = catimg_meta[batchidx, :]
                lab_rand = lab_meta[batchidx]
                depGT_rand = depGT_meta[batchidx, :]

                feat_ext_all, feat = FeatExtor(img_rand)
                pred = FeatEmbder(feat)
                depth_Pre = DepthEstor(feat_ext_all)

                Loss_cls = criterionCls(pred.squeeze(), lab_rand)
                Loss_dep = criterionDepth(depth_Pre, depGT_rand)

                Loss_dep_train += Loss_dep
                Loss_cls_train += Loss_cls

                zero_param_grad(FeatEmbder.parameters())
                grads_FeatEmbder = torch.autograd.grad(Loss_cls,
                                                       FeatEmbder.parameters(),
                                                       create_graph=True)
                fast_weights_FeatEmbder = FeatEmbder.cloned_state_dict()

                adapted_params = OrderedDict()
                for (key, val), grad in zip(FeatEmbder.named_parameters(),
                                            grads_FeatEmbder):
                    adapted_params[key] = val - args.meta_step_size * grad
                    fast_weights_FeatEmbder[key] = adapted_params[key]

                adapted_state_dicts.append(fast_weights_FeatEmbder)

            #============ meta testing ============#
            Loss_dep_test = 0.0
            Loss_cls_test = 0.0

            index = meta_test_list[0]

            catimg_meta = catimglist[index]
            lab_meta = lablist[index]
            depGT_meta = deplist[index]

            batchidx = list(range(len(catimg_meta)))
            random.shuffle(batchidx)

            img_rand = catimg_meta[batchidx, :]
            lab_rand = lab_meta[batchidx]
            depGT_rand = depGT_meta[batchidx, :]

            feat_ext_all, feat = FeatExtor(img_rand)
            depth_Pre = DepthEstor(feat_ext_all)
            Loss_dep = criterionDepth(depth_Pre, depGT_rand)

            for n_scr in range(len(meta_train_list)):
                a_dict = adapted_state_dicts[n_scr]

                pred = FeatEmbder(feat, a_dict)
                Loss_cls = criterionCls(pred.squeeze(), lab_rand)

                Loss_cls_test += Loss_cls

            Loss_dep_test = Loss_dep

            Loss_dep_train_ave = Loss_dep_train / len(meta_train_list)
            Loss_dep_test = Loss_dep_test

            Loss_meta_train = Loss_cls_train + args.W_depth * Loss_dep_train
            Loss_meta_test = Loss_cls_test + args.W_depth * Loss_dep_test

            Loss_all = Loss_meta_train + args.W_metatest * Loss_meta_test

            optimizer_all.zero_grad()
            Loss_all.backward()
            optimizer_all.step()

            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([
                    ('Loss_meta_train', Loss_meta_train.item()),
                    ('Loss_meta_test', Loss_meta_test.item()),
                    ('Loss_cls_train', Loss_cls_train.item()),
                    ('Loss_cls_test', Loss_cls_test.item()),
                    ('Loss_dep_train_ave', Loss_dep_train_ave.item()),
                    ('Loss_dep_test', Loss_dep_test.item()),
                ])
                Saver.print_current_errors((epoch + 1), (step + 1), errors)

            #============ tensorboard the log info ============#
            info = {
                'Loss_meta_train': Loss_meta_train.item(),
                'Loss_meta_test': Loss_meta_test.item(),
                'Loss_cls_train': Loss_cls_train.item(),
                'Loss_cls_test': Loss_cls_test.item(),
                'Loss_dep_train_ave': Loss_dep_train_ave.item(),
                'Loss_dep_test': Loss_dep_test.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            global_step += 1

            #############################
            # 2.4 save model parameters #
            #############################
            if ((step + 1) % args.model_save_step == 0):
                model_save_path = os.path.join(args.results_path, 'snapshots',
                                               savefilename)
                mkdir(model_save_path)

                torch.save(
                    FeatExtor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "FeatExtor-{}-{}.pt".format(epoch + 1, step + 1)))
                torch.save(
                    FeatEmbder.state_dict(),
                    os.path.join(
                        model_save_path,
                        "FeatEmbder-{}-{}.pt".format(epoch + 1, step + 1)))
                torch.save(
                    DepthEstor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DepthEstor-{}-{}.pt".format(epoch + 1, step + 1)))

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "FeatExtor-{}.pt".format(epoch + 1)))
            torch.save(
                FeatEmbder.state_dict(),
                os.path.join(model_save_path,
                             "FeatEmbder-{}.pt".format(epoch + 1)))
            torch.save(
                DepthEstor.state_dict(),
                os.path.join(model_save_path,
                             "DepthEstor-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "FeatExtor-final.pt"))
    torch.save(FeatEmbder.state_dict(),
               os.path.join(model_save_path, "FeatEmbder-final.pt"))
    torch.save(DepthEstor.state_dict(),
               os.path.join(model_save_path, "DepthEstor-final.pt"))
Beispiel #4
0
def Pre_train(args, FeatExtor, DepthEsmator, data_loader_real,
              data_loader_fake, summary_writer, saver, savefilename):

    # savepath = os.path.join(args.results_path, savefilename)
    # mkdir(savepath)
    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    DepthEsmator.train()

    FeatExtor = DataParallel(FeatExtor)
    DepthEsmator = DataParallel(DepthEsmator)

    criterionDepth = torch.nn.MSELoss()

    optimizer_DG_depth = optim.Adam(list(FeatExtor.parameters()) +
                                    list(DepthEsmator.parameters()),
                                    lr=args.lr_DG_depth,
                                    betas=(args.beta1, args.beta2))

    iternum = max(len(data_loader_real), len(data_loader_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.pre_epochs):

        # epoch=epochNum+5

        data_real = get_inf_iterator(data_loader_real)
        data_fake = get_inf_iterator(data_loader_fake)

        for step in range(iternum):

            cat_img_real, depth_img_real, lab_real = next(data_real)
            cat_img_fake, depth_img_fake, lab_fake = next(data_fake)

            ori_img = torch.cat([cat_img_real, cat_img_fake], 0)
            ori_img = ori_img.cuda()

            depth_img = torch.cat([depth_img_real, depth_img_fake], 0)
            depth_img = depth_img.cuda()

            feat_ext, _ = FeatExtor(ori_img)
            depth_Pre = DepthEsmator(feat_ext)

            Loss_depth = criterionDepth(depth_Pre, depth_img)

            optimizer_DG_depth.zero_grad()
            Loss_depth.backward()
            optimizer_DG_depth.step()

            info = {
                'Loss_depth': Loss_depth.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            #============ print the log info ============#
            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([('Loss_depth', Loss_depth.item())])
                saver.print_current_errors((epoch + 1), (step + 1), errors)

            global_step += 1

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Ext-{}.pt".format(epoch + 1)))
            torch.save(
                DepthEsmator.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Depth-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "DGFA-Ext-final.pt"))
    torch.save(DepthEsmator.state_dict(),
               os.path.join(model_save_path, "DGFA-Depth-final.pt"))
Beispiel #5
0
def train_Ours(args, train_loader, val_loader, knownclass, Encoder, Decoder,
               NorClsfier, SSDClsfier, summary_writer, saver):
    seed = init_random_seed(args.manual_seed)

    criterionCls = nn.CrossEntropyLoss()
    criterionRec = nn.MSELoss()

    if args.parallel_train:
        Encoder = DataParallel(Encoder)
        Decoder = DataParallel(Decoder)
        NorClsfier = DataParallel(NorClsfier)
        SSDClsfier = DataParallel(SSDClsfier)

    optimizer = optim.Adam(
        list(Encoder.parameters()) + list(NorClsfier.parameters()) +
        list(SSDClsfier.parameters()) + list(Decoder.parameters()),
        lr=args.lr)

    if args.adv is 'PGDattack':
        print("**********Defense PGD Attack**********")
    elif args.adv is 'FGSMattack':
        print("**********Defense FGSM Attack**********")

    if args.adv is 'PGDattack':
        from advertorch.attacks import PGDAttack
        nor_adversary = PGDAttack(predict1=Encoder,
                                  predict2=NorClsfier,
                                  nb_iter=args.adv_iter)
        rot_adversary = PGDAttack(predict1=Encoder,
                                  predict2=SSDClsfier,
                                  nb_iter=args.adv_iter)

    elif args.adv is 'FGSMattack':
        from advertorch.attacks import GradientSignAttack
        nor_adversary = GradientSignAttack(predict1=Encoder,
                                           predict2=NorClsfier)
        rot_adversary = GradientSignAttack(predict1=Encoder,
                                           predict2=SSDClsfier)

    global_step = 0
    # ----------
    #  Training
    # ----------
    for epoch in range(args.n_epoch):

        Encoder.train()
        Decoder.train()
        NorClsfier.train()
        SSDClsfier.train()

        for steps, (orig, label, rot_orig,
                    rot_label) in enumerate(train_loader):

            label = lab_conv(knownclass, label)
            orig, label = orig.cuda(), label.long().cuda()

            rot_orig, rot_label = rot_orig.cuda(), rot_label.long().cuda()

            with ctx_noparamgrad_and_eval(Encoder):
                with ctx_noparamgrad_and_eval(NorClsfier):
                    with ctx_noparamgrad_and_eval(SSDClsfier):
                        adv = nor_adversary.perturb(orig, label)
                        rot_adv = rot_adversary.perturb(rot_orig, rot_label)

            latent_feat = Encoder(adv)
            norpred = NorClsfier(latent_feat)
            norlossCls = criterionCls(norpred, label)

            recon = Decoder(latent_feat)
            lossRec = criterionRec(recon, orig)

            ssdpred = SSDClsfier(Encoder(rot_adv))
            rotlossCls = criterionCls(ssdpred, rot_label)

            loss = args.norClsWgt * norlossCls + args.rotClsWgt * rotlossCls + args.RecWgt * lossRec

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #============ tensorboard the log info ============#
            lossinfo = {
                'loss': loss.item(),
                'norlossCls': norlossCls.item(),
                'lossRec': lossRec.item(),
                'rotlossCls': rotlossCls.item(),
            }

            global_step += 1

            #============ print the log info ============#
            if (steps + 1) % args.log_step == 0:
                errors = OrderedDict([
                    ('loss', loss.item()),
                    ('norlossCls', norlossCls.item()),
                    ('lossRec', lossRec.item()),
                    ('rotlossCls', rotlossCls.item()),
                ])

                saver.print_current_errors((epoch + 1), (steps + 1), errors)

        # evaluate performance on validation set periodically
        if ((epoch + 1) % args.val_epoch == 0):

            # switch model to evaluation mode
            Encoder.eval()
            NorClsfier.eval()

            running_corrects = 0.0
            epoch_size = 0.0
            val_loss_list = []

            # calculate accuracy on validation set
            for steps, (images, label) in enumerate(val_loader):

                label = lab_conv(knownclass, label)
                images, label = images.cuda(), label.long().cuda()

                adv = nor_adversary.perturb(images, label)

                with torch.no_grad():
                    logits = NorClsfier(Encoder(adv))
                    _, preds = torch.max(logits, 1)
                    running_corrects += torch.sum(preds == label.data)
                    epoch_size += images.size(0)

                    val_loss = criterionCls(logits, label)

                    val_loss_list.append(val_loss.item())

            val_loss_mean = sum(val_loss_list) / len(val_loss_list)

            val_acc = running_corrects.double() / epoch_size
            print('Val Acc: {:.4f}, Val Loss: {:.4f}'.format(
                val_acc, val_loss_mean))

            valinfo = {
                'Val Acc': val_acc.item(),
                'Val Loss': val_loss.item(),
            }
            for tag, value in valinfo.items():
                summary_writer.add_scalar(tag, value, (epoch + 1))

            orig_show = vutils.make_grid(orig, normalize=True, scale_each=True)
            recon_show = vutils.make_grid(recon,
                                          normalize=True,
                                          scale_each=True)

            summary_writer.add_image('Ori_Image', orig_show, (epoch + 1))
            summary_writer.add_image('Rec_Image', recon_show, (epoch + 1))

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path,
                                           args.training_type, 'snapshots',
                                           args.datasetname + '-' + args.split,
                                           args.denoisemean,
                                           args.adv + str(args.adv_iter))
            mkdir(model_save_path)
            torch.save(
                Encoder.state_dict(),
                os.path.join(model_save_path,
                             "Encoder-{}.pt".format(epoch + 1)))
            torch.save(
                NorClsfier.state_dict(),
                os.path.join(model_save_path,
                             "NorClsfier-{}.pt".format(epoch + 1)))
            torch.save(
                Decoder.state_dict(),
                os.path.join(model_save_path,
                             "Decoder-{}.pt".format(epoch + 1)))

    torch.save(Encoder.state_dict(),
               os.path.join(model_save_path, "Encoder-final.pt"))
    torch.save(NorClsfier.state_dict(),
               os.path.join(model_save_path, "NorClsfier-final.pt"))
    torch.save(Decoder.state_dict(),
               os.path.join(model_save_path, "Decoder-final.pt"))
Beispiel #6
0
def Train(args, FeatExtor, DepthEsmator, FeatEmbder, Discriminator1,
          Discriminator2, Discriminator3, PreFeatExtorS1, PreFeatExtorS2,
          PreFeatExtorS3, data_loader1_real, data_loader1_fake,
          data_loader2_real, data_loader2_fake, data_loader3_real,
          data_loader3_fake, data_loader_target, summary_writer, Saver,
          savefilename):

    ####################
    # 1. setup network #
    ####################
    # set train state for Dropout and BN layers
    FeatExtor.train()
    FeatEmbder.train()
    DepthEsmator.train()
    Discriminator1.train()
    Discriminator2.train()
    Discriminator3.train()

    PreFeatExtorS1.eval()
    PreFeatExtorS2.eval()
    PreFeatExtorS3.eval()

    FeatExtor = DataParallel(FeatExtor)
    FeatEmbder = DataParallel(FeatEmbder)
    DepthEsmator = DataParallel(DepthEsmator)
    Discriminator1 = DataParallel(Discriminator1)
    Discriminator2 = DataParallel(Discriminator2)
    Discriminator3 = DataParallel(Discriminator3)

    PreFeatExtorS1 = DataParallel(PreFeatExtorS1)
    PreFeatExtorS2 = DataParallel(PreFeatExtorS2)
    PreFeatExtorS3 = DataParallel(PreFeatExtorS3)

    # setup criterion and optimizer
    criterionDepth = torch.nn.MSELoss()
    criterionAdv = loss.GANLoss()
    criterionCls = torch.nn.BCEWithLogitsLoss()

    optimizer_DG_depth = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                    DepthEsmator.parameters()),
                                    lr=args.lr_DG_depth,
                                    betas=(args.beta1, args.beta2))

    optimizer_DG_conf = optim.Adam(itertools.chain(FeatExtor.parameters(),
                                                   FeatEmbder.parameters()),
                                   lr=args.lr_DG_conf,
                                   betas=(args.beta1, args.beta2))

    optimizer_critic1 = optim.Adam(Discriminator1.parameters(),
                                   lr=args.lr_critic,
                                   betas=(args.beta1, args.beta2))

    optimizer_critic2 = optim.Adam(Discriminator2.parameters(),
                                   lr=args.lr_critic,
                                   betas=(args.beta1, args.beta2))

    optimizer_critic3 = optim.Adam(Discriminator3.parameters(),
                                   lr=args.lr_critic,
                                   betas=(args.beta1, args.beta2))

    iternum = max(len(data_loader1_real), len(data_loader1_fake),
                  len(data_loader2_real), len(data_loader2_fake),
                  len(data_loader3_real), len(data_loader3_fake))

    print('iternum={}'.format(iternum))

    ####################
    # 2. train network #
    ####################
    global_step = 0

    for epoch in range(args.epochs):

        data1_real = get_inf_iterator(data_loader1_real)
        data1_fake = get_inf_iterator(data_loader1_fake)

        data2_real = get_inf_iterator(data_loader2_real)
        data2_fake = get_inf_iterator(data_loader2_fake)

        data3_real = get_inf_iterator(data_loader3_real)
        data3_fake = get_inf_iterator(data_loader3_fake)

        for step in range(iternum):

            FeatExtor.train()
            FeatEmbder.train()
            DepthEsmator.train()
            Discriminator1.train()
            Discriminator2.train()
            Discriminator3.train()

            #============ one batch extraction ============#

            cat_img1_real, depth_img1_real, lab1_real = next(data1_real)
            cat_img1_fake, depth_img1_fake, lab1_fake = next(data1_fake)

            cat_img2_real, depth_img2_real, lab2_real = next(data2_real)
            cat_img2_fake, depth_img2_fake, lab2_fake = next(data2_fake)

            cat_img3_real, depth_img3_real, lab3_real = next(data3_real)
            cat_img3_fake, depth_img3_fake, lab3_fake = next(data3_fake)

            #============ one batch collection ============#

            ori_img1 = torch.cat([cat_img1_real, cat_img1_fake], 0).cuda()
            depth_img1 = torch.cat([depth_img1_real, depth_img1_fake], 0)
            lab1 = torch.cat([lab1_real, lab1_fake], 0)

            ori_img2 = torch.cat([cat_img2_real, cat_img2_fake], 0).cuda()
            depth_img2 = torch.cat([depth_img2_real, depth_img2_fake], 0)
            lab2 = torch.cat([lab2_real, lab2_fake], 0)

            ori_img3 = torch.cat([cat_img3_real, cat_img3_fake], 0).cuda()
            depth_img3 = torch.cat([depth_img3_real, depth_img3_fake], 0)
            lab3 = torch.cat([lab3_real, lab3_fake], 0)

            ori_img = torch.cat([ori_img1, ori_img2, ori_img3], 0)
            # ori_img = ori_img.cuda()

            depth_GT = torch.cat([depth_img1, depth_img2, depth_img3], 0)
            depth_GT = depth_GT.cuda()

            label = torch.cat([lab1, lab2, lab3], 0)
            label = label.long().squeeze().cuda()

            with torch.no_grad():
                pre_feat_ext1 = PreFeatExtorS1(ori_img1)[1]
                pre_feat_ext2 = PreFeatExtorS2(ori_img2)[1]
                pre_feat_ext3 = PreFeatExtorS3(ori_img3)[1]

            #============ Depth supervision ============#

            ######### 1. depth loss #########
            optimizer_DG_depth.zero_grad()

            feat_ext_all, feat_ext = FeatExtor(ori_img)
            depth_Pre = DepthEsmator(feat_ext_all)

            Loss_depth = args.W_depth * criterionDepth(depth_Pre, depth_GT)

            Loss_depth.backward()
            optimizer_DG_depth.step()

            #============ domain generalization supervision ============#

            optimizer_DG_conf.zero_grad()

            _, feat_ext = FeatExtor(ori_img)

            feat_tgt = feat_ext

            #************************* confusion all **********************************#

            # predict on generator
            loss_generator1 = criterionAdv(Discriminator1(feat_tgt), True)

            loss_generator2 = criterionAdv(Discriminator2(feat_tgt), True)

            loss_generator3 = criterionAdv(Discriminator3(feat_tgt), True)

            feat_embd, label_pred = FeatEmbder(feat_ext)

            ########## cross-domain triplet loss #########
            Loss_triplet = TripletLossCal(args, feat_embd, lab1, lab2, lab3)

            Loss_cls = criterionCls(label_pred.squeeze(), label.float())

            Loss_gen = args.W_genave * (loss_generator1 + loss_generator2 +
                                        loss_generator3)

            Loss_G = args.W_trip * Loss_triplet + args.W_cls * Loss_cls + args.W_gen * Loss_gen

            Loss_G.backward()
            optimizer_DG_conf.step()

            #************************* confusion domain 1 with 2,3 **********************************#

            feat_src = torch.cat([pre_feat_ext1, pre_feat_ext1, pre_feat_ext1],
                                 0)

            # predict on discriminator
            optimizer_critic1.zero_grad()

            real_loss = criterionAdv(Discriminator1(feat_src), True)
            fake_loss = criterionAdv(Discriminator1(feat_tgt.detach()), False)

            loss_critic1 = 0.5 * (real_loss + fake_loss)

            loss_critic1.backward()
            optimizer_critic1.step()

            #************************* confusion domain 2 with 1,3 **********************************#

            feat_src = torch.cat([pre_feat_ext2, pre_feat_ext2, pre_feat_ext2],
                                 0)

            # predict on discriminator
            optimizer_critic2.zero_grad()

            real_loss = criterionAdv(Discriminator2(feat_src), True)
            fake_loss = criterionAdv(Discriminator2(feat_tgt.detach()), False)

            loss_critic2 = 0.5 * (real_loss + fake_loss)

            loss_critic2.backward()
            optimizer_critic2.step()

            #************************* confusion domain 3 with 1,2 **********************************#

            feat_src = torch.cat([pre_feat_ext3, pre_feat_ext3, pre_feat_ext3],
                                 0)

            # predict on discriminator
            optimizer_critic3.zero_grad()

            real_loss = criterionAdv(Discriminator3(feat_src), True)
            fake_loss = criterionAdv(Discriminator3(feat_tgt.detach()), False)

            loss_critic3 = 0.5 * (real_loss + fake_loss)

            loss_critic3.backward()
            optimizer_critic3.step()

            #============ tensorboard the log info ============#
            info = {
                'Loss_depth': Loss_depth.item(),
                'Loss_triplet': Loss_triplet.item(),
                'Loss_cls': Loss_cls.item(),
                'Loss_G': Loss_G.item(),
                'loss_critic1': loss_critic1.item(),
                'loss_generator1': loss_generator1.item(),
                'loss_critic2': loss_critic2.item(),
                'loss_generator2': loss_generator2.item(),
                'loss_critic3': loss_critic3.item(),
                'loss_generator3': loss_generator3.item(),
            }
            for tag, value in info.items():
                summary_writer.add_scalar(tag, value, global_step)

            if (step + 1) % args.tst_step == 0:
                depth_Pre_real = torch.cat([
                    depth_Pre[0:args.batchsize],
                    depth_Pre[2 * args.batchsize:3 * args.batchsize],
                    depth_Pre[4 * args.batchsize:5 * args.batchsize]
                ], 0)
                depth_Pre_fake = torch.cat([
                    depth_Pre[args.batchsize:2 * args.batchsize],
                    depth_Pre[3 * args.batchsize:4 * args.batchsize],
                    depth_Pre[5 * args.batchsize:6 * args.batchsize]
                ], 0)

                depth_Pre_all = vutils.make_grid(depth_Pre,
                                                 normalize=True,
                                                 scale_each=True)
                depth_Pre_real = vutils.make_grid(depth_Pre_real,
                                                  normalize=True,
                                                  scale_each=True)
                depth_Pre_fake = vutils.make_grid(depth_Pre_fake,
                                                  normalize=True,
                                                  scale_each=True)

                summary_writer.add_image('Depth_Image_all', depth_Pre_all,
                                         global_step)
                summary_writer.add_image('Depth_Image_real', depth_Pre_real,
                                         global_step)
                summary_writer.add_image('Depth_Image_fake', depth_Pre_fake,
                                         global_step)

            #============ print the log info ============#
            if (step + 1) % args.log_step == 0:
                errors = OrderedDict([
                    ('Loss_depth', Loss_depth.item()),
                    ('Loss_triplet', Loss_triplet.item()),
                    ('Loss_cls', Loss_cls.item()), ('Loss_G', Loss_G.item()),
                    ('loss_critic1', loss_critic1.item()),
                    ('loss_generator1', loss_generator1.item()),
                    ('loss_critic2', loss_critic2.item()),
                    ('loss_generator2', loss_generator2.item()),
                    ('loss_critic3', loss_critic3.item()),
                    ('loss_generator3', loss_generator3.item())
                ])

                Saver.print_current_errors((epoch + 1), (step + 1), errors)

            if (step + 1) % args.tst_step == 0:
                evaluate.evaluate_img(FeatExtor, DepthEsmator,
                                      data_loader_target, (epoch + 1),
                                      (step + 1), Saver)

            global_step += 1

            #############################
            # 2.4 save model parameters #
            #############################
            if ((step + 1) % args.model_save_step == 0):
                model_save_path = os.path.join(args.results_path, 'snapshots',
                                               savefilename)
                mkdir(model_save_path)

                torch.save(
                    FeatExtor.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-Ext-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    FeatEmbder.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-Embd-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    DepthEsmator.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-Depth-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    Discriminator1.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-D1-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    Discriminator2.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-D2-{}-{}.pt".format(epoch + 1, step + 1)))

                torch.save(
                    Discriminator3.state_dict(),
                    os.path.join(
                        model_save_path,
                        "DGFA-D3-{}-{}.pt".format(epoch + 1, step + 1)))

        if ((epoch + 1) % args.model_save_epoch == 0):
            model_save_path = os.path.join(args.results_path, 'snapshots',
                                           savefilename)
            mkdir(model_save_path)

            torch.save(
                FeatExtor.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Ext-{}.pt".format(epoch + 1)))

            torch.save(
                FeatEmbder.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Embd-{}.pt".format(epoch + 1)))

            torch.save(
                DepthEsmator.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-Depth-{}.pt".format(epoch + 1)))

            torch.save(
                Discriminator1.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-D1-{}.pt".format(epoch + 1)))

            torch.save(
                Discriminator2.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-D2-{}.pt".format(epoch + 1)))

            torch.save(
                Discriminator3.state_dict(),
                os.path.join(model_save_path,
                             "DGFA-D3-{}.pt".format(epoch + 1)))

    torch.save(FeatExtor.state_dict(),
               os.path.join(model_save_path, "DGFA-Ext-final.pt"))

    torch.save(FeatEmbder.state_dict(),
               os.path.join(model_save_path, "DGFA-Embd-final.pt"))

    torch.save(DepthEsmator.state_dict(),
               os.path.join(model_save_path, "DGFA-Depth-final.pt"))

    torch.save(Discriminator1.state_dict(),
               os.path.join(model_save_path, "DGFA-D1-final.pt"))

    torch.save(Discriminator2.state_dict(),
               os.path.join(model_save_path, "DGFA-D2-final.pt"))

    torch.save(Discriminator3.state_dict(),
               os.path.join(model_save_path, "DGFA-D3-final.pt"))