コード例 #1
0
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
        # define networks (both generator and discriminator)

        # LGQ here I change the generator to my line encoding
        #self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
        #                              not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG = fusion_nets._2layerFusionNets_()
        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc + opt.output_nc,
                                          opt.ndf, opt.netD, opt.n_layers_D,
                                          opt.norm, opt.init_type,
                                          opt.init_gain, self.gpu_ids)

        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            # LGQ add another loss for G
            self.criterionMTL = MTL_loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

        self.validation_init()
コード例 #2
0
        save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
        network.load_state_dict(torch.load(save_path))
def save_network(network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            network.cuda()

netG_deblur = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.gpu_ids, False,
                                      opt.learn_residual)
phsics_blur = physicsReblurNet(opt)
use_sigmoid = opt.gan_type == 'gan'
netD = networks.define_D(opt.output_nc, opt.ndf,
                                  opt.which_model_netD,
                                  opt.n_layers_D, opt.norm, use_sigmoid, opt.gpu_ids, False)

load_network(netG_deblur, 'G', opt.which_epoch)
load_network(netD, 'D', opt.which_epoch)
print('------- Networks deblur_G initialized ---------')
networks.print_network(netG_deblur)
print('-----------------------------------------------')

print('------- Networks deblur_D initialized ---------')
networks.print_network(netD)
print('-----------------------------------------------')
# ### Freeze layers

# In[6]:
コード例 #3
0
ファイル: train.py プロジェクト: sunattic/MLFcGAN
training_data_loader,val_data_loader = get_dataset_loader(opt)

device = torch.device("cuda:0" if opt.cuda else "cpu")

print('===> Building models')

net_g = define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,0,'batch', False, 'normal', 0.02, gpu_id=device,upsample=opt.upsample)
if opt.loss_method !='WGAN-GP':
    use_sigmoid = True
    norm = 'batch'
else:
    use_sigmoid = False
    norm = 'instance'

net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,norm=norm, gpu_id=device,use_sigmoid=use_sigmoid)

criterionGAN = GANLoss(opt.loss_method).to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device)
criterionSSIM = ssim
optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)

if opt.resume_netG_path:
    # resume training
    if os.path.isfile(opt.resume_netG_path):
        print("====>loading checkpoint for netG {}".format(opt.resume_netG_path))
        checkpoint = torch.load(opt.resume_netG_path)