Пример #1
0
 def test(self, times):
     self.GeneratorA2B.eval()
     self.GeneratorB2A.eval()
     
     with torch.no_grad():
         utils.saveimage(self.GeneratorA2B(self.fixdataA)[0], times, 'A', self.ID)
         utils.saveimage(self.GeneratorB2A(self.fixdataB)[0], times, 'B', self.ID)
Пример #2
0
def cmd_main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-x', type=int, default=100)
    parser.add_argument('-y', type=int, default=150)
    parser.add_argument('-t', '--text', type=str)
    parser.add_argument('-r', '--rotate', type=float, default=0.0)
    parser.add_argument('-c', '--color', type=rgbaspec, default='#ffffff')
    parser.add_argument('-b', '--bg', type=rgbaspec, default=None)
    parser.add_argument('-n', '--fontname', default='Courier New.ttf')
    parser.add_argument('-z', '--size', type=int, default=64)
    parser.add_argument('-p', '--padbg', type=commapair, default='0,0')
    parser.add_argument('--bgsize', type=commapair)
    parser.add_argument('--outprefix', default="t-")
    parser.add_argument('--outnaming', default=None)
    parser.add_argument('--outtype', default="JPEG")
    parser.add_argument('--inplace', action='store_true')
    parser.add_argument('--textoffset', type=commapair, default='0,0')
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args()

    if args.inplace and args.outnaming:
        print("can't specify both --inplace and --outnaming")
        exit(1)

    if args.inplace:
        outnamefmt = "{}"
    else:
        outnamefmt = args.outnaming or "{outprefix}{pp.name}"

    for fn in args.filenames:
        with Image.open(fn) as img:
            addtext(img,
                    args.x,
                    args.y,
                    args.text,
                    font=args.fontname,
                    size=args.size,
                    rgb=args.color,
                    bg=args.bg,
                    bgsize=args.bgsize,
                    bgpad=args.padbg,
                    textoffset=args.textoffset,
                    rotate=args.rotate)

            outname = formatoutname(fn, outnamefmt, outprefix=args.outprefix)
            saveimage(img, outname, image_format=args.outtype)
Пример #3
0
def cmd_main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--outprefix', default="pan-")
    parser.add_argument('--outnaming',
                        default="{outprefix}{pp.stem}-{seq:05d}{pp.suffix}")
    parser.add_argument('--outtype', default="JPEG")
    parser.add_argument('--pan',
                        type=str,
                        help=panspechelp,
                        action='append',
                        default=[])
    parser.add_argument('--size', type=sizetuple, default=None)
    parser.add_argument('-v', '--verbose', action="count", default=0)
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args()

    if not args.filenames:
        return

    def vprint(*pargs, **kwargs):
        if args.verbose:
            print(*pargs, **kwargs)

    def vvprint(*pargs, **kwargs):
        if args.verbose > 1:
            print(*pargs, **kwargs)

    pans = []
    for s in args.pan:
        pans += gen_panspecs(s)

    for seq, t in enumerate(expand_pans(args.filenames, pans)):
        fname, cropbox = t
        if args.verbose > 1 or (seq % 250) == 0:
            vprint(f"#{seq:05d}, cropping {fname} to {cropbox}")
        with Image.open(fname) as img:
            cropped = img.crop(box=cropbox)
            if args.size is not None and cropped.size != args.size:
                vvprint(f"resizing because cropped size is {cropped.size}")
                cropped = cropped.resize(args.size)
            outname = formatoutname(fname,
                                    args.outnaming,
                                    outprefix=args.outprefix,
                                    seq=seq)
            saveimage(cropped, outname, image_format=args.outtype)
Пример #4
0
    def train(self):
        #### Check build_model ####
        try:
            getattr(self, 'Generator')
            getattr(self, 'Discriminator')
            #            getattr(self, 'criterion')
            getattr(self, 'optimG')
            getattr(self, 'optimD')
        except:
            assert False, 'Not apply build_model'

        #### Check load_dataset ####
        try:
            getattr(self, 'dataloader')
        except:
            assert False, 'Not apply load_dataset'

        self.fix_latent = torch.randn(64, 100, 1, 1).to(device=self.device)

        one = torch.FloatTensor([1]).cuda().mean()
        mone = torch.FloatTensor([-1]).cuda().mean()

        #### Train ####
        for epoch in range(self.Epoch):

            if epoch % 1 == 0:
                utils.saveimage(
                    self.Generator(self.fix_latent).cpu().detach().numpy(),
                    epoch)
            #if epoch < self.step:
            #    continue
            for i, data in enumerate(self.dataloader):
                latent = torch.randn(data[0].size(0), 100, 1,
                                     1).to(device=self.device)

                #### train Discriminator ####

                for p in self.Discriminator.parameters(
                ):  # reset requires_grad
                    p.requires_grad = True  # they are set to False below in netG update
                for p in self.Generator.parameters():
                    p.requires_grad = False

                self.optimD.zero_grad()
                real = data[0].to(self.device)
                fake = self.Generator(latent)
                real_D = self.Discriminator(real)
                fake_D = self.Discriminator(fake)

                loss_real = real_D.mean()
                loss_real.backward(mone)

                loss_fake = fake_D.mean()
                loss_fake.backward(one)

                gradient_penalty = self.calc_gradient_penalty(
                    self.Discriminator, real, fake)
                gradient_penalty.backward()
                lossD = loss_fake - loss_real + gradient_penalty

                self.optimD.step()

                #### train Generator ####
                for p in self.Discriminator.parameters(
                ):  # reset requires_grad
                    p.requires_grad = False  # they are set to False below in netG update
                for p in self.Generator.parameters():
                    p.requires_grad = True

                latent = torch.randn(data[0].size(0), 100, 1,
                                     1).to(device=self.device)
                fake = self.Generator(latent)
                fake_D = self.Discriminator(fake)

                self.optimG.zero_grad()
                lossG = fake_D.mean()
                lossG.backward(mone)
                self.optimG.step()

                self.step += 1

                utils.PresentationExperience(epoch,
                                             i,
                                             100,
                                             lossG=-lossG.item(),
                                             lossD=lossD.item())
            if epoch % 20 == 19:
                torch.save(self.Generator, 'Generator.pkl')
                torch.save(self.Discriminator, 'Discriminator.pkl')
Пример #5
0
    def train(self):
        self.fixdataA = utils.make_fix_img(self.fixdataA_idx, self.test_datasetA).to(device = self.device)
        self.fixdataB = utils.make_fix_img(self.fixdataB_idx, self.test_datasetB).to(device = self.device)
        utils.saveimage(self.fixdataA, 0, 'A', self.ID)
        utils.saveimage(self.fixdataB, 0, 'B', self.ID)
#        self.FreezeD()


        
        for times in range(self.total_iteration//self.check_iteration):
            self._train()

            if times >= (self.total_iteration//self.check_iteration)//2:
                self.optimG.param_groups[0]['lr'] -= self.lr/((self.total_iteration//self.check_iteration)//2)
                self.optimD.param_groups[0]['lr'] -= self.lr/((self.total_iteration//self.check_iteration)//2)

            if times < self.times:
                continue
            pbar = tqdm.tqdm(range(self.check_iteration), total = self.check_iteration)

            for step in pbar:
                self.optimD.zero_grad()
                
                realA, realB = self._next()
                fakeB, _ = self.GeneratorA2B(realA)
                fakeA, _ = self.GeneratorB2A(realB)
                
                realLA, realLA_CAM = self.DiscriminatorLA(realA)
                realGA, realGA_CAM = self.DiscriminatorGA(realA)
                
                realLB, realLB_CAM = self.DiscriminatorLB(realB)
                realGB, realGB_CAM = self.DiscriminatorGB(realB)
            
                fakeLA, fakeLA_CAM = self.DiscriminatorLA(fakeA)
                fakeGA, fakeGA_CAM = self.DiscriminatorGA(fakeA)
                
                fakeLB, fakeLB_CAM = self.DiscriminatorLB(fakeB)
                fakeGB, fakeGB_CAM = self.DiscriminatorGB(fakeB)
                
                Adversarial_Loss_A = self.MSELoss(realLA, torch.ones(realLA.shape).to(device = self.device)) + self.MSELoss(realGA, torch.ones(realGA.shape).to(device = self.device)) + self.MSELoss(fakeLA, torch.zeros(fakeLA.shape).to(device = self.device)) + self.MSELoss(fakeGA, torch.zeros(fakeGA.shape).to(device = self.device))
                Adversarial_Loss_B = self.MSELoss(realLB, torch.ones(realLB.shape).to(device = self.device)) + self.MSELoss(realGB, torch.ones(realGB.shape).to(device = self.device)) + self.MSELoss(fakeLB, torch.zeros(fakeLB.shape).to(device = self.device)) + self.MSELoss(fakeGB, torch.zeros(fakeGB.shape).to(device = self.device))
                
                Ad_CAM_Loss_A      = self.MSELoss(realLA_CAM, torch.ones(realLA_CAM.shape).to(device = self.device)) + self.MSELoss(realGA_CAM, torch.ones(realGA_CAM.shape).to(device = self.device)) + self.MSELoss(fakeLA_CAM, torch.zeros(fakeLA_CAM.shape).to(device = self.device)) + self.MSELoss(fakeGA_CAM, torch.zeros(fakeGA_CAM.shape).to(device = self.device))
                Ad_CAM_Loss_B      = self.MSELoss(realLB_CAM, torch.ones(realLB_CAM.shape).to(device = self.device)) + self.MSELoss(realGB_CAM, torch.ones(realGB_CAM.shape).to(device = self.device)) + self.MSELoss(realLB_CAM, torch.zeros(realLB_CAM.shape).to(device = self.device)) + self.MSELoss(fakeGB_CAM, torch.zeros(fakeGB_CAM.shape).to(device = self.device))
                            
                Discriminator_Loss_A = Adversarial_Loss_A + Ad_CAM_Loss_A
                Discriminator_Loss_B = Adversarial_Loss_B + Ad_CAM_Loss_B
                
                Discriminator_Loss = self.weight[0] * (Discriminator_Loss_A + Discriminator_Loss_B)
                Discriminator_Loss.backward()
                self.optimD.step()
                
                del(fakeB, fakeA, realLA, realLA_CAM, realGA, realGA_CAM, realLB, realLB_CAM, realGB, realGB_CAM, fakeLA, fakeLA_CAM, fakeGA, fakeGA_CAM, fakeLB, fakeLB_CAM, fakeGB, fakeGB_CAM)
                
                self.optimG.zero_grad()
                
                fakeB, fakeB_CAM_gen = self.GeneratorA2B(realA)
                fakeA, fakeA_CAM_gen = self.GeneratorB2A(realB)
                
                reconA, _ = self.GeneratorB2A(fakeB)
                reconB, _ = self.GeneratorA2B(fakeA)
                
                fakeB2B, fakeB2B_CAM_gen = self.GeneratorA2B(realB)
                fakeA2A, fakeA2A_CAM_gen = self.GeneratorB2A(realA)
                
            
                fakeLA, fakeLA_CAM = self.DiscriminatorLA(fakeA)
                fakeGA, fakeGA_CAM = self.DiscriminatorGA(fakeA)
                
                fakeLB, fakeLB_CAM = self.DiscriminatorLB(fakeB)
                fakeGB, fakeGB_CAM = self.DiscriminatorGB(fakeB)
                
                Adversarial_Loss_A = self.MSELoss(fakeLA, torch.ones(fakeLA.shape).to(device = self.device)) + self.MSELoss(fakeGA, torch.ones(fakeGA.shape).to(device = self.device))
                Adversarial_Loss_B = self.MSELoss(fakeLB, torch.ones(fakeLB.shape).to(device = self.device)) + self.MSELoss(fakeGB, torch.ones(fakeGB.shape).to(device = self.device))
                
                Ad_CAM_Loss_A = self.MSELoss(fakeLA_CAM, torch.ones(fakeLA_CAM.shape).to(device = self.device)) + self.MSELoss(fakeGA_CAM, torch.ones(fakeGA_CAM.shape).to(device = self.device))
                Ad_CAM_Loss_B = self.MSELoss(fakeLB_CAM, torch.ones(fakeLB_CAM.shape).to(device = self.device)) + self.MSELoss(fakeGB_CAM, torch.ones(fakeGB_CAM.shape).to(device = self.device))
                
                
                Cycle_Loss_A = self.L1Loss(reconA, realA)
                Cycle_Loss_B = self.L1Loss(reconB, realB)
                Identity_Loss_A = self.L1Loss(fakeA2A, realA)
                Identity_Loss_B = self.L1Loss(fakeB2B, realB)
                
                G_CAM_Loss_A = self.BCELoss(fakeB_CAM_gen, torch.ones(fakeB_CAM_gen.shape).to(device = self.device)) + self.BCELoss(fakeB2B_CAM_gen, torch.zeros(fakeB2B_CAM_gen.shape).to(device = self.device))
                G_CAM_Loss_B = self.BCELoss(fakeA_CAM_gen, torch.ones(fakeA_CAM_gen.shape).to(device = self.device)) + self.BCELoss(fakeA2A_CAM_gen, torch.zeros(fakeA2A_CAM_gen.shape).to(device = self.device))
                
                Generator_Loss_A = self.weight[0] * (Adversarial_Loss_A + Ad_CAM_Loss_A) + self.weight[1] * Cycle_Loss_A + self.weight[2] * Identity_Loss_A + self.weight[3] * G_CAM_Loss_A
                Generator_Loss_B = self.weight[0] * (Adversarial_Loss_B + Ad_CAM_Loss_B) + self.weight[1] * Cycle_Loss_B + self.weight[2] * Identity_Loss_B + self.weight[3] * G_CAM_Loss_B
#                 Generator_vgg_Loss_A = self.vggloss(realA, fakeA)
#                 Generator_vgg_Loss_B = self.vggloss(realB, fakeB)
#                 Generator_Loss_A += self.weight[4] * Generator_vgg_Loss_A
#                 Generator_Loss_B += self.weight[4] * Generator_vgg_Loss_B
    
    
                Generator_TV_Loss_A = self.tvloss(fakeA)
                Generator_TV_Loss_B = self.tvloss(fakeB)             
                Generator_Loss_A += Generator_TV_Loss_A
                Generator_Loss_B += Generator_TV_Loss_B 
                
                
                Generator_Loss = Generator_Loss_A + Generator_Loss_B


                Generator_Loss.backward()
                self.optimG.step()
                del(fakeB, fakeB_CAM_gen, fakeA, fakeA_CAM_gen, reconA, reconB, fakeB2B, fakeB2B_CAM_gen, fakeA2A, fakeA2A_CAM_gen)
                
                self.GeneratorA2B.apply(self.RhoClipper)
                self.GeneratorB2A.apply(self.RhoClipper)
                
#                msg = '[{:03}/{:03}] [Generator A : {:.3f} | B : {:.3f}] [Discriminator A : {:.3f} | B : {:.3f}]'.format(times, self.total_iteration//self.check_iteration, Generator_Loss_A.item(), Generator_Loss_B.item(), Discriminator_Loss_A.item(), Discriminator_Loss_B.item())
                msg = '[{:03}/{:03}] [ModelA G : {:.3f} | D : {:.3f}] [ModelB G : {:.3f} | D : {:.3f}]'.format(times, self.total_iteration//self.check_iteration, Generator_Loss_A.item(), Discriminator_Loss_A.item(), Generator_Loss_B.item(), Discriminator_Loss_B.item())

                pbar.set_description_str(msg)
                
            self.times = times + 1
            self.test(times = self.times)
            self.save()
Пример #6
0
def inference(config):

    # Load test dataset
    x_test, _, nameList = dataset.read_data(config.PATH_DATASET_TEST, None,
                                            (0, 0), True)

    N, h, w, c = x_test.shape

    if not h == w and N == 0:
        return

    modes = [key for (key, value) in config.IS_MODEL.items() if value == True]

    mode_dict = {}
    saver_dict = {}

    tf.reset_default_graph()

    # Create model
    for mode in modes:
        network_model = config.NETWORK_MODEL[mode]
        network_layer_size = config.NETWORK_LAYER_SIZE[mode]

        mode_dict[str(mode)] = utils.build_model(mode, network_model, (h, w),
                                                 network_layer_size,
                                                 config.NETWORK_FEATURE_SIZE,
                                                 c)

        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope=mode)
        saver_dict[str(mode)] = tf.train.Saver(var_list)

    sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
        allow_growth=True)))
    init = tf.global_variables_initializer()
    sess.run(init)

    # Load weights
    for mode in modes:
        checkpoint_path = config.PATH_WEIGHT + '\\%s' % (mode)
        saver_dict[str(mode)].restore(
            sess, tf.train.latest_checkpoint(checkpoint_path))
        print('>%s Restored!' % (mode))

    # run
    with sess as se:

        for i in range(N):
            image = x_test[i].reshape(1, h, w, c)
            time_start = time.time()

            if modes.count('denoise'):
                image = se.run([mode_dict['denoise'].out],
                               feed_dict={mode_dict['denoise'].input: image})
                image = image[0]
            if modes.count('deblur'):
                image = se.run([mode_dict['deblur'].out],
                               feed_dict={mode_dict['deblur'].input: image})
                image = image[0]
            if modes.count('SR'):
                crop_diff = (h // 2 // 2, w // 2 // 2)
                image = image[:, crop_diff[0]:crop_diff[0] + h // 2,
                              crop_diff[1]:crop_diff[1] + w // 2, :]
                image = se.run([mode_dict['SR'].out],
                               feed_dict={mode_dict['SR'].input: image})

            out_images = np.array(image, dtype=np.uint8)
            #_, res_h, res_w, res_c = out_images[0].shape
            #out_images = out_images.reshape(res_h,res_w,res_c)
            out_images = out_images.reshape(h, w, c)

            time_spand = time.time() - time_start

            file_name = '{}_Result'.format(nameList[i])
            path = config.PATH_RESULT + '/%s' % "_".join(
                [str(x) for x in modes])
            utils.saveimage(out_images, path, file_name)
            print('> Inference image : %s(%dx%dx%d) [%0.4f]' %
                  (file_name, h, w, c, time_spand))