def initialize(self, opt, net):
        BaseModel.initialize(self, opt)
        self.net = net.to(self.device)
        self.edge_map = EdgeMap(scale=1).to(self.device)

        if self.isTrain:
            # define loss functions
            self.vgg = losses.Vgg19(requires_grad=False).to(self.device)
            self.loss_dic = losses.init_loss(opt, self.Tensor)
            vggloss = losses.ContentLoss()
            vggloss.initialize(losses.VGGLoss(self.vgg))
            self.loss_dic['t_vgg'] = vggloss

            cxloss = losses.ContentLoss()
            if opt.unaligned_loss == 'vgg':
                cxloss.initialize(
                    losses.VGGLoss(self.vgg, weights=[0.1], indices=[31]))
            elif opt.unaligned_loss == 'ctx':
                cxloss.initialize(
                    losses.CXLoss(self.vgg,
                                  weights=[0.1, 0.1, 0.1],
                                  indices=[8, 13, 22]))
            elif opt.unaligned_loss == 'mse':
                cxloss.initialize(nn.MSELoss())
            elif opt.unaligned_loss == 'ctx_vgg':
                cxloss.initialize(
                    losses.CXLoss(self.vgg,
                                  weights=[0.1, 0.1, 0.1, 0.1],
                                  indices=[8, 13, 22, 31],
                                  criterions=[losses.CX_loss] * 3 +
                                  [nn.L1Loss()]))

            else:
                raise NotImplementedError

            self.loss_dic['t_cx'] = cxloss

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.net.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999),
                                                weight_decay=opt.wd)

            self._init_optimizer([self.optimizer_G])

            # define discriminator
            # if self.opt.lambda_gan > 0:
            self.netD = networks.define_D(opt, 3)
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self._init_optimizer([self.optimizer_D])

        if opt.no_verbose is False:
            self.print_network()
Exemplo n.º 2
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.parse_model = Stage_I_Model()
        
        self.parse_model.initialize(opt, "resNet")
        self.parse_model.eval()


        self.main_model = SemanticAlignModel()
        
        self.main_model.initialize(opt, "wapResNet_v3_afftps")

        self.net_SK = Skeleton_Model(opt).cuda()
        self.geo = GeoAPI()

        # cpm_model_path = 'openpose_coco_best.pth.tar'
        # self.cpm_model = heatmap_pose.construct_model(cpm_model_path)

        self.parsing_label_nc = opt.parsing_label_nc
        self.opt = opt

        nb = opt.batchSize
        size = opt.fineSize

        self.mask = torch.ones([nb, 1, 46, 32]).cuda()

        print('---------- Networks initialized -------------')
        
        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionFeat = torch.nn.L1Loss()

            # parsing loss for unsupervised pairs
            if not opt.no_Parsing_loss:
                self.criterionParsingLoss = ParsingLoss()
            if not self.opt.no_VGG_loss:
                self.criterionVGG = losses.VGGLoss()
            if not opt.no_TV_loss:
                self.criterionTV = losses.TVLoss()

            self.loss_names = ['D_real', 'D_fake', 'G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_L1', 'G_TV', 'G_Parsing']

            # initialize optimizers
            # optimizer G
            self.optimizer_G = self.main_model.optimizer_G
            # optimizer SK
            self.optimizer_SK = torch.optim.Adam([self.net_SK.alpha], lr=1 , betas=(opt.beta2, 0.999))
Exemplo n.º 3
0
 def __init__(self, opt):
     super(OASIS_model, self).__init__()
     self.opt = opt
     #--- generator and discriminator ---
     self.netG = generators.OASIS_Generator(opt)
     if opt.phase == "train":
         self.netD = discriminators.OASIS_Discriminator(opt)
     self.print_parameter_count()
     self.init_networks()
     #--- EMA of generator weights ---
     with torch.no_grad():
         self.netEMA = copy.deepcopy(self.netG) if not opt.no_EMA else None
     #--- load previous checkpoints if needed ---
     self.load_checkpoints()
     #--- perceptual loss ---#
     if opt.phase == "train":
         if opt.add_vgg_loss:
             self.VGG_loss = losses.VGGLoss(self.opt.gpu_ids)
    def initialize(self, opt):
        if len(opt.gpu_ids) > 0:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device("cpu")
        BaseModel.initialize(self, opt)

        in_channels = 3
        self.vgg = None

        if opt.hyper:
            self.vgg = losses.Vgg19(requires_grad=False).to(self.device)
            #in_channels += 1472
            #siju mod
            in_channels += 1280

        self.net_i = arch.__dict__[self.opt.inet](in_channels,
                                                  3).to(self.device)
        networks.init_weights(
            self.net_i,
            init_type=opt.init_type)  # using default initialization as EDSR
        self.edge_map = EdgeMap(scale=1).to(self.device)

        if self.isTrain:
            # define loss functions
            self.loss_dic = losses.init_loss(opt, self.Tensor)
            vggloss = losses.ContentLoss()
            vggloss.initialize(losses.VGGLoss(self.vgg))
            self.loss_dic['t_vgg'] = vggloss

            cxloss = losses.ContentLoss()
            if opt.unaligned_loss == 'vgg':
                cxloss.initialize(
                    losses.VGGLoss(self.vgg,
                                   weights=[0.1],
                                   indices=[opt.vgg_layer]))
            elif opt.unaligned_loss == 'ctx':
                cxloss.initialize(
                    losses.CXLoss(self.vgg,
                                  weights=[0.1, 0.1, 0.1],
                                  indices=[8, 13, 22]))
            elif opt.unaligned_loss == 'mse':
                cxloss.initialize(nn.MSELoss())
            elif opt.unaligned_loss == 'ctx_vgg':
                cxloss.initialize(
                    losses.CXLoss(self.vgg,
                                  weights=[0.1, 0.1, 0.1, 0.1],
                                  indices=[8, 13, 22, 31],
                                  criterions=[losses.CX_loss] * 3 +
                                  [nn.L1Loss()]))
            else:
                raise NotImplementedError

            self.loss_dic['t_cx'] = cxloss

            # Define discriminator
            # if self.opt.lambda_gan > 0:
            self.netD = networks.define_D(opt, 3)
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr,
                                                betas=(0.9, 0.999))
            self._init_optimizer([self.optimizer_D])

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.net_i.parameters(),
                                                lr=opt.lr,
                                                betas=(0.9, 0.999),
                                                weight_decay=opt.wd)

            self._init_optimizer([self.optimizer_G])

        if opt.resume:
            self.load(self, opt.resume_epoch)

        if opt.no_verbose is False:
            self.print_network()
Exemplo n.º 5
0
def main(opt):
    # Training config
    print(opt)

    t.manual_seed(0)

    # Parameters
    lambda_FM = 10
    lambda_P = 10
    lambda_2 = opt.lambda_second

    nf = 64  # 64
    n_blocks = 6  # 6

    # Load the networks
    if t.cuda.is_available():
        device = "cuda"
    else:
        device = 'cpu'

    print(f"Device: {device}")

    if opt.segment:
        disc = networks.MultiScaleDisc(input_nc=4, ndf=nf).to(device)
        gen = networks.Generator(input_nc=6,
                                 output_nc=1,
                                 ngf=nf,
                                 n_blocks=n_blocks,
                                 transposed=opt.transposed).to(device)
    else:
        disc = networks.MultiScaleDisc(input_nc=1, ndf=nf).to(device)
        gen = networks.Generator(input_nc=3,
                                 output_nc=1,
                                 ngf=nf,
                                 n_blocks=n_blocks,
                                 transposed=opt.transposed).to(device)

    if opt.current_epoch != 0:
        disc.load_state_dict(
            t.load(
                os.path.join(opt.checkpoints_file,
                             f"e_{opt.current_epoch:0>3d}_discriminator.pth")))
        gen.load_state_dict(
            t.load(
                os.path.join(opt.checkpoints_file,
                             f"e_{opt.current_epoch:0>3d}_generator.pth")))
        print(f"- e_{opt.current_epoch:0>3d}_generator.pth was loaded! -")
        print(f"- e_{opt.current_epoch:0>3d}_discriminator.pth was loaded! -")

    else:
        disc.apply(utils.weights_init)
        gen.apply(utils.weights_init)
        print("- Weights are initialized from scratch -")

    # Losses to track
    # # Main losses
    loss_change_g = []
    loss_change_d = []
    # # Components
    loss_change_fm1 = []
    loss_change_fm2 = []
    loss_change_d1 = []
    loss_change_d2 = []
    loss_change_g1 = []
    loss_change_g2 = []
    loss_change_p = []

    # Create optimizers (Notice the lr of discriminator)
    optim_g = optim.Adam(gen.parameters(),
                         lr=opt.learning_rate / 5,
                         betas=(0.5, 0.999))
    optim_d = optim.Adam(disc.parameters(),
                         lr=opt.learning_rate,
                         betas=(0.5, 0.999),
                         weight_decay=0.0001)

    # Create Schedulers
    # g_scheduler = t.optim.lr_scheduler.LambdaLR(optim_g, utils.lr_lambda)
    # d_scheduler = t.optim.lr_scheduler.LambdaLR(optim_d, utils.lr_lambda)

    # Create loss functions
    loss = losses.GanLoss()
    loss_fm = losses.FeatureMatchingLoss()
    loss_p = losses.VGGLoss(device)  # perceptual loss

    # Create dataloader
    ds = dataset.CustomDataset(opt.data_dir,
                               is_segment=opt.segment,
                               sf=opt.scale_factor)
    dataloader = DataLoader(ds,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=2)

    # Start to training
    print("Training is starting...")
    i = 0
    for e in range(1 + opt.current_epoch,
                   1 + opt.training_epoch + opt.current_epoch):
        print(f"---- Epoch #{e} ----")
        start = time.time()

        for data in tqdm(dataloader):
            i += 1

            rgb = data[0].to(device)
            ir = data[1].to(device)

            if opt.segment:
                segment = data[2].to(device)
                condition = t.cat([rgb, segment], dim=1)
                ir_ = t.cat([ir, segment], dim=1)

            else:
                condition = rgb
                ir_ = ir

            out1, out2 = disc(ir_)
            ir_pred = gen(condition)

            # # # Updating Discriminator # # #
            optim_d.zero_grad()

            if opt.segment:
                ir_pred_ = t.cat([ir_pred, segment], dim=1)
            else:
                ir_pred_ = ir_pred

            out1_pred, out2_pred = disc(
                ir_pred_.detach())  # It returns a list [fms... + output]

            l_d_pred1, l_d_pred2 = loss(out1_pred[-1],
                                        out2_pred[-1],
                                        is_real=False)
            l_d_real1, l_d_real2 = loss(out1[-1], out2[-1], is_real=True)

            l_d_scale1 = l_d_pred1 + l_d_real1
            l_d_scale2 = l_d_pred2 + l_d_real2

            disc_loss = l_d_scale1 + l_d_scale2 * lambda_2

            # Normalize the loss, and track
            loss_change_d += [disc_loss.item() / opt.batch_size]
            loss_change_d1 += [l_d_scale1.item() / opt.batch_size]
            loss_change_d2 += [l_d_scale2.item() / opt.batch_size]

            disc_loss.backward()
            optim_d.step()

            # # # Updating Generator # # #
            optim_g.zero_grad()

            out1_pred, out2_pred = disc(
                ir_pred_)  # It returns a list [fms... + output]

            fm_scale1 = loss_fm(out1_pred[:-1], out1[:-1])
            fm_scale2 = loss_fm(out2_pred[:-1], out2[:-1])

            fm = fm_scale1 + fm_scale2 * lambda_2

            perceptual = loss_p(ir_pred, ir)

            loss_change_fm1 += [fm_scale1.item() / opt.batch_size]
            loss_change_fm2 += [fm_scale2.item() / opt.batch_size]

            loss_change_p += [perceptual.item() / opt.batch_size]

            l_g_scale1, l_g_scale2 = loss(out1_pred[-1],
                                          out2_pred[-1],
                                          is_real=True)
            gen_loss = l_g_scale1 + l_g_scale2 * lambda_2 + fm * lambda_FM + perceptual * lambda_P

            loss_change_g += [gen_loss.item() / opt.batch_size]
            loss_change_g1 += [l_g_scale1.item() / opt.batch_size]
            loss_change_g2 += [l_g_scale2.item() / opt.batch_size]

            gen_loss.backward()
            optim_g.step()

            # Save images
            if i % opt.img_save_freq == 1:
                utils.save_tensor_images(ir_pred, i, opt.results_file, 'pred')
                utils.save_tensor_images(ir, i, opt.results_file, 'ir')
                utils.save_tensor_images(rgb, i, opt.results_file, 'rgb')
                utils.save_tensor_images(segment, i, opt.results_file,
                                         'segment')
                print('\nExample images saved')

                print("Losses:")
                print(
                    f"G: {loss_change_g[-1]:.4f}; D: {loss_change_d[-1]:.4f}")
                print(
                    f"G1: {loss_change_g1[-1]:.4f}; G2: {loss_change_g2[-1]:.4f}"
                )
                print(
                    f"D1: {loss_change_d1[-1]:.4f}; D2: {loss_change_d2[-1]:.4f}"
                )
                print(
                    f"FM1: {loss_change_fm1[-1]:.4f}; FM2: {loss_change_fm2[-1]:.4f}; P: {loss_change_p[-1]:.4f}"
                )

        # g_scheduler.step()
        # d_scheduler.step()

        print(
            f"Epoch duration: {int((time.time() - start) // 60):5d}m {(time.time() - start) % 60:.1f}s"
        )

        if i % opt.model_save_freq == 0:
            utils.save_model(disc, gen, e, opt.checkpoints_file)
    # End of training

    # Main losses are g and d, but I want to save all components separately
    utils.save_loss(d=loss_change_d,
                    d1=loss_change_d1,
                    d2=loss_change_d2,
                    g=loss_change_g,
                    g1=loss_change_g1,
                    g2=loss_change_g2,
                    fm1=loss_change_fm1,
                    fm2=loss_change_fm2,
                    p=loss_change_p,
                    path=opt.loss_file,
                    e=e)

    utils.save_model(disc, gen, e, opt.checkpoints_file)

    utils.show_loss(opt.checkpoints_file)
    print("Done!")