Beispiel #1
0
        10, size=(int(args.wm_num[1]/args.wm_batchsize), args.wm_batchsize))
wm_labels = torch.from_numpy(np_labels).cuda()

#wm_labels = SpecifiedLabel()
best_real_acc, best_wm_acc, best_wm_input_acc = 0, 0, 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
train_loss, test_loss = [[], []], [[], []]
train_acc, test_acc = [[], []], [[], []]

# Model
print('==> Building model..')
if args.dataset == 'mnist':
    Hidnet = UnetGenerator_mnist()
    Disnet = DiscriminatorNet_mnist()
elif args.dataset == 'cifar10':
    Hidnet = UnetGenerator()
    Disnet = DiscriminatorNet()

#Dnnet = LeNet5()
#Dnnet = VGG('VGG19')
#Dnnet = model
#Dnnet = gcv.models.resnet50(pretrained=False)
Dnnet = ResNet34()
#Dnnet = PreActResNet18()
#Dnnet = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
#Dnnet = MobileNet()
#Dnnet = MobileNetV2()
#Dnnet = DPN26()
# net = ShuffleNetG2()
Beispiel #2
0
def main():
    ############### define global parameters ###############
    global opt, optimizer, optimizerR, writer, logPath, scheduler, schedulerR, val_loader, smallestLoss, DATA_DIR, noiser_dropout, noiser_gaussian, noiser_identity

    opt = parser.parse_args()
    opt.ngpu = torch.cuda.device_count()
    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, "
              "so you should probably run with --cuda")

    cudnn.benchmark = True

    if opt.hostname == 'DL178':
        DATA_DIR = '/media/user/SSD1TB-2/ImageNet' 
    elif opt.hostname == 'amax':
        # (DCMMC) server 199
        DATA_DIR = '/data/xwt/Universal-Deep-Hiding/ImageNet'
    assert DATA_DIR


    ############  create the dirs to save the result #############
    if not opt.debug:
        try:
            cur_time = time.strftime('%Y-%m-%d_H%H-%M-%S', time.localtime())
            if opt.test == '':
                secret_comment = 'color' if opt.channel_secret == 3 else 'gray'
                cover_comment = 'color' if opt.channel_cover == 3 else 'gray'
                comment = str(opt.num_secret) + secret_comment + 'In' + str(opt.num_cover) + cover_comment
                experiment_dir = opt.hostname + "_" + cur_time + "_" + str(opt.imageSize)+ "_"+ str(opt.num_secret) + "_"+ str(opt.num_training)+ "_" + \
                str(opt.bs_secret)+ "_" + str(opt.ngpu)+ "_" + opt.norm+ "_" + opt.loss+ "_"+ str(opt.beta)+ "_"+ comment + "_" + opt.remark
                opt.outckpts += experiment_dir + "/checkPoints"
                opt.trainpics += experiment_dir + "/trainPics"
                opt.validationpics += experiment_dir + "/validationPics"
                opt.outlogs += experiment_dir + "/trainingLogs"
                opt.outcodes += experiment_dir + "/codes"
                if not os.path.exists(opt.outckpts):
                    os.makedirs(opt.outckpts)
                if not os.path.exists(opt.trainpics):
                    os.makedirs(opt.trainpics)
                if not os.path.exists(opt.validationpics):
                    os.makedirs(opt.validationpics)
                if not os.path.exists(opt.outlogs):
                    os.makedirs(opt.outlogs)
                if not os.path.exists(opt.outcodes):
                    os.makedirs(opt.outcodes)
                save_current_codes(opt.outcodes)
            else:
                experiment_dir = opt.test
                opt.testPics += experiment_dir + "/testPics"
                opt.validationpics = opt.testPics
                opt.outlogs += experiment_dir + "/testLogs"
                if (not os.path.exists(opt.testPics)) and opt.test != '':
                    os.makedirs(opt.testPics)
                if not os.path.exists(opt.outlogs):
                    os.makedirs(opt.outlogs)
        except OSError:
            print("mkdir failed   XXXXXXXXXXXXXXXXXXXXX") # ignore

    logPath = opt.outlogs + '/%s_%d_log.txt' % (opt.dataset, opt.bs_secret)
    if opt.debug:
        logPath = './debug/debug_logs/debug.txt'
    print_log(str(opt), logPath)


    ##################  Datasets  #################
    traindir = os.path.join(DATA_DIR, 'train')
    valdir = os.path.join(DATA_DIR, 'val')

    transforms_color = transforms.Compose([ 
                transforms.Resize([opt.imageSize, opt.imageSize]),
                transforms.ToTensor(),
            ])  

    transforms_gray = transforms.Compose([
                transforms.Grayscale(num_output_channels=1),
                transforms.Resize([opt.imageSize, opt.imageSize]),
                transforms.ToTensor(),
            ])    
    if opt.channel_cover == 1:  
        transforms_cover = transforms_gray
    else:
         transforms_cover = transforms_color

    if opt.channel_secret == 1:  
        transforms_secret = transforms_gray
    else:
         transforms_secret = transforms_color

    if opt.test == '':
        train_dataset_cover = ImageFolder(
            traindir, 
            transforms_cover)

        train_dataset_secret = ImageFolder(
            traindir, 
            transforms_secret)

        val_dataset_cover = ImageFolder(
            valdir, 
            transforms_cover)
        val_dataset_secret = ImageFolder(
            valdir, 
            transforms_secret)

        assert train_dataset_cover; assert train_dataset_secret
        assert val_dataset_cover; assert val_dataset_secret
    else:
        opt.checkpoint = "./training/" + opt.test + "/checkPoints/" + "best_checkpoint.pth.tar"
        if opt.test_diff != '':
            opt.checkpoint_diff = "./training/" + opt.test_diff + "/checkPoints/" + "best_checkpoint.pth.tar"
        testdir = valdir
        test_dataset_cover = ImageFolder(
            testdir,  
            transforms_cover)
        test_dataset_secret = ImageFolder(
            testdir,  
            transforms_secret)
        assert test_dataset_cover; assert test_dataset_secret

    ##################  Hiding and Reveal  #################
    assert opt.imageSize % 32 == 0 
    num_downs = 5 
    if opt.norm == 'instance':
        norm_layer = nn.InstanceNorm2d
    if opt.norm == 'batch':
        norm_layer = nn.BatchNorm2d
    if opt.norm == 'none':
        norm_layer = None
    if opt.cover_dependent:
        Hnet = UnetGenerator(input_nc=opt.channel_secret*opt.num_secret+opt.channel_cover*opt.num_cover, output_nc=opt.channel_cover*opt.num_cover, num_downs=num_downs, norm_layer=norm_layer, output_function=nn.Sigmoid)
    else:
        Hnet = UnetGenerator(input_nc=opt.channel_secret*opt.num_secret, output_nc=opt.channel_cover*opt.num_cover, num_downs=num_downs, norm_layer=norm_layer, output_function=nn.Tanh)
    Rnet = RevealNet(input_nc=opt.channel_cover*opt.num_cover, output_nc=opt.channel_secret*opt.num_secret, nhf=64, norm_layer=norm_layer, output_function=nn.Sigmoid)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    p = 0.3
    noiser_dropout = Dropout([p, p])
    noiser_gaussian = gaussian_kernel()
    noiser_identity = Identity()

    if opt.cover_dependent:
        assert opt.num_training == 1
        assert opt.no_cover == False

    ##### We used kaiming normalization #####
    Hnet.apply(weights_init)
    Rnet.apply(weights_init)

    ##### Always set to multiple GPU mode  #####
    Hnet = torch.nn.DataParallel(Hnet).cuda()
    Rnet = torch.nn.DataParallel(Rnet).cuda()

    noiser_dropout = torch.nn.DataParallel(noiser_dropout).cuda()
    noiser_gaussian = torch.nn.DataParallel(noiser_gaussian).cuda()
    noiser_identity = torch.nn.DataParallel(noiser_identity).cuda()

    if opt.checkpoint != "":
        if opt.checkpoint_diff == "":
            checkpoint = torch.load(opt.checkpoint)
            Hnet.load_state_dict(checkpoint['H_state_dict'])
            Rnet.load_state_dict(checkpoint['R_state_dict'])
        else:
            checkpoint = torch.load(opt.checkpoint)
            checkpoint_diff = torch.load(opt.checkpoint_diff)
            Hnet.load_state_dict(checkpoint_diff['H_state_dict'])
            Rnet.load_state_dict(checkpoint['R_state_dict'])            

    print_network(Hnet)
    print_network(Rnet)

    # Loss and Metric
    if opt.loss == 'l1':
        criterion = nn.L1Loss().cuda()
    if opt.loss == 'l2':
        criterion = nn.MSELoss().cuda()

    # Train the networks when opt.test is empty
    if opt.test == '':
        # tensorboardX writer
        if not opt.debug:
            writer = SummaryWriter(log_dir='runs/' + experiment_dir)

        params = list(Hnet.parameters())+list(Rnet.parameters())
        optimizer = optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=8, verbose=True)        

        train_loader_secret = DataLoader(train_dataset_secret, batch_size=opt.bs_secret*opt.num_secret,
                                  shuffle=True, num_workers=int(opt.workers))
        train_loader_cover = DataLoader(train_dataset_cover, batch_size=opt.bs_secret*opt.num_cover*opt.num_training,
                                  shuffle=True, num_workers=int(opt.workers))
        val_loader_secret = DataLoader(val_dataset_secret, batch_size=opt.bs_secret*opt.num_secret,
                                shuffle=False, num_workers=int(opt.workers))
        val_loader_cover = DataLoader(val_dataset_cover, batch_size=opt.bs_secret*opt.num_cover*opt.num_training,
                                shuffle=True, num_workers=int(opt.workers))

        smallestLoss = 10000
        print_log("training is beginning .......................................................", logPath)
        for epoch in range(opt.epochs):
            ##### get a new zipped data loader for a new epoch to aviod unnecessary coding handling #####
            adjust_learning_rate(optimizer, epoch)
            train_loader = zip(train_loader_secret, train_loader_cover)
            val_loader = zip(val_loader_secret, val_loader_cover)
            ######################## train ##########################################
            train(train_loader, epoch, Hnet=Hnet, Rnet=Rnet, criterion=criterion)

            ####################### validation  #####################################
            val_hloss, val_rloss, val_hdiff, val_rdiff = validation(val_loader, epoch, Hnet=Hnet, Rnet=Rnet, criterion=criterion)

            ####################### adjust learning rate ############################
            scheduler.step(val_rloss)

            # save the best model parameters
            sum_diff = val_hdiff + val_rdiff
            is_best = sum_diff < globals()["smallestLoss"]
            globals()["smallestLoss"] = sum_diff

            save_checkpoint({
                'epoch': epoch + 1,
                'H_state_dict': Hnet.state_dict(),
                'R_state_dict': Rnet.state_dict(),
                'optimizer' : optimizer.state_dict(),
            }, is_best, epoch, '%s/epoch_%d_Hloss_%.4f_Rloss=%.4f_Hdiff_Hdiff%.4f_Rdiff%.4f' % (opt.outckpts, epoch, val_hloss, val_rloss, val_hdiff, val_rdiff) )

        if not opt.debug:
            writer.close()

     # For testing the trained network
    else:
        test_loader_secret = DataLoader(test_dataset_secret, batch_size=opt.bs_secret*opt.num_secret,
                                 shuffle=False, num_workers=int(opt.workers))
        test_loader_cover = DataLoader(test_dataset_cover, batch_size=opt.bs_secret*opt.num_cover*opt.num_training,
                                 shuffle=True, num_workers=int(opt.workers))
        test_loader = zip(test_loader_secret, test_loader_cover)
        #validation(test_loader, 0, Hnet=Hnet, Rnet=Rnet, criterion=criterion)
        analysis(test_loader, 0, Hnet=Hnet, Rnet=Rnet, criterion=criterion)
        analysis_img_save(test_loader, 0, Hnet=Hnet, Rnet=Rnet, criterion=criterion)
def main():
    ############### define global parameters ###############
    global opt, optimizerH, optimizerR, writer, logPath, schedulerH, schedulerR, val_loader, smallestLoss

    #################  output configuration   ###############
    opt = parser.parse_args()

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, "
              "so you should probably run with --cuda")

    cudnn.benchmark = True

    ############  create dirs to save the result #############
    if not opt.debug:
        try:
            cur_time = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime())
            experiment_dir = opt.hostname + "_" + cur_time + opt.remark
            opt.outckpts += experiment_dir + "/checkPoints"
            opt.trainpics += experiment_dir + "/trainPics"
            opt.validationpics += experiment_dir + "/validationPics"
            opt.outlogs += experiment_dir + "/trainingLogs"
            opt.outcodes += experiment_dir + "/codes"
            opt.testPics += experiment_dir + "/testPics"
            if not os.path.exists(opt.outckpts):
                os.makedirs(opt.outckpts)
            if not os.path.exists(opt.trainpics):
                os.makedirs(opt.trainpics)
            if not os.path.exists(opt.validationpics):
                os.makedirs(opt.validationpics)
            if not os.path.exists(opt.outlogs):
                os.makedirs(opt.outlogs)
            if not os.path.exists(opt.outcodes):
                os.makedirs(opt.outcodes)
            if (not os.path.exists(opt.testPics)) and opt.test != '':
                os.makedirs(opt.testPics)

        except OSError:
            print("mkdir failed   XXXXXXXXXXXXXXXXXXXXX")

    logPath = opt.outlogs + '/%s_%d_log.txt' % (opt.dataset, opt.batchSize)

    print_log(str(opt), logPath)
    save_current_codes(opt.outcodes)

    if opt.test == '':
        # tensorboardX writer
        writer = SummaryWriter(comment='**' + opt.remark)
        ##############   get dataset   ############################
        traindir = os.path.join(DATA_DIR, 'train')
        valdir = os.path.join(DATA_DIR, 'val')
        train_dataset = MyImageFolder(
            traindir,
            transforms.Compose([
                transforms.Resize([opt.imageSize,
                                   opt.imageSize]),  # resize to a given size
                transforms.ToTensor(),
            ]))
        val_dataset = MyImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize([opt.imageSize, opt.imageSize]),
                transforms.ToTensor(),
            ]))
        assert train_dataset
        assert val_dataset
    else:
        opt.Hnet = "./checkPoint/netH_epoch_73,sumloss=0.000447,Hloss=0.000258.pth"
        opt.Rnet = "./checkPoint/netR_epoch_73,sumloss=0.000447,Rloss=0.000252.pth"
        testdir = opt.test
        test_dataset = MyImageFolder(
            testdir,
            transforms.Compose([
                transforms.Resize([opt.imageSize, opt.imageSize]),
                transforms.ToTensor(),
            ]))
        assert test_dataset

    Hnet = UnetGenerator(input_nc=6,
                         output_nc=3,
                         num_downs=7,
                         output_function=nn.Sigmoid)
    Hnet.cuda()
    Hnet.apply(weights_init)
    # whether to load pre-trained model
    if opt.Hnet != "":
        Hnet.load_state_dict(torch.load(opt.Hnet))
    if opt.ngpu > 1:
        Hnet = torch.nn.DataParallel(Hnet).cuda()
    print_network(Hnet)

    Rnet = RevealNet(output_function=nn.Sigmoid)
    Rnet.cuda()
    Rnet.apply(weights_init)
    if opt.Rnet != '':
        Rnet.load_state_dict(torch.load(opt.Rnet))
    if opt.ngpu > 1:
        Rnet = torch.nn.DataParallel(Rnet).cuda()
    print_network(Rnet)

    # MSE loss
    criterion = nn.MSELoss().cuda()
    # training mode
    if opt.test == '':
        # setup optimizer
        optimizerH = optim.Adam(Hnet.parameters(),
                                lr=opt.lr,
                                betas=(opt.beta1, 0.999))
        schedulerH = ReduceLROnPlateau(optimizerH,
                                       mode='min',
                                       factor=0.2,
                                       patience=5,
                                       verbose=True)

        optimizerR = optim.Adam(Rnet.parameters(),
                                lr=opt.lr,
                                betas=(opt.beta1, 0.999))
        schedulerR = ReduceLROnPlateau(optimizerR,
                                       mode='min',
                                       factor=0.2,
                                       patience=8,
                                       verbose=True)

        train_loader = DataLoader(train_dataset,
                                  batch_size=opt.batchSize,
                                  shuffle=True,
                                  num_workers=int(opt.workers))
        val_loader = DataLoader(val_dataset,
                                batch_size=opt.batchSize,
                                shuffle=False,
                                num_workers=int(opt.workers))
        smallestLoss = 10000
        print_log(
            "training is beginning .......................................................",
            logPath)
        for epoch in range(opt.niter):
            ######################## train ##########################################
            train(train_loader,
                  epoch,
                  Hnet=Hnet,
                  Rnet=Rnet,
                  criterion=criterion)

            ####################### validation  #####################################
            val_hloss, val_rloss, val_sumloss = validation(val_loader,
                                                           epoch,
                                                           Hnet=Hnet,
                                                           Rnet=Rnet,
                                                           criterion=criterion)

            ####################### adjust learning rate ############################
            schedulerH.step(val_sumloss)
            schedulerR.step(val_rloss)

            # save the best model parameters
            if val_sumloss < globals()["smallestLoss"]:
                globals()["smallestLoss"] = val_sumloss
                # do checkPointing
                torch.save(
                    Hnet.state_dict(),
                    '%s/netH_epoch_%d,sumloss=%.6f,Hloss=%.6f.pth' %
                    (opt.outckpts, epoch, val_sumloss, val_hloss))
                torch.save(
                    Rnet.state_dict(),
                    '%s/netR_epoch_%d,sumloss=%.6f,Rloss=%.6f.pth' %
                    (opt.outckpts, epoch, val_sumloss, val_rloss))

        writer.close()

    # test mode
    else:
        test_loader = DataLoader(test_dataset,
                                 batch_size=opt.batchSize,
                                 shuffle=False,
                                 num_workers=int(opt.workers))
        test(test_loader, 0, Hnet=Hnet, Rnet=Rnet, criterion=criterion)
        print(
            "##################   test is completed, the result pic is saved in the ./training/yourcompuer+time/testPics/   ######################"
        )
def main():
    ############### define global parameters ###############
    global opt, optimizerH, optimizerR, optimizerD, writer, logPath, schedulerH, schedulerR
    global val_loader, smallestLoss,  mse_loss, gan_loss, pixel_loss, patch, criterion_GAN, criterion_pixelwise

    #################  输出配置参数   ###############
    opt = parser.parse_args()

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device, "
              "so you should probably run with --cuda")

    cudnn.benchmark = True

    ############  create the dirs to save the result #############

    cur_time = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime())
    experiment_dir = opt.hostname  + "_" + cur_time + opt.remark
    opt.outckpts += experiment_dir + "/checkPoints"
    opt.trainpics += experiment_dir + "/trainPics"
    opt.validationpics += experiment_dir + "/validationPics"
    opt.outlogs += experiment_dir + "/trainingLogs"
    opt.outcodes += experiment_dir + "/codes"
    opt.testPics += experiment_dir + "/testPics"
    if not os.path.exists(opt.outckpts):
        os.makedirs(opt.outckpts)
    if not os.path.exists(opt.trainpics):
        os.makedirs(opt.trainpics)
    if not os.path.exists(opt.validationpics):
        os.makedirs(opt.validationpics)
    if not os.path.exists(opt.outlogs):
        os.makedirs(opt.outlogs)
    if not os.path.exists(opt.outcodes):
        os.makedirs(opt.outcodes)
    if (not os.path.exists(opt.testPics)) and opt.test != '':
        os.makedirs(opt.testPics)



    logPath = opt.outlogs + '/%s_%d_log.txt' % (opt.dataset, opt.batchSize)

    # 保存模型的参数
    print_log(str(opt), logPath)
    # 保存本次实验的代码
    save_current_codes(opt.outcodes)
    # tensorboardX writer
    writer = SummaryWriter(comment='**' + opt.hostname + "_" + opt.remark)




    ##############   获取数据集   ############################
    DATA_DIR_root = './datasets/'
    DATA_DIR = os.path.join(DATA_DIR_root, opt.datasets)

    traindir = os.path.join(DATA_DIR, 'train')
    valdir = os.path.join(DATA_DIR, 'val')
    secretdir = os.path.join(DATA_DIR_root, opt.secret)
    

    
    train_dataset = MyImageFolder(
        traindir,  
        transforms.Compose([ 
            transforms.Resize([opt.imageSize, 512]),  
            transforms.ToTensor(),            
        ]))
    val_dataset = MyImageFolder(
        valdir,  
        transforms.Compose([  
            transforms.Resize([opt.imageSize, 512]),  
            transforms.ToTensor(),  
        ]))
		
    secret_dataset = MyImageFolder(
        secretdir,  
        transforms.Compose([ 
            transforms.Resize([opt.imageSize, opt.imageSize]), 
            transforms.ToTensor(),  
        ]))
		
    assert train_dataset
    assert val_dataset
    assert secret_dataset


    train_loader = DataLoader(train_dataset, batch_size=opt.batchSize,
                              shuffle=True, num_workers=int(opt.workers))
    secret_loader = DataLoader(secret_dataset, batch_size=opt.batchSize,
                              shuffle=False, num_workers=int(opt.workers))
    val_loader = DataLoader(val_dataset, batch_size=opt.batchSize,
                            shuffle=True, num_workers=int(opt.workers))    	

    ##############   所使用网络结构   ############################


    Hnet = UnetGenerator(input_nc=6, output_nc=3, num_downs= opt.num_downs, output_function=nn.Sigmoid)
    Hnet.cuda()
    Hnet.apply(weights_init)

    Rnet = RevealNet(output_function=nn.Sigmoid)
    Rnet.cuda()
    Rnet.apply(weights_init)

    if opt.Dnorm == "spectral" :
        Dnet = Discriminator_SN(in_channels=3)
        Dnet.cuda()
    elif opt.Dnorm == "switch" :
        Dnet = Discriminator_Switch(in_channels=3)
        Dnet.cuda()
    else:
        Dnet = Discriminator(in_channels=3)
        Dnet.cuda()


    # Dnet.apply(weights_init)
    
    # Calculate output of image discriminator (PatchGAN)
    patch = (1, opt.imageSize // 2 ** 4, opt.imageSize // 2 ** 4)


    # setup optimizer
    optimizerH = optim.Adam(Hnet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    schedulerH = ReduceLROnPlateau(optimizerH, mode='min', factor=0.2, patience=5, verbose=True)

    optimizerR = optim.Adam(Rnet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    schedulerR = ReduceLROnPlateau(optimizerR, mode='min', factor=0.2, patience=8, verbose=True)

    optimizerD = optim.Adam(Dnet.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    schedulerD = ReduceLROnPlateau(optimizerD, mode='min', factor=0.2, patience=5, verbose=True)


    # 判断是否接着之前的训练
    if opt.Hnet != "":
        Hnet.load_state_dict(torch.load(opt.Hnet))
    # 两块卡加这行
    if opt.ngpu > 1:
        Hnet = torch.nn.DataParallel(Hnet).cuda()
    print_network(Hnet)


    if opt.Rnet != '':
        Rnet.load_state_dict(torch.load(opt.Rnet))
    if opt.ngpu > 1:
        Rnet = torch.nn.DataParallel(Rnet).cuda()
    print_network(Rnet)

    if opt.Dnet != '':
        Dnet.load_state_dict(torch.load(opt.Dnet))
    if opt.ngpu > 1:
        Dnet = torch.nn.DataParallel(Dnet).cuda()
    print_network(Dnet)


    # define loss
    mse_loss = nn.MSELoss().cuda()
    criterion_GAN = nn.MSELoss().cuda()
    criterion_pixelwise = nn.L1Loss().cuda()


    smallestLoss = 10000
    print_log("training is beginning .......................................................", logPath)
    for epoch in range(opt.niter):
        ######################## train ##########################################
        train(train_loader, secret_loader, epoch, Hnet=Hnet, Rnet=Rnet, Dnet=Dnet)

        ####################### validation  #####################################
        val_hloss, val_rloss, val_r_mseloss, val_r_consistloss, val_dloss, val_fakedloss, val_realdloss, val_Ganlosses, val_Pixellosses, val_sumloss = validation(val_loader, secret_loader, epoch, Hnet=Hnet, Rnet=Rnet, Dnet=Dnet)

        ####################### adjust learning rate ############################
        schedulerH.step(val_sumloss)
        schedulerR.step(val_rloss)
        schedulerD.step(val_dloss)

        # # save the best model parameters
        # if val_sumloss < globals()["smallestLoss"]:
        #     globals()["smallestLoss"] = val_sumloss
        #     # do checkPointing
        #     torch.save(Hnet.state_dict(),
        #                '%s/netH_epoch_%d,sumloss=%.6f,Hloss=%.6f.pth' % (
        #                    opt.outckpts, epoch, val_sumloss, val_hloss))
        #     torch.save(Rnet.state_dict(),
        #                '%s/netR_epoch_%d,sumloss=%.6f,Rloss=%.6f.pth' % (
        #                    opt.outckpts, epoch, val_sumloss, val_rloss))
        #     torch.save(Dnet.state_dict(),
        #                '%s/netD_epoch_%d,sumloss=%.6f,Dloss=%.6f.pth' % (
        #                    opt.outckpts, epoch, val_sumloss, val_dloss))

        # save the epoch model parameters

        torch.save(Hnet.state_dict(),
                   '%s/netH_epoch_%d,sumloss=%.6f,Hloss=%.6f.pth' % (
                       opt.outckpts, epoch, val_sumloss, val_hloss))
        torch.save(Rnet.state_dict(),
                   '%s/netR_epoch_%d,sumloss=%.6f,Rloss=%.6f.pth' % (
                       opt.outckpts, epoch, val_sumloss, val_rloss))
        torch.save(Dnet.state_dict(),
                   '%s/netD_epoch_%d,sumloss=%.6f,Dloss=%.6f.pth' % (
                       opt.outckpts, epoch, val_sumloss, val_dloss))

    writer.close()