Esempio n. 1
0
    def build_whole_net_with_loss(self):

        i_t, i_s = self.i_t, self.i_s
        t_sk, t_t, t_b, t_f, mask_t = self.t_sk, self.t_t, self.t_b, self.t_f, self.mask_t
        inputs = [i_t, i_s]
        labels = [t_sk, t_t, t_b, t_f]

        o_sk, o_t, o_b, o_f = self.build_generator(inputs)
        self.o_sk = tf.identity(o_sk, name='o_sk')
        self.o_t = tf.identity(o_t, name='o_t')
        self.o_b = tf.identity(o_b, name='o_b')
        self.o_f = tf.identity(o_f, name='o_f')

        i_db_true = tf.concat([t_b, i_s], axis=-1, name='db_true_concat')
        i_db_pred = tf.concat([o_b, i_s], axis=-1, name='db_pred_concat')
        i_db = tf.concat([i_db_true, i_db_pred], axis=0, name='db_concat')

        i_df_true = tf.concat([t_f, i_t], axis=-1, name='df_true_concat')
        i_df_pred = tf.concat([o_f, i_t], axis=-1, name='df_pred_concat')
        i_df = tf.concat([i_df_true, i_df_pred], axis=0, name='df_concat')

        o_db = self.build_discriminator(i_db, name='db')
        o_df = self.build_discriminator(i_df, name='df')

        i_vgg = tf.concat([t_f, o_f], axis=0, name='vgg_concat')

        vgg_graph_def = tf.GraphDef()
        vgg_graph_path = cfg.vgg19_weights
        with open(vgg_graph_path, 'rb') as f:
            vgg_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(vgg_graph_def,
                                    input_map={"inputs:0": i_vgg})
        with tf.Session() as sess:
            o_vgg_1 = sess.graph.get_tensor_by_name(
                "import/block1_conv1/Relu:0")
            o_vgg_2 = sess.graph.get_tensor_by_name(
                "import/block2_conv1/Relu:0")
            o_vgg_3 = sess.graph.get_tensor_by_name(
                "import/block3_conv1/Relu:0")
            o_vgg_4 = sess.graph.get_tensor_by_name(
                "import/block4_conv1/Relu:0")
            o_vgg_5 = sess.graph.get_tensor_by_name(
                "import/block5_conv1/Relu:0")

        out_g = [o_sk, o_t, o_b, o_f, mask_t]
        out_d = [o_db, o_df]
        out_vgg = [o_vgg_1, o_vgg_2, o_vgg_3, o_vgg_4, o_vgg_5]

        db_loss = build_discriminator_loss(o_db, name='db_loss')
        df_loss = build_discriminator_loss(o_df, name='df_loss')
        self.d_loss_detail = [db_loss, df_loss]
        self.d_loss = tf.add(db_loss, df_loss, name='d_loss')
        self.g_loss, self.g_loss_detail = build_generator_loss(out_g,
                                                               out_d,
                                                               out_vgg,
                                                               labels,
                                                               name='g_loss')
Esempio n. 2
0
def main():
    
    os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu)
    
    train_name = get_train_name()
    
    print_log('Initializing SRNET', content_color = PrintColor['yellow'])
    
    train_data = datagen_srnet(cfg)
    
    train_data = DataLoader(dataset = train_data, batch_size = cfg.batch_size, shuffle = False, collate_fn = custom_collate,  pin_memory = True)
    
    trfms = To_tensor()
    example_data = example_dataset(transform = trfms)
        
    example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False)
    
    print_log('training start.', content_color = PrintColor['yellow'])
        
    G = Generator(in_channels = 3).cuda()
    
    D1 = Discriminator(in_channels = 6).cuda()
    
    D2 = Discriminator(in_channels = 6).cuda()
        
    vgg_features = Vgg19().cuda()    
        
    G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))

    #g_scheduler = torch.optim.lr_scheduler.MultiStepLR(G_solver, milestones=[30, 200], gamma=0.5)
    
    #d1_scheduler = torch.optim.lr_scheduler.MultiStepLR(D1_solver, milestones=[30, 200], gamma=0.5)
    
    #d2_scheduler = torch.optim.lr_scheduler.MultiStepLR(D2_solver, milestones=[30, 200], gamma=0.5)

    try:
    
      checkpoint = torch.load(cfg.ckpt_path)
      G.load_state_dict(checkpoint['generator'])
      D1.load_state_dict(checkpoint['discriminator1'])
      D2.load_state_dict(checkpoint['discriminator2'])
      G_solver.load_state_dict(checkpoint['g_optimizer'])
      D1_solver.load_state_dict(checkpoint['d1_optimizer'])
      D2_solver.load_state_dict(checkpoint['d2_optimizer'])
      
      '''
      g_scheduler.load_state_dict(checkpoint['g_scheduler'])
      d1_scheduler.load_state_dict(checkpoint['d1_scheduler'])
      d2_scheduler.load_state_dict(checkpoint['d2_scheduler'])
      '''

      print('Resuming after loading...')

    except FileNotFoundError:

      print('checkpoint not found')
      pass  

    requires_grad(G, False)

    requires_grad(D1, True)
    requires_grad(D2, True)


    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0
        
    
    trainiter = iter(train_data)
    example_iter = iter(example_loader)
    
    K = torch.nn.ZeroPad2d((0, 1, 1, 0))

    for step in tqdm(range(cfg.max_iter)):
        
        D1_solver.zero_grad()
        D2_solver.zero_grad()
        
        if ((step+1) % cfg.save_ckpt_interval == 0):
            
            torch.save(
                {
                    'generator': G.state_dict(),
                    'discriminator1': D1.state_dict(),
                    'discriminator2': D2.state_dict(),
                    'g_optimizer': G_solver.state_dict(),
                    'd1_optimizer': D1_solver.state_dict(),
                    'd2_optimizer': D2_solver.state_dict(),
                    #'g_scheduler' : g_scheduler.state_dict(),
                    #'d1_scheduler':d1_scheduler.state_dict(),
                    #'d2_scheduler':d2_scheduler.state_dict(),
                },
                cfg.checkpoint_savedir+f'train_step-{step+1}.model',
            )
        
        try:

          i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next()

        except StopIteration:

          trainiter = iter(train_data)
          i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next()
                
        i_t = i_t.cuda()
        i_s = i_s.cuda()
        t_sk = t_sk.cuda()
        t_t = t_t.cuda()
        t_b = t_b.cuda()
        t_f = t_f.cuda()
        mask_t = mask_t.cuda()
                
        #inputs = [i_t, i_s]
        labels = [t_sk, t_t, t_b, t_f]
        
        o_sk, o_t, o_b, o_f = G(i_t, i_s)
        
        o_sk = K(o_sk)
        o_t = K(o_t)
        o_b = K(o_b)
        o_f = K(o_f)
                
        #print(o_sk.shape, o_t.shape, o_b.shape, o_f.shape)
        #print('------')
        #print(i_s.shape)
        
        i_db_true = torch.cat((t_b, i_s), dim = 1)
        i_db_pred = torch.cat((o_b, i_s), dim = 1)
        
        i_df_true = torch.cat((t_f, i_t), dim = 1)
        i_df_pred = torch.cat((o_f, i_t), dim = 1)
        
        o_db_true = D1(i_db_true)
        o_db_pred = D1(i_db_pred)
        
        o_df_true = D2(i_df_true)
        o_df_pred = D2(i_df_pred)
        
        i_vgg = torch.cat((t_f, o_f), dim = 0)
        
        out_vgg = vgg_features(i_vgg)
        
        db_loss = build_discriminator_loss(o_db_true,  o_db_pred)
        
        df_loss = build_discriminator_loss(o_df_true, o_df_pred)
                
        db_loss.backward()
        df_loss.backward()
        
        D1_solver.step()
        D2_solver.step()
        
        #d1_scheduler.step()
        #d2_scheduler.step()
        
        clip_grad(D1)
        clip_grad(D2)
        
        
        if ((step+1) % 5 == 0):
            
            requires_grad(G, True)

            requires_grad(D1, False)
            requires_grad(D2, False)
            
            G_solver.zero_grad()
            
            o_sk, o_t, o_b, o_f = G(i_t, i_s)
            
            o_sk = K(o_sk)
            o_t = K(o_t)
            o_b = K(o_b)
            o_f = K(o_f)

            #print(o_sk.shape, o_t.shape, o_b.shape, o_f.shape)
            #print('------')
            #print(i_s.shape)

            i_db_true = torch.cat((t_b, i_s), dim = 1)
            i_db_pred = torch.cat((o_b, i_s), dim = 1)

            i_df_true = torch.cat((t_f, i_t), dim = 1)
            i_df_pred = torch.cat((o_f, i_t), dim = 1)

            o_db_pred = D1(i_db_pred)

            o_df_pred = D2(i_df_pred)

            i_vgg = torch.cat((t_f, o_f), dim = 0)

            out_vgg = vgg_features(i_vgg)
            
            out_g = [o_sk, o_t, o_b, o_f, mask_t]
        
            out_d = [o_db_pred, o_df_pred]
        
            g_loss, detail = build_generator_loss(out_g, out_d, out_vgg, labels)    
                
            g_loss.backward()
            
            G_solver.step()
            
            #g_scheduler.step()
                        
            requires_grad(G, False)

            requires_grad(D1, True)
            requires_grad(D2, True)
            
        if ((step+1) % cfg.write_log_interval == 0):
            
            print('Iter: {}/{} | Gen: {} | D_bg: {} | D_fus: {}'.format(step+1, cfg.max_iter, g_loss.item(), db_loss.item(), df_loss.item()))
            
        if ((step+1) % cfg.gen_example_interval == 0):
            
            savedir = os.path.join(cfg.example_result_dir, train_name, 'iter-' + str(step+1).zfill(len(str(cfg.max_iter))))
            
            with torch.no_grad():

                try:

                  inp = example_iter.next()
                
                except StopIteration:

                  example_iter = iter(example_loader)
                  inp = example_iter.next()
                
                i_t = inp[0].cuda()
                i_s = inp[1].cuda()
                name = str(inp[2][0])
                
                o_sk, o_t, o_b, o_f = G(i_t, i_s)

                o_sk = o_sk.squeeze(0).to('cpu')
                o_t = o_t.squeeze(0).to('cpu')
                o_b = o_b.squeeze(0).to('cpu')
                o_f = o_f.squeeze(0).to('cpu')
                
                if not os.path.exists(savedir):
                    os.makedirs(savedir)
                
                o_sk = F.to_pil_image(o_sk)
                o_t = F.to_pil_image((o_t + 1)/2)
                o_b = F.to_pil_image((o_b + 1)/2)
                o_f = F.to_pil_image((o_f + 1)/2)
                               
                o_f.save(os.path.join(savedir, name + 'o_f.png'))
                
                o_sk.save(os.path.join(savedir, name + 'o_sk.png'))
                o_t.save(os.path.join(savedir, name + 'o_t.png'))
                o_b.save(os.path.join(savedir, name + 'o_b.png'))
Esempio n. 3
0
def main():

    os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu)

    train_name = get_train_name()

    print_log('Initializing SRNET', content_color=PrintColor['yellow'])
    model = SRNet(shape=cfg.data_shape, name=train_name)
    print_log('model compiled.', content_color=PrintColor['yellow'])

    train_data = datagen_srnet(cfg)

    train_data = DataLoader(dataset=train_data,
                            batch_size=cfg.batch_size,
                            shuffle=False,
                            collate=custom_collate,
                            pin_memory=True)

    trfms = To_tensor()
    example_data = example_dataset(transform=trfms)

    example_loader = DataLoader(dataset=example_data,
                                batch_size=len(example_data),
                                shuffle=False)

    print_log('training start.', content_color=PrintColor['yellow'])

    G = Generator(in_channels=3).cuda()

    D1 = discriminator(in_channels=6).cuda()

    D2 = discriminator(in_channels=6).cuda()

    vgg_features = Vgg19().cuda()

    G_solver = torch.optim.Adam(G.parameters(),
                                lr=cfg.learning_rate,
                                betas=(cfg.beta1, cfg.beta2))
    D1_solver = torch.optim.Adam(D1.parameters(),
                                 lr=cfg.learning_rate,
                                 betas=(cfg.beta1, cfg.beta2))
    D2_solver = torch.optim.Adam(D2.parameters(),
                                 lr=cfg.learning_rate,
                                 betas=(cfg.beta1, cfg.beta2))

    g_scheduler = torch.optim.lr_scheduler.MultiStepLR(G_solver,
                                                       milestones=[30, 200],
                                                       gamma=0.5)

    d1_scheduler = torch.optim.lr_scheduler.MultiStepLR(D1_solver,
                                                        milestones=[30, 200],
                                                        gamma=0.5)

    d2_scheduler = torch.optim.lr_scheduler.MultiStepLR(D2_solver,
                                                        milestones=[30, 200],
                                                        gamma=0.5)

    requires_grad(G, False)

    requires_grad(D1, True)
    requires_grad(D2, True)

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0

    trainiter = iter(train_data)
    example_iter = iter(example_loader)

    for step in tqdm(range(cfg.max_iter)):

        D1.zero_grad()
        D2.zero_grad()

        if ((step + 1) % save_ckpt_interval == 0):

            torch.save(
                {
                    'generator': G.module.state_dict(),
                    'discriminator1': D1.module.state_dict(),
                    'discriminator2': D2.module.state_dict(),
                    'g_optimizer': G_solver.state_dict(),
                    'd1_optimizer': D1_solver.state_dict(),
                    'd2_optimizer': D2_solver.state_dict(),
                },
                f'checkpoint/train_step-{step+1}.model',
            )

        i_t, i_s, t_sk, t_t, t_b, t_f, mask_t = trainiter.next()

        inputs = [i_t, i_s]
        labels = [t_sk, t_t, t_b, t_f]

        o_sk, o_t, o_b, o_f = G(inputs)

        i_db_true = torch.cat((t_b, i_s), dim=1)
        i_db_pred = torch.cat((o_b, i_s), dim=1)

        i_df_true = torch.cat((t_f, i_t), dim=1)
        i_df_pred = torch.cat((o_f, i_t), dim=1)

        o_db_true = D1(i_db_true)
        o_db_pred = D1(i_db_pred)

        o_df_true = D2(i_df_true)
        o_df_pred = D2(i_df_pred)

        i_vgg = torch.cat((t_f, o_f), dim=0)

        out_vgg = vgg_features(i_vgg)

        db_loss = build_discriminator_loss(o_db_true, o_db_pred)

        df_loss = build_discriminator_loss(o_df_true, o_df_pred)

        out_g = [o_sk, o_t, o_b, o_f, mask_t]

        out_d = [o_db_true, o_db_pred, o_df_true, o_df_pred]

        g_loss, detail = build_generator_loss(out_g, out_d, out_vgg, labels)

        db_loss.backward()
        df_loss.backward()

        D1_solver.step()

        D2_solver.step()

        d1_scheduler.step()
        d2_scheduler.step()

        clip_grad(D1)
        clip_grad(D2)

        if ((step + 1) % 5 == 0):

            requires_grad(G, True)

            requires_grad(D1, False)
            requires_grad(D2, False)

            G.zero_grad()

            g_loss.backward()

            G_solver.step()

            g_scheduler.step()

            clip_grad(G)

            requires_grad(G, False)

            requires_grad(D1, True)
            requires_grad(D2, True)

        if ((step + 1) % cfg.write_log_interval == 0):

            print('Iter: {}/{} | Gen: {} | D_bg: {} | D_fus: {}'.format(
                step + 1, cfg.max_iter, g_loss.item(), db_loss.item(),
                df_loss.item()))

        if ((step + 1) % gen_example_interval == 0):

            savedir = os.path.join(
                cfg.example_result_dir, train_name,
                'iter-' + str(step + 1).zfill(len(str(cfg.max_iter))))

            with torch.no_grad():

                inp = example_iter.next()

                o_sk, o_t, o_b, o_f = G(inp)

                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)

                o_sk = skimage.img_as_ubyte(o_sk)
                o_t = skimage.img_as_ubyte(o_t + 1)
                o_b = skimage.img_as_ubyte(o_b + 1)
                o_f = skimage.img_as_ubyte(o_f + 1)

                io.imsave(os.path.join(save_dir, name + 'o_f.png'), o_f)

                io.imsave(os.path.join(save_dir, name + 'o_sk.png'), o_sk)
                io.imsave(os.path.join(save_dir, name + 'o_t.png'), o_t)
                io.imsave(os.path.join(save_dir, name + 'o_b.png'), o_b)