def __init__(self, device='cpu', last=nn.Sigmoid):
     super(SGAN, self).__init__()
     self.device = device
     self.net_g = G()
     self.net_d = D(last=last)
     self.criterion = GANLoss(relativistic=False)
     self.optim_G = Adam(self.net_g.parameters())
     self.optim_D = Adam(self.net_d.parameters())
示例#2
0
# content loss
if opt.content_loss_type == 'L1_Charbonnier':
    content_loss = L1_Charbonnier_loss()
elif opt.content_loss_type == 'L1':
    content_loss = torch.nn.L1Loss()
elif opt.content_loss_type == 'L2':
    content_loss = torch.nn.MSELoss()

# pixel loss
if opt.pixel_loss_type == 'L1':
    pixel_loss = torch.nn.L1Loss()
elif opt.pixel_loss_type == 'L2':
    pixel_loss = torch.nn.MSELoss()

# gan loss
GAN_loss = GANLoss(opt.gan_type, real_label_val=1.0, fake_label_val=0.0)
edge_loss = edgeV_loss()
tv_loss = TV_loss()
# GPU
if opt.cuda and not torch.cuda.is_available():  # 检查是否有GPU
    raise Exception('No GPU found, please run without --cuda')
print("===> Setting GPU")
if opt.cuda:
    print('cuda_mode:', opt.cuda)
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    feature_extractor = feature_extractor.cuda()
    content_loss = content_loss.cuda()
    pixel_loss = pixel_loss.cuda()
    GAN_loss = GAN_loss.cuda()
    edge_loss = edge_loss.cuda()
示例#3
0
        vec.append(word2vec[term])

del word2vec

# BERT Model
model = modeling.BertNoEmbed(vocab=vocab, hidden_size=1024, enc_num_layer=3)
model.load_state_dict(torch.load('checkpoint/bert-LanGen-last.pt')['state'])
model.cuda()
d_net = modeling.TextCNNClassify(vocab, vec, num_labels=2)
d_net.cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(d_net.parameters(), lr=0.01)

label_smoothing = modeling.LabelSmoothing(len(vocab), 0, 0.1)
label_smoothing.cuda()
gan_loss = GANLoss()
gan_loss.cuda()
G_STEP = 1
D_STEP = 3
D_PRE = 5
SAVE_EVERY = 50
PENALTY_EPOCH = -1
DRAW_LEARNING_CURVE = False
data = []

# Tokenized input
print('Tokenization...')
with open('pair.csv') as PAIR:
    for line in tqdm(PAIR):
        [text, summary, _] = line.split(',')
        texts = []
示例#4
0
cprint('==> Preparing Data Set: Complete\n', 'green')

################################################################################
cprint('==> Building Models', 'yellow')
netG = define_G(opt.input_nc, opt.output_nc, opt.ngf, norm='batch', use_dropout=False, gpu_ids=gpu_ids)
netD = define_D(opt.input_nc + opt.output_nc, opt.ndf, norm='batch', use_sigmoid=False, gpu_ids=gpu_ids)

print('---------- Networks initialized -------------')
print_network(netG)
print_network(netD)
print('-----------------------------------------------\n')
cprint('==> Building Models: Complete\n', 'green')

################################################################################
criterionGAN = GANLoss()
criterionL1 = nn.L1Loss()
criterionMSE = nn.MSELoss()

# setup optimizer
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

real_a = torch.FloatTensor(opt.batchSize, opt.input_nc, 256, 256)
real_b = torch.FloatTensor(opt.batchSize, opt.output_nc, 256, 256)

if opt.cuda:
    netD = netD.cuda()
    netG = netG.cuda()
    criterionGAN = criterionGAN.cuda()
    criterionL1 = criterionL1.cuda()
示例#5
0
def train(opt):
    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id)
                          if opt.gpu_id >= 0 else torch.device('cpu'))

    #### dataset
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()
    print("The number of training images = %d." % len(data_set))

    #### initialize models
    ## declaration
    E_a2Zb = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Zb2b = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Zb2Za = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_b = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    E_b2Za = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Za2a = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Za2Zb = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_a = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    ## initialization
    E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device)
    G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device)
    T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device)
    D_b = init_net(D_b, init_type=opt.init_type).to(device)

    E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device)
    G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device)
    T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device)
    D_a = init_net(D_a, init_type=opt.init_type).to(device)
    print(
        "+------------------------------------------------------+\nFinish initializing networks."
    )

    #### optimizer and criterion
    ## criterion
    criterionGAN = GANLoss(opt.gan_mode).to(device)
    criterionZId = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionCTC = nn.L1Loss()
    criterionZCyc = nn.L1Loss()

    ## optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(),
                                                   G_Zb2b.parameters(),
                                                   T_Zb2Za.parameters(),
                                                   E_b2Za.parameters(),
                                                   G_Za2a.parameters(),
                                                   T_Za2Zb.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(),
                                                   D_b.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))

    ## scheduler
    scheduler = [
        get_scheduler(optimizer_G, opt),
        get_scheduler(optimizer_D, opt)
    ]

    print(
        "+------------------------------------------------------+\nFinish initializing the optimizers and criterions."
    )

    #### global variables
    checkpoints_pth = os.path.join(opt.checkpoints, opt.name)
    if os.path.exists(checkpoints_pth) is not True:
        os.mkdir(checkpoints_pth)
        os.mkdir(os.path.join(checkpoints_pth, 'images'))
    record_fh = open(os.path.join(checkpoints_pth, 'records.txt'),
                     'w',
                     encoding='utf-8')
    loss_names = [
        'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B',
        'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'
    ]

    fake_A_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images
    fake_B_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images

    print(
        "+------------------------------------------------------+\nFinish preparing the other works."
    )
    print(
        "+------------------------------------------------------+\nNow training is beginning .."
    )
    #### training
    cur_iter = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch

        for i, data in enumerate(data_set):
            ## setup inputs
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)

            ## forward
            # image cycle / GAN
            latent_B = E_a2Zb(real_A)  #-> a -> Zb     : E_a2b(a)
            fake_B = G_Zb2b(latent_B)  #-> Zb -> b'    : G_b(E_a2b(a))
            latent_A = E_b2Za(real_B)  #-> b -> Za     : E_b2a(b)
            fake_A = G_Za2a(latent_A)  #-> Za -> a'    : G_a(E_b2a(b))

            # Idt
            '''
            rec_A = G_Za2a(E_b2Za(fake_B))          #-> b' -> Za' -> rec_a  : G_a(E_b2a(fake_b))
            rec_B = G_Zb2b(E_a2Zb(fake_A))          #-> a' -> Zb' -> rec_b  : G_b(E_a2b(fake_a))
            '''
            idt_latent_A = E_b2Za(real_A)  #-> a -> Za        : E_b2a(a)
            idt_A = G_Za2a(idt_latent_A)  #-> Za -> idt_a    : G_a(E_b2a(a))
            idt_latent_B = E_a2Zb(real_B)  #-> b -> Zb        : E_a2b(b)
            idt_B = G_Zb2b(idt_latent_B)  #-> Zb -> idt_b    : G_b(E_a2b(b))

            # ZIdt
            T_latent_A = T_Zb2Za(latent_B)  #-> Zb -> Za''  : T_b2a(E_a2b(a))
            T_rec_A = G_Za2a(
                T_latent_A)  #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a)))
            T_latent_B = T_Za2Zb(latent_A)  #-> Za -> Zb''  : T_a2b(E_b2a(b))
            T_rec_B = G_Zb2b(
                T_latent_B)  #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b)))

            # CTC
            T_idt_latent_B = T_Za2Zb(idt_latent_A)  #-> a -> T_a2b(E_b2a(a))
            T_idt_latent_A = T_Zb2Za(idt_latent_B)  #-> b -> T_b2a(E_a2b(b))

            # ZCyc
            TT_latent_B = T_Za2Zb(T_latent_A)  #-> T_a2b(T_b2a(E_a2b(a)))
            TT_latent_A = T_Zb2Za(T_latent_B)  #-> T_b2a(T_a2b(E_b2a(b)))

            ### optimize parameters
            ## Generator updating
            set_requires_grad(
                [D_b, D_a],
                False)  #-> set Discriminator to require no gradient
            optimizer_G.zero_grad()
            # GAN loss
            loss_G_A = criterionGAN(D_b(fake_B), True)
            loss_G_B = criterionGAN(D_a(fake_A), True)
            loss_GAN = loss_G_A + loss_G_B
            # Idt loss
            loss_idt_A = criterionIdt(idt_A, real_A)
            loss_idt_B = criterionIdt(idt_B, real_B)
            loss_Idt = loss_idt_A + loss_idt_B
            # Latent cross-identity loss
            loss_Zid_A = criterionZId(T_rec_A, real_A)
            loss_Zid_B = criterionZId(T_rec_B, real_B)
            loss_Zid = loss_Zid_A + loss_Zid_B
            # Latent cross-translation consistency
            loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A)
            loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B)
            loss_CTC = loss_CTC_B + loss_CTC_A
            # Latent cycle consistency
            loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A)
            loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B)
            loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A

            loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc

            # backward and gradient updating
            loss_G.backward()
            optimizer_G.step()

            ## Discriminator updating
            set_requires_grad([D_b, D_a],
                              True)  # -> set Discriminator to require gradient
            optimizer_D.zero_grad()

            # backward D_b
            fake_B_ = fake_B_pool.query(fake_B)
            #-> real_B, fake_B
            pred_real_B = D_b(real_B)
            loss_D_real_B = criterionGAN(pred_real_B, True)

            pred_fake_B = D_b(fake_B_)
            loss_D_fake_B = criterionGAN(pred_fake_B, False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()

            # backward D_a
            fake_A_ = fake_A_pool.query(fake_A)
            #-> real_A, fake_A
            pred_real_A = D_a(real_A)
            loss_D_real_A = criterionGAN(pred_real_A, True)

            pred_fake_A = D_a(fake_A_)
            loss_D_fake_A = criterionGAN(pred_fake_A, False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()

            # update the gradients
            optimizer_D.step()

            ### validate here, both qualitively and quantitatively
            ## record the losses
            if cur_iter % opt.log_freq == 0:
                # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B']
                losses = [
                    loss_G_A.item(),
                    loss_D_A.item(),
                    loss_idt_A.item(),
                    loss_CTC_A.item(),
                    loss_Zid_A.item(),
                    loss_ZCyc_A.item(),
                    loss_G_B.item(),
                    loss_D_B.item(),
                    loss_idt_B.item(),
                    loss_CTC_B.item(),
                    loss_Zid_B.item(),
                    loss_ZCyc_B.item()
                ]
                # record
                line = ''
                for loss in losses:
                    line += '{} '.format(loss)
                record_fh.write(line[:-1] + '\n')
                # print out
                print('Epoch: %3d/%3dIter: %9d--------------------------+' %
                      (epoch, opt.epoch, i))
                field_names = loss_names[:len(loss_names) // 2]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[:len(field_names)])
                print(table.get_string(reversesort=True))

                field_names = loss_names[len(loss_names) // 2:]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[-len(field_names):])
                print(table.get_string(reversesort=True))

            ## visualize
            if cur_iter % opt.vis_freq == 0:
                if opt.gpu_id >= 0:
                    real_A = real_A.cpu().data
                    real_B = real_B.cpu().data
                    fake_A = fake_A.cpu().data
                    fake_B = fake_B.cpu().data
                    idt_A = idt_A.cpu().data
                    idt_B = idt_B.cpu().data
                    T_rec_A = T_rec_A.cpu().data
                    T_rec_B = T_rec_B.cpu().data

                plt.subplot(241), plt.title('real_A'), plt.imshow(
                    tensor2image_RGB(real_A[0, ...]))
                plt.subplot(242), plt.title('fake_B'), plt.imshow(
                    tensor2image_RGB(fake_B[0, ...]))
                plt.subplot(243), plt.title('idt_A'), plt.imshow(
                    tensor2image_RGB(idt_A[0, ...]))
                plt.subplot(244), plt.title('L_idt_A'), plt.imshow(
                    tensor2image_RGB(T_rec_A[0, ...]))

                plt.subplot(245), plt.title('real_B'), plt.imshow(
                    tensor2image_RGB(real_B[0, ...]))
                plt.subplot(246), plt.title('fake_A'), plt.imshow(
                    tensor2image_RGB(fake_A[0, ...]))
                plt.subplot(247), plt.title('idt_B'), plt.imshow(
                    tensor2image_RGB(idt_B[0, ...]))
                plt.subplot(248), plt.title('L_idt_B'), plt.imshow(
                    tensor2image_RGB(T_rec_B[0, ...]))

                plt.savefig(
                    os.path.join(checkpoints_pth, 'images',
                                 '%03d_%09d.jpg' % (epoch, i)))

            cur_iter += 1
            #break #-> debug

        ## till now, we finish one epoch, try to update the learning rate
        update_learning_rate(schedulers=scheduler,
                             opt=opt,
                             optimizer=optimizer_D)
        ## save the model
        if epoch % opt.ckp_freq == 0:
            #-> save models
            # torch.save(model.state_dict(), PATH)
            #-> load in models
            # model.load_state_dict(torch.load(PATH))
            # model.eval()
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.cpu()
                G_Zb2b = G_Zb2b.cpu()
                T_Zb2Za = T_Zb2Za.cpu()
                D_b = D_b.cpu()

                E_b2Za = E_b2Za.cpu()
                G_Za2a = G_Za2a.cpu()
                T_Za2Zb = T_Za2Zb.cpu()
                D_a = D_a.cpu()
                '''
                torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
                torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_b.pth' % epoch))
                torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
                torch.save(    D_b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_b.pth' % epoch))

                torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
                torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_a.pth' % epoch))
                torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
                torch.save(    D_a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_a.pth' % epoch))
                '''
            torch.save(
                E_a2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
            torch.save(
                G_Zb2b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch))
            torch.save(
                T_Zb2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
            torch.save(
                D_b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch))

            torch.save(
                E_b2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
            torch.save(
                G_Za2a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch))
            torch.save(
                T_Za2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
            torch.save(
                D_a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch))
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.to(device)
                G_Zb2b = G_Zb2b.to(device)
                T_Zb2Za = T_Zb2Za.to(device)
                D_b = D_b.to(device)

                E_b2Za = E_b2Za.to(device)
                G_Za2a = G_Za2a.to(device)
                T_Za2Zb = T_Za2Zb.to(device)
                D_a = D_a.to(device)
            print("+Successfully saving models in epoch: %3d.-------------+" %
                  epoch)
        #break #-> debug
    record_fh.close()
    print("≧◔◡◔≦ Congratulation! Finishing the training!")
示例#6
0
    def __init__(self, opt):
        """Pix2PIxHD model

        Parameters
        ----------
        opt : ArgumentParsee
            option of this Model. e.g.)  gain, isAffine

        """
        super(Pix2PixHDModel, self).__init__()
        self.opt = opt

        if opt.gpu_ids == 0:
            self.device = torch.device("cuda:0")
        elif opt.gpu_ids == 1:
            self.device = torch.device("cuda:1")
        else:
            self.device = torch.device("cpu")

        # define networks respectively
        input_nc = opt.label_num
        if not opt.no_use_feature:
            input_nc += opt.feature_nc
        if not opt.no_use_edge:
            input_nc += 1
        self.netG = define_G(
            input_nc=input_nc,
            output_nc=opt.output_nc,
            ngf=opt.ngf,
            g_type=opt.g_type,
            device=self.device,
            isAffine=opt.isAffine,
            use_relu=opt.use_relu,
        )

        input_nc = opt.output_nc
        if not opt.no_use_edge:
            input_nc += opt.label_num + 1
        else:
            input_nc += opt.label_num
        self.netD = define_D(
            input_nc=input_nc,
            ndf=opt.ndf,
            n_layers_D=opt.n_layers_D,
            device=self.device,
            isAffine=opt.isAffine,
            num_D=opt.num_D,
        )

        self.netE = define_E(
            input_nc=opt.output_nc,
            feat_num=opt.feature_nc,
            nef=opt.nef,
            device=self.device,
            isAffine=opt.isAffine,
        )

        # define optimizer respectively
        # initialize optimizer G&E
        # if opt.niter_fix_global is True, fix parameters in Global Generator
        if opt.niter_fix_global > 0:
            finetune_list = set()

            params = []
            for key, value in self.netG.named_parameters():
                if key.startswith("model" + str(opt.n_local_enhancers)):
                    params += [value]
                    finetune_list.add(key.split(".")[0])
            print(
                "------------- Only training the local enhancer network (for %d epochs) ------------"
                % opt.niter_fix_global)
            print("The layers that are finetuned are ", sorted(finetune_list))
        else:
            params = list(self.netG.parameters())
        if not self.opt.no_use_feature:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params,
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.scheduler_G = LinearDecayLR(self.optimizer_G,
                                         niter_decay=opt.niter_decay)

        # initialize optimizer D
        # optimizer D
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.scheduler_D = LinearDecayLR(self.optimizer_D,
                                         niter_decay=opt.niter_decay)

        # defin loss functions
        if opt.gpu_ids == 0 or opt.gpu_ids == 1:
            self.Tensor = torch.cuda.FloatTensor
        else:
            self.Tensor = torch.FloatTensor

        self.criterionGAN = GANLoss(self.device,
                                    use_lsgan=not opt.no_lsgan,
                                    tensor=self.Tensor)
        if not self.opt.no_fmLoss:
            self.criterionFM = FMLoss(num_D=opt.num_D,
                                      n_layers=opt.n_layers_D,
                                      lambda_feat=opt.lambda_feat)
        if not self.opt.no_pLoss:
            self.criterionP = PerceptualLoss(
                self.device, lambda_perceptual=opt.lambda_perceptual)
示例#7
0
def train(learning_rate=0.0002, beta1=0.5, epochs=1):

    # parse data from args passed
    data_dir = args.data
    batch_size = args.batch_size
    num_workers = args.num_workers

    #check if data dir exists
    assert os.path.isdir(data_dir), "{} is not a valid directory".format(
        data_dir)
    '''
    # create dataset (transforms are also included in this only)
    print('Loading dataset...')
    dataset = DehazeDataset(data_dir)
    print('Dataset loaded successfully...')
    print('Dataset contains {} distinct datapoints in X(source) & Y(target) domain\n\n'.format(len(dataset)))

    # create custom DataLoader
    dataloader = DataLoader(dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers)
    '''

    # create G, F
    print('Loading Generators(G & F)...')
    G = Generator()
    F = Generator()
    print('Generators(G & F) loaded successfully...')

    # create Dx, Dy
    print('Loading Discriminators(Dx, Dy)...')
    Dx = Discriminator()
    Dy = Discriminator()
    print('Discriminators(Dx, Dy) loaded successfully...')

    # check generator summary

    #summary(G,(3,256,256))
    # OR
    print(G)  # print Generator

    # check discriminator summary

    #summary(Dx,(3,256,256))
    # OR
    print(Dx)  # print Discriminator

    # create 3-loss_functions - Adv_loss, Cycle_consistent_loss, perceptual_loss
    criterionGAN = GANLoss().to(device)  ############## change device
    criterionCycle = nn.L1Loss()
    criterionIdt = nn.L1Loss()

    # create optimizers
    optimizers = []
    optimizer_G = optim.Adam(itertools.chain(G.parameters(), F.parameters()),
                             lr=learning_rate,
                             betas=(beta1, 0.999))
    optimizer_D = optim.Adam(itertools.chain(Dx.parameters(), Dy.parameters()),
                             lr=learning_rate,
                             betas=(beta1, 0.999))
    optimizers.append(optimizer_G)
    optimizers.append(optimizer_D)

    # make dataset ready for training
    data_loader = CustomDatasetLoader()
    dataset = data_loader.load_data()
    print('Number of training images = %d' % len(dataset))

    # iterate over dataset for training
    for epoch in range(
            epochs
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        #epoch_start_time = time.time()  # timer for entire epoch
        #iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, batch in enumerate(dataset):  # inner loop within one epoch
            pass
示例#8
0
def run():
    # Dataset
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])

    dataset = datasets.MNIST('.', transform=transform, download=True)
    dataloader = data.DataLoader(dataset, batch_size=4)
    print("[INFO] Define DataLoader")

    # Define Model
    g = Generator()
    d = Discriminator()
    print("[INFO] Define Model")

    # optimizer, loss
    gan_loss = GANLoss()

    optim_G = optim.Adam(g.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_D = optim.Adam(d.parameters(), lr=0.0002, betas=(0.5, 0.999))
    print('[INFO] Define optimizer and loss')

    # train
    num_epoch = 2

    print('[INFO] Start Training!!')
    for epoch in range(num_epoch):
        total_batch = len(dataloader)

        for idx, (image, _) in enumerate(dataloader):
            d.train()
            g.train()

            # fake image 생성
            noise = torch.randn(4, 100, 1, 1)
            output_fake = g(noise)

            # Loss

            d_loss_fake = gan_loss(d(output_fake.detach()), False)
            d_loss_real = gan_loss(d(image), True)
            d_loss = (d_loss_fake + d_loss_real) / 2

            g_loss = gan_loss(d(output_fake), True)

            # update
            optim_G.zero_grad()
            g_loss.backward()
            optim_G.step()

            optim_D.zero_grad()
            d_loss.backward()
            optim_D.step()

            if ((epoch * total_batch) + idx) % 1000 == 0:
                print(
                    'Epoch [%d/%d], Iter [%d/%d], D_loss: %.4f, G_loss: %.4f' %
                    (epoch, num_epoch, idx + 1, total_batch, d_loss.item(),
                     g_loss.item()))

                save_model('model', 'GAN', g, {'loss': g_loss.item()})
示例#9
0
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
import os
import time
from glob import glob
from collections import OrderedDict
from os import makedirs, environ
from os.path import join, exists, split, isfile
from scipy.misc import imread, imresize, imsave, imrotate

from loss import GANLoss, gram_matrix

cri_gan = GANLoss('gan', 1.0, 0.0)
from model import VGGMOD, SR, Discriminator, compute_gradient_penalty

torch.set_default_dtype(torch.float32)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# some global variables
MODEL_FOLDER = 'model'
SAMPLE_FOLDER = 'sample'

input_dir = 'sr_data/CUFED_128/input'  # original images
ref_dir = 'sr_data/CUFED_128/ref'  # reference images
map_dir = 'sr_data/CUFED_128/map_321'  # texture maps after texture swapping

use_gpu = True
use_train_ref = True
pre_load_img = True
示例#10
0
 def __init__(self, device='cpu'):
     super(RaSGAN, self).__init__(device=device)
     self.criterion = GANLoss(relativistic=True, average=True)
示例#11
0
 def __init__(self, device='cpu'):
     super(RSGAN, self).__init__(device=device, last=None)
     self.criterion = GANLoss(relativistic=True, average=False)
示例#12
0
data_dir = "/content/drive/My Drive/Grasping GAN/processed"
model_dir = "/content/drive/My Drive/Grasping GAN/models"
batch_size = 8
epochs = 1
lr = 0.01

dataset = GraspingDataset(data_dir)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net_g = define_G(3, 3, 64, "batch", False, "normal", 0.02, gpu_id=device)
net_d = define_D(3 + 3, 64, "basic", gpu_id=device)

criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device)

optimizer_g = optim.Adam(net_g.parameters(), lr=lr)
optimizer_d = optim.Adam(net_d.parameters(), lr=lr)

l1_weight = 10

for epoch in range(epochs):
    # train
    for iteration, batch in enumerate(data_loader, 1):
        # forward
        real_a, real_b = batch[0].to(device), batch[1].to(device)
        fake_b = net_g(real_a)