Exemplo n.º 1
0
def init_criterion(opt):
    criterions = {}

    criterions['GANLoss'] = GANLoss(
        opt.gan_mode).to('cuda') if torch.cuda.is_available() else GANLoss(
            opt.gan_mode)
    criterions['CELoss'] = nn.CrossEntropyLoss().to(
        'cuda') if torch.cuda.is_available() else nn.CrossEntropyLoss()

    return criterions
Exemplo n.º 2
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.criterionGAN = GANLoss(hyperparameters['dis']['gan_type']).cuda()
        self.featureLoss = nn.MSELoss(reduction='mean')
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
Exemplo n.º 3
0
                                  batch_size=opt.batchSize,
                                  shuffle=True)
testing_data_loader = DataLoader(dataset=test_set,
                                 num_workers=opt.threads,
                                 batch_size=opt.testBatchSize,
                                 shuffle=False)

print('===> Building model')

netG = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'batch', False, [0])

netD = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'batch', False, [0])

print('loading done')

criterionGAN = GANLoss()
criterionL1 = nn.L1Loss()
criterionMSE = nn.MSELoss()

# setup optimizer
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

print('---------- Networks initialized -------------')
print_network(netG)
print_network(netD)
print('-----------------------------------------------')

real_a = torch.FloatTensor(opt.batchSize, opt.input_nc, 256, 256)
real_b = torch.FloatTensor(opt.batchSize, opt.output_nc, 256, 256)
Exemplo n.º 4
0
                                               shuffle=False)
test_input, test_target = test_data_loader.__iter__().__next__()

# Models
G = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=6)
D = Discriminator(input_nc=6, ndf=64)
#G.cuda()
#D.cuda()
init_weights(G)
init_weights(D)
#G.init_weights(mean=0.0, std=0.02)
#D.init_weights(mean=0.0, std=0.02)

# Loss function
#BCE_loss = torch.nn.BCELoss()#.cuda()
BCE_loss = GANLoss()
L1_loss = torch.nn.L1Loss()  #.cuda()

# Optimizers
G_optimizer = torch.optim.Adam(G.parameters(),
                               lr=params.lrG,
                               betas=(params.beta1, params.beta2))
D_optimizer = torch.optim.Adam(D.parameters(),
                               lr=params.lrD,
                               betas=(params.beta1, params.beta2))

# Training GAN
D_avg_losses = []
G_avg_losses = []

D_losses = []
Exemplo n.º 5
0
                                  num_workers=opt.threads,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
device = torch.device("cuda:0" if opt.cuda else "cpu")
print('===> Building models')
net_g = define_G(opt.input_nc,
                 opt.output_nc,
                 opt.ngf,
                 'batch',
                 False,
                 'normal',
                 0.02,
                 gpu_id=device)
net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'basic', gpu_id=device)

criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device)

# setup optimizer
optimizer_g = optim.Adam(net_g.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
optimizer_d = optim.Adam(net_d.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)
face_descriptor = FaceDescriptor().to(device)
face_landmarks = FaceLandmarks().to(device)
Exemplo n.º 6
0
def main(name_exp, segloss=False, cuda=True, finetune=False):
    # Training settings
    parser = argparse.ArgumentParser(description='pix2pix-PyTorch-implementation')
    parser.add_argument('--batchSize', type=int, default=8, help='training batch size')
    parser.add_argument('--testBatchSize', type=int, default=8, help='testing batch size')
    parser.add_argument('--nEpochs', type=int, default=100, help='number of epochs to train for')
    parser.add_argument('--input_nc', type=int, default=3, help='input image channels')
    parser.add_argument('--output_nc', type=int, default=3, help='output image channels')
    parser.add_argument('--ngf', type=int, default=64, help='generator filt+ers in first conv layer')
    parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer')
    parser.add_argument('--lr', type=float, default=0.0002, help='Learning Rate. Default=0.0002')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
    parser.add_argument('--threads', type=int, default=8, help='number of threads for data loader to use')
    parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
    parser.add_argument('--lamb', type=int, default=10, help='weight on L1 term in objective')
    opt = parser.parse_args()

    cudnn.benchmark = True



    def val():
        net_current = "path_exp/checkpoint/DFS/{}/netG_model_current.pth".format(name_exp)
        netVal = torch.load(net_current)
        netVal.eval()
        SEG_NET.eval()
        features.eval()
        with torch.no_grad():
            total_mse = 0
            total_mse2 = 0
            avg_psnr_depth = 0
            avg_psnr_dehaze = 0
            avg_ssim_depth = 0
            avg_ssim_dehaze = 0
            for batch in validation_data_loader:
                input, target, depth = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
                if cuda == True:
                    input = input.cuda()
                    target = target.cuda()
                    depth = depth.cuda()
                    
                

                dehaze = netVal(input)
                prediction = SEG_NET(dehaze)

                avg_ssim_dehaze += pytorch_ssim.ssim(dehaze, target).item()

                mse = criterionMSE(prediction, depth)
                total_mse += mse.item()
                avg_psnr_depth += 10 * log10(1 / mse.item())

                mse2 = criterionMSE(dehaze, target)
                total_mse2 += mse2.item()
                avg_psnr_dehaze += 10 * log10(1 / mse2.item())

                avg_ssim_depth += pytorch_ssim.ssim(prediction, depth).item()


                visual_ret_val = OrderedDict()

                visual_ret_val['Haze'] = input
                visual_ret_val['Seg estimate'] = prediction
                visual_ret_val['Dehaze '] = dehaze
                visual_ret_val['GT dehaze'] = target
                visual_ret_val['GT Seg '] = depth

                visualizer.display_current_results(visual_ret_val, epoch, True)


            print("===> Validation")
            #f.write("===> Testing: \r\n")

            print("===> PSNR seg: {:.4f} ".format(avg_psnr_depth / len(validation_data_loader)))
            #f.write("===> PSNR depth: {:.4f} \r\n".format(avg_psnr_depth / len(validation_data_loader)))

            print("===> Mse seg: {:.4f} ".format(total_mse / len(validation_data_loader)))
            #f.write("===> Mse depth: {:.4f} \r\n".format(total_mse / len(validation_data_loader)))

            print("===> SSIM seg: {:.4f} ".format(avg_ssim_depth / len(validation_data_loader)))
            #f.write("===> SSIM depth: {:.4f} \r\n".format(avg_ssim_depth / len(validation_data_loader)))

            return total_mse / len(validation_data_loader)






    def testing():
        path = "path_exp/checkpoint/DFS/{}/netG_model_best.pth".format(name_exp)
        net = torch.load(path)
        net.eval()
        SEG_NET.eval()
        features.eval()
        with torch.no_grad():
            total_mse = 0
            total_mse2 = 0
            avg_psnr_depth = 0
            avg_psnr_dehaze = 0
            avg_ssim_depth = 0
            avg_ssim_dehaze = 0
            for batch in testing_data_loader:
                input, target, depth = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
                if cuda == True:
                    input = input.cuda()
                    target = target.cuda()
                    depth = depth.cuda()

                dehaze = net(input)
                prediction = SEG_NET(dehaze)

                avg_ssim_dehaze += pytorch_ssim.ssim(dehaze, target).item()

                mse = criterionMSE(prediction, depth)
                total_mse += mse.item()
                avg_psnr_depth += 10 * log10(1 / mse.item())

                mse2 = criterionMSE(dehaze, target)
                total_mse2 += mse2.item()
                avg_psnr_dehaze += 10 * log10(1 / mse2.item())

                avg_ssim_depth += pytorch_ssim.ssim(prediction, depth).item()

            print("===> Testing")
            print("===> PSNR seg: {:.4f} ".format(avg_psnr_depth / len(testing_data_loader)))
            print("===> Mse seg: {:.4f} ".format(total_mse / len(testing_data_loader)))
            print("===> SSIM seg: {:.4f} ".format(avg_ssim_depth / len(testing_data_loader)))
            print("===> PSNR dehaze: {:.4f} ".format(avg_psnr_dehaze / len(testing_data_loader)))
            print("===> SSIM dehaze: {:.4f} ".format(avg_ssim_dehaze / len(testing_data_loader)))





    def checkpoint():
        if not os.path.exists("checkpoint"):
            os.mkdir("checkpoint")
        if not os.path.exists(os.path.join("path_exp/checkpoint/DFS", name_exp)):
            os.mkdir(os.path.join("path_exp/checkpoint/DFS", name_exp))
        net_g_model_out_path = "path_exp/checkpoint/DFS/{}/netG_model_best.pth".format(name_exp)
        net_d_model_out_path = "path_exp/checkpoint/DFS/{}/netD_model_best.pth".format(name_exp)
        torch.save(netG, net_g_model_out_path)
        torch.save(netD, net_d_model_out_path)


    def checkpoint_current():
        if not os.path.exists(os.path.join("path_exp/checkpoint/DFS", name_exp)):
            os.mkdir(os.path.join("path_exp/checkpoint/DFS", name_exp))
        net_g_model_out_path = "path_exp/checkpoint/DFS/{}/netG_model_current.pth".format(name_exp)
        torch.save(netG, net_g_model_out_path)

    def checkpoint_seg():
        if not os.path.exists(os.path.join("path_exp/checkpoint/DFS", name_exp)):
            os.mkdir(os.path.join("path_exp/checkpoint/DFS", name_exp))
        net_g_model_out_path = "path_exp/checkpoint/DFS/{}/seg_net.pth".format(name_exp)
        torch.save(SEG_NET, net_g_model_out_path)



    torch.manual_seed(opt.seed)
    if cuda==True:
        torch.cuda.manual_seed(opt.seed)

    print(" ")
    print(name_exp)
    print(" ")

    print('===> Loading datasets')
    train_set = get_training_set('path_exp/cityscape/HAZE')
    val_set = get_val_set('path_exp/cityscape/HAZE')
    test_set = get_test_set('path_exp/cityscape/HAZE')


    training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
    validation_data_loader = DataLoader(dataset=val_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
    testing_data_loader= DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

    print('===> Building model')
    netG = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'batch', False, [0])
    netD = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'batch', False, [0])

    criterionGAN = GANLoss()
    criterionL1 = nn.L1Loss()
    criterionMSE = nn.MSELoss()

    # setup optimizer
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))



    print('---------- Networks initialized -------------')
    print_network(netG)
    print_network(netD)
    print('-----------------------------------------------')


    real_a = torch.FloatTensor(opt.batchSize, opt.input_nc, 256, 256)
    real_b = torch.FloatTensor(opt.batchSize, opt.output_nc, 256, 256)
    real_c = torch.FloatTensor(opt.batchSize, opt.output_nc, 256, 256)

    if cuda==True:
        netD = netD.cuda()
        netG = netG.cuda()
        criterionGAN = criterionGAN.cuda()
        criterionL1 = criterionL1.cuda()
        criterionMSE = criterionMSE.cuda()
        real_a = real_a.cuda()
        real_b = real_b.cuda()
        real_c=real_c.cuda()

    real_a = Variable(real_a)
    real_b = Variable(real_b)
    real_c = Variable(real_c)



    SEG_NET = torch.load("path_exp/SEG_NET.pth")

    optimizerSeg = optim.Adam(SEG_NET.parameters(), lr=opt.lr/10, betas=(opt.beta1, 0.999))



    features = Vgg16()

    if cuda==True:
        SEG_NET.cuda()
        features.cuda()


    bon =100000000
    for epoch in range(opt.nEpochs):
        features.eval()

        if finetune== True and epoch>50:
            SEG_NET.train()
        else:
            SEG_NET.eval()

        loss_epoch_gen=0
        loss_epoch_dis=0
        total_segloss=0
        loss_seg=0
        i=0
        for iteration, batch in enumerate(training_data_loader, 1):

            netG.train()
            i=i+1

            # forward
            real_a_cpu, real_b_cpu, real_c_cpu = batch[0], batch[1], batch[2]

            with torch.no_grad():
                real_a = real_a.resize_(real_a_cpu.size()).copy_(real_a_cpu)

            with torch.no_grad():
                real_b = real_b.resize_(real_b_cpu.size()).copy_(real_b_cpu)

            with torch.no_grad():
                real_c = real_c.resize_(real_c_cpu.size()).copy_(real_c_cpu)


            fake_b = netG(real_a)

            ############################
            # (1) Update D network: maximize log(D(x,y)) + log(1 - D(x,G(x)))
            ###########################

            optimizerD.zero_grad()

            # train with fake
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = netD.forward(fake_ab.detach())
            loss_d_fake = criterionGAN(pred_fake, False)

            # train with real
            real_ab = torch.cat((real_a, real_b), 1)
            pred_real = netD.forward(real_ab)
            loss_d_real = criterionGAN(pred_real, True)

            # Combined loss
            loss_d = (loss_d_fake + loss_d_real) * 0.5

            loss_d.backward()

            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
            ##########################
            optimizerG.zero_grad()
            # First, G(A) should fake the discriminator
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = netD.forward(fake_ab)
            loss_g_gan = criterionGAN(pred_fake, True)


            # Second, G(A) = B
            loss_g_l1 = criterionL1(fake_b, real_b) * opt.lamb

            features_y = features(fake_b)
            features_x = features(real_b)

            loss_content = criterionMSE(features_y[1], features_x[1])*10


            if segloss == True:
                fake_seg = SEG_NET(fake_b)
                loss_seg = criterionMSE(fake_seg, real_c) * 10

                total_segloss += loss_seg.item()

                features_y = features(fake_seg)
                features_x = features(real_c)

                ssim_seg = criterionMSE(features_y[1], features_x[1]) * 10

                loss_g = loss_g_gan + loss_g_l1 + loss_content + loss_seg


            else:
                loss_g = loss_g_gan + loss_g_l1+loss_content

            loss_epoch_gen+=loss_g.item()
            loss_epoch_dis+=loss_d.item()





            if finetune== True and epoch>50:
                loss_g.backward(retain_graph=True)

                optimizerG.step()

                loss_seg=loss_seg

                loss_seg.backward()

                optimizerSeg.zero_grad()

                optimizerSeg.step()

            else:
                loss_g.backward()
                optimizerG.step()



            errors_ret = OrderedDict()
            errors_ret['Total_G'] = float(loss_g)
            errors_ret['Content'] = float(loss_content)
            errors_ret['GAN'] = float(loss_g_gan)
            errors_ret['L1'] = float(loss_g_l1)
            errors_ret['D'] = float(loss_d)



            if i % 10 == 0:  # print training losses and save logging information to the disk
                if i > 0:
                    visualizer.plot_current_losses(epoch, i/(len(training_data_loader)*opt.batchSize), errors_ret)




        print("===> Epoch[{}]: Loss_D: {:.4f} Loss_G: {:.4f} Loss Seg: {:.4f} ".format(epoch, loss_epoch_dis,loss_epoch_gen, total_segloss))
        checkpoint_current()
        MSE=val()
        if MSE < bon:
            bon = MSE
            checkpoint()
            checkpoint_seg()
            print("BEST EPOCH SAVED")

    testing()
Exemplo n.º 7
0
training_data_loader = DataLoader(
    dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(
    dataset=test_set,num_workers=opt.threads, batch_size=opt.batchSize, shuffle=False)



## Define Network
netG = define_G(opt.input_nc, opt.output_nc,
                    opt.ngf, 'batch','leakyrelu', opt.useDropout, opt.upConvType, opt.gtype, opt.blockType, opt.nblocks, gpus, n_downsampling=opt.ndowns)
netD = define_D(opt.input_nc + opt.output_nc,
                    opt.ndf, 'batch', not opt.lsgan, opt.nlayers, gpus)

## Define Losses
criterionGAN = GANLoss(use_lsgan=opt.lsgan)
criterionL1 = nn.L1Loss()
criterionMSE = nn.MSELoss()

## Define Optimizers
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(
    opt.beta1, 0.999), weight_decay=opt.regTerm)
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(
    opt.beta1, 0.999), weight_decay=opt.regTerm)

## Continue from a checkpoint

if opt.cont:
    ## Load Networks
    netG.load_state_dict(torch.load(
        'checkpoint/{}/netG_model_epoch_{}.pth'.format(opt.savePath, opt.contEpoch),
Exemplo n.º 8
0
    def __init__(self, config):
        # start from basic parameters
        self.config = config
        self.gpuid = config.gpuid
        self.mode = config.mode
        self.pretrained_weights = config.pretrained_weights
        self.criterionMSE = nn.MSELoss()
        self.criterionL1 = nn.L1Loss()
        self.criterionBCE = nn.BCELoss()
        self.criterionGAN = GANLoss(use_lsgan=False).to(self.gpuid)
        self.parallel = config.parallel

        self.ckpt_dir = config.check_point_dir
        self.model_name = 'HeavyRain-stage%s-%s' % (self.config.training_stage,
                                                    str(datetime.date.today()))
        self.best_valid_acc = 0
        self.logs_dir = '.logs/'
        self.use_tensorboard = True

        # == hyper-parameters ==
        self.epoch = 0
        self.total_loss = 0
        self.input_list = []
        self.output_list = []
        self.ch_in = 6
        self.ch_out = 3
        self.ndf = 64
        self.batch_size = config.batch_size
        self.image_size = config.image_size
        self.epoch_limit = config.epoch_limit
        self.LR = config.learning_rate
        self.training_stage = config.training_stage

        # ==  I/O paths ==
        self.train_dir = config.train_dir
        self.val_dir = config.val_dir
        self.test_input_dir = config.test_input_dir

        # == initialize tensors ==
        # input
        self.image_in = torch.FloatTensor(self.batch_size, 3, self.image_size,
                                          self.image_size)
        self.image_in_var = None
        self.streak_gt_var = None
        self.trans_gt_var = None
        self.atm_gt_var = None
        self.clean_gt_var = None

        # outputs
        self.st_out = None
        self.trans_out = None
        self.atm_out = None
        self.clean_out = None
        self.realrain_st = None
        self.realrain_trans = None
        self.realrain_atm = None
        self.realrain_out = torch.FloatTensor(self.batch_size, 3,
                                              self.image_size, self.image_size)
        self.loss_adv_realrain = self.criterionGAN(
            torch.FloatTensor(0).cuda(), True)
        self.accs = AverageMeter()
        self.probability = -1
        self.fl = -1
        self.tl = -1

        # == declare models ==
        self.G = None
        self.D = None
        self.G_optim = None  # optimizer
        self.D_optim = None  # optimizer
        self.vgg_model = None
        self.create_model()

        # ==  reset other infrastructure  ==
        self.reset(config)
def main():

	print(f"epoch: {opt.niter+opt.niter_decay}")
	print(f"cuda: {opt.cuda}")
	print(f"dataset: {opt.dataset}")
	print(f"output: {opt.output_path}")

	if opt.cuda and not torch.cuda.is_available():
		raise Exception("No GPU found, please run without --cuda")

	cudnn.benchmark = True

	torch.manual_seed(opt.seed)
	if opt.cuda:
		torch.cuda.manual_seed(opt.seed)

	print('Loading datasets')
	train_set = get_training_set(root_path + opt.dataset, opt.direction)
	test_set = get_test_set(root_path + opt.dataset, opt.direction)

	training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
	testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)

	device = torch.device("cuda:0" if opt.cuda else "cpu")

	print('Building models')
	net_g = define_G(opt.input_nc, opt.output_nc, opt.g_ch, len(class_name_array), 'batch', False, 'normal', 0.02, gpu_id=device)
	net_d = define_D(opt.input_nc + opt.output_nc, opt.d_ch, len(class_name_array), 'basic', gpu_id=device)

	criterionGAN = GANLoss().to(device)
	criterionL1 = nn.L1Loss().to(device)
	criterionMSE = nn.MSELoss().to(device)

	# setup optimizer
	optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
	optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
	net_g_scheduler = get_scheduler(optimizer_g, opt)
	net_d_scheduler = get_scheduler(optimizer_d, opt)

	start_time = time.time()

	#save loss
	G_loss_array = []
	D_loss_array = []
	epoch_array = []

	for epoch in tqdm(range(opt.epoch_count, opt.niter + opt.niter_decay + 1), desc="Epoch"):
		# train
		loss_g_sum = 0
		loss_d_sum = 0
		for iteration, batch in enumerate(tqdm(training_data_loader, desc="Batch"), 1):
			# forward
			real_a, real_b, class_label, _ = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3][0]
			fake_b = net_g(real_a, class_label)

			######################
			# (1) Update D network
			######################

			optimizer_d.zero_grad()
			
			# train with fake
			if opt.padding:
				real_a_for_d = padding(real_a)
				real_b_for_d = padding(real_b)
				fake_b_for_d = padding(fake_b)
			else:
				real_a_for_d = real_a
				real_b_for_d = real_b
				fake_b_for_d = fake_b
			
			fake_ab = torch.cat((real_a_for_d, fake_b_for_d), 1)
			pred_fake = net_d.forward(fake_ab.detach(), class_label)
			loss_d_fake = criterionGAN(pred_fake, False)

			# train with real
			real_ab = torch.cat((real_a_for_d, real_b_for_d), 1)
			pred_real = net_d.forward(real_ab, class_label)
			loss_d_real = criterionGAN(pred_real, True)
			
			# Combined D loss
			loss_d = (loss_d_fake + loss_d_real) * 0.5

			loss_d.backward()
		   
			optimizer_d.step()

			######################
			# (2) Update G network
			######################

			optimizer_g.zero_grad()

			# First, G(A) should fake the discriminator
			fake_ab = torch.cat((real_a_for_d, fake_b_for_d), 1)
			pred_fake = net_d.forward(fake_ab, class_label)
			loss_g_gan = criterionGAN(pred_fake, True)

			# Second, G(A) = B
			loss_g_l1 = criterionL1(fake_b, real_b) * opt.lamb
			
			loss_g = loss_g_gan + loss_g_l1
			
			loss_g.backward()

			optimizer_g.step()
			loss_d_sum += loss_d.item()
			loss_g_sum += loss_g.item()

		update_learning_rate(net_g_scheduler, optimizer_g)
		update_learning_rate(net_d_scheduler, optimizer_d)
		
		# test
		avg_psnr = 0
		dst = Image.new('RGB', (512*4, 256*4))
		n = 0
		for batch in tqdm(testing_data_loader, desc="Batch"):
			input, target, class_label, _ = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3][0]

			prediction = net_g(input, class_label)
			mse = criterionMSE(prediction, target)
			psnr = 10 * log10(1 / mse.item())
			avg_psnr += psnr
			
			n += 1
			if n <= 16:
				#make test preview
				out_img = prediction.detach().squeeze(0).cpu()
				image_numpy = out_img.float().numpy()
				image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
				image_numpy = image_numpy.clip(0, 255)
				image_numpy = image_numpy.astype(np.uint8)
				image_pil = Image.fromarray(image_numpy)
				dst.paste(image_pil, ((n-1)%4*512, (n-1)//4*256))
				
		if not os.path.exists("results"):
			os.mkdir("results")
		if not os.path.exists(os.path.join("results", opt.output_path)):
			os.mkdir(os.path.join("results", opt.output_path))
		dst.save(f"results/{opt.output_path}/epoch{epoch}_test_preview.jpg")
		
		epoch_array += [epoch]
		G_loss_array += [loss_g_sum/len(training_data_loader)]
		D_loss_array += [loss_d_sum/len(training_data_loader)]
		
		if opt.graph_save_while_training and len(epoch_array) > 1:
			output_graph(epoch_array, G_loss_array, D_loss_array, False)
		
		#checkpoint
		if epoch % opt.save_interval == 0:
			if not os.path.exists("checkpoint"):
				os.mkdir("checkpoint")
			if not os.path.exists(os.path.join("checkpoint", opt.output_path)):
				os.mkdir(os.path.join("checkpoint", opt.output_path))
			net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format(opt.output_path, epoch)
			net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format(opt.output_path, epoch)
			torch.save(net_g, net_g_model_out_path)
			torch.save(net_d, net_d_model_out_path)

	#save the latest net
	if not os.path.exists("checkpoint"):
		os.mkdir("checkpoint")
	if not os.path.exists(os.path.join("checkpoint", opt.output_path)):
		os.mkdir(os.path.join("checkpoint", opt.output_path))
	net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format(opt.output_path, opt.niter + opt.niter_decay)
	net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format(opt.output_path, opt.niter + opt.niter_decay)
	torch.save(net_g, net_g_model_out_path)
	torch.save(net_d, net_d_model_out_path)
	print("\nCheckpoint saved to {}".format("checkpoint/" + opt.output_path))

	# output loss graph
	output_graph(epoch_array, G_loss_array, D_loss_array)

	# finish training
	now_time = time.time()
	t = now_time - start_time
	print(f"Training time: {t/60:.1f}m")
Exemplo n.º 10
0
net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'basic')

if opt.gpus[-1] > 0:
    net_g = nn.DataParallel(net_g, device_ids=gpus).cuda()
    net_d = nn.DataParallel(net_d, device_ids=gpus).cuda()

if opt.net_G:
    G_checkpoint = torch.load(opt.net_G).state_dict()
    net_g.load_state_dict(G_checkpoint)
    print("=> loaded checkpoint net_G")
if opt.net_D:
    D_checkpoint = torch.load(opt.net_D).state_dict()
    net_d.load_state_dict(D_checkpoint)
    print("=> loaded checkpoint net_D")

criterionGAN = GANLoss().cuda()
criterionL1 = nn.L1Loss().cuda()
criterionMSE = nn.MSELoss().cuda()

# setup optimizer
optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    # train
    for iteration, batch in enumerate(training_data_loader, 1):
        # forward
        real_a, real_b = batch[0].cuda(), batch[1].cuda()
        fake_b = net_g(real_a)