def __init__(self, opt):
        BaseModel.__init__(self, opt)
        self.net_names = ['G']
        self.net_G = Unet(inchannel=3,
                          outchannel=3,
                          ndf=64,
                          enc_blocks=1,
                          dec_blocks=1,
                          depth=3,
                          concat=True,
                          bilinear=self.opt.bilinear,
                          norm_layer='LN')
        if self.isTrain:
            self.loss_names = ['G_ad_gen', 'D', 'G', 'G_style', 'G_scene']
            self.net_names += ['D']
            self.net_D = ResDiscriminator(ndf=32,
                                          img_f=128,
                                          layers=4,
                                          use_spect=True)
            self.GANloss = AdversarialLoss('lsgan').to(self.device)
            self.Styleloss = vgg_style_loss().to(self.device)
            self.sceneloss = scene_loss(self.opt.scenepath).to(self.device)

            self.optim_G = torch.optim.Adam(self.net_G.parameters(),
                                            lr=opt.lr,
                                            betas=(0.0, 0.999))
            self.optim_D = torch.optim.Adam(self.net_D.parameters(),
                                            lr=opt.lr * opt.lrg2d,
                                            betas=(0.0, 0.999))
Esempio n. 2
0
    def __init__(self,
                 model,
                 losses,
                 metrics,
                 optimizer_g,
                 optimizer_d_s,
                 optimizer_d_t,
                 resume,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 train_logger=None,
                 learn_mask=True,
                 test_data_loader=None,
                 pretrained_path=None):
        super().__init__(model, losses, metrics, optimizer_g, optimizer_d_s,
                         optimizer_d_t, resume, config, train_logger,
                         pretrained_path)
        self.config = config
        self.data_loader = data_loader
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = self.config['visualization']['log_step']
        self.loss_gan_s_w = config['gan_losses']['loss_gan_spatial_weight']
        self.loss_gan_t_w = config['gan_losses']['loss_gan_temporal_weight']
        self.adv_loss_fn = AdversarialLoss()
        self.evaluate_score = config['trainer'].get('evaluate_score', True)
        self.store_gated_values = False
        self.printlog = False
        self.use_flow = config['use_flow']

        self.valid_length = config['validation']['valid_length']
        self.valid_interval = config['validation']['valid_interval']

        if self.test_data_loader is not None:
            self.toPILImage = ToPILImage()
            self.evaluate_test_warp_error = config.get(
                'evaluate_test_warp_error', False)
            self.test_output_root_dir = os.path.join(self.checkpoint_dir,
                                                     'test_outputs')
        init_i3d_model()
Esempio n. 3
0
 def __init__(self, opt):
     BaseModel.__init__(self, opt)
     self.net_names = ['G']
     self.net_G = Unet(inchannel=3,
                       outchannel=3,
                       ndf=64,
                       enc_blocks=2,
                       dec_blocks=2,
                       depth=3,
                       bilinear=self.opt.bilinear,
                       norm_layer='LN')
     if self.isTrain:
         self.loss_names = ['G_ad_gen', 'D', 'G', 'G_style', 'G_scene']
         self.net_names += ['D']
         self.net_D = ResDiscriminator(ndf=32,
                                       img_f=128,
                                       layers=4,
                                       use_spect=True)
         self.GANloss = AdversarialLoss('lsgan').to(self.device)
         self.Styleloss = vgg_style_loss().to(self.device)
         self.sceneloss = scene_loss('./checkpoints/net_E_s.path').to(
             self.device)
         self.set_Adam_optims(self.net_names)
Esempio n. 4
0
args = parser.parse_args()

device = torch.device('cuda:0')

dataset_train = RSEdgeDataSet()
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batchsize
                                               , shuffle=True, num_workers=0)

generator = EdgeGenerator(use_spectral_norm=True).to(device)
discriminator = Discriminator(in_channels=2, use_sigmoid=args.GAN_LOSS != 'hinge').to(device)

generator = torch.load('Edge_generator.pth')
discriminator = torch.load('Edge_discriminator.pth')

l1_loss = nn.L1Loss().to(device)
adversarial_loss = AdversarialLoss(type=args.GAN_LOSS).to(device)

gen_optimizer = optim.Adam(
    params=generator.parameters(),
    lr=float(args.LR),
    betas=(args.BETA1, args.BETA2)
)

dis_optimizer = optim.Adam(
    params=discriminator.parameters(),
    lr=float(args.LR) * float(args.D2G_LR),
    betas=(args.BETA1, args.BETA2)
)

list_dloss = []
list_gloss = []
Esempio n. 5
0
    def __init__(self, config):
        super(InpaintingModel, self).__init__('InpaintingModel', config)
        generator = InpaintingGenerator(config)
        self.with_discriminator = config['training']['discriminator']
        if self.with_discriminator:
            discriminator = InpaintingDiscriminator(config)

        if config["gpu"]:
            gpus = [int(i) for i in config["gpu"].split(",")]

            if len(gpus) > 1:
                gpus = list(range(len(gpus)))
                generator = nn.DataParallel(generator, gpus)
                if self.with_discriminator:
                    discriminator = nn.DataParallel(discriminator, gpus)

        self.add_module('generator', generator)

        l1_loss = nn.L1Loss()
        self.add_module('l1_loss', l1_loss)
        self.rec_loss_weight = config['training']["rec_loss_weight"]
        self.step_loss_weight = config['training']["step_loss_weight"]

        mse_loss = nn.MSELoss()
        self.add_module('mse_loss', mse_loss)
        self.mse_loss_weight = config['training']["mse_loss_weight"]

        style_loss = StyleLoss()
        self.add_module('style_loss', style_loss)
        self.style_loss_weight = config['training']["style_loss_weight"]

        per_loss = PerceptualLoss()
        self.add_module('per_loss', per_loss)
        self.per_loss_weight = config['training']["per_loss_weight"]

        learning_rate = config['training']["learning_rate"]
        betas = (config['training']["beta1"], config['training']["beta2"])

        if config['training']['optimizer'] == 'adam':
            self.gen_optimizer = torch.optim.Adam(generator.parameters(),
                                                  lr=learning_rate,
                                                  betas=betas)
        elif config['training']['optimizer'] == 'radam':
            self.gen_optimizer = RAdam(generator.parameters(),
                                       lr=learning_rate,
                                       betas=betas)

        if self.with_discriminator:
            self.add_module('discriminator', discriminator)
            adversarial_loss = AdversarialLoss(
                type=config['training']['gan_loss'])
            self.add_module('adversarial_loss', adversarial_loss)
            self.adversarial_loss_weight = config['training'][
                "adv_loss_weight"]

            self.dis_optimizer = torch.optim.Adam(discriminator.parameters(),
                                                  lr=learning_rate *
                                                  config['training']['d2g_lr'],
                                                  betas=betas)

        # Teacher forcing
        self.beta = config['training']['beta']

        self.alpha = config['training']['alpha']
        self.alpha_decay = config['training']['alpha_decay']
        self.alpha_decay_start_iter = config['training'][
            'alpha_decay_start_iter']

        self.alpha = config['training']['alpha']