def set_network(args):
    style_layers = [4, 9, 16]
    # Pixel2pixel networks
    # netG = UnetGenerator(in_channels=3, num_downs=8)
    net_label = label_Discriminator(in_channels=1, use_sigmoid=False)
    netG = STE()
    netD = Discriminator(in_channels=6, use_sigmoid=False)
    netvgg = models.vgg16(pretrained=True)
    net = get_net(netvgg, style_layers)
    net.collect_params().reset_ctx(ctx)
    # net.collect_params().setattr('grad_req', 'null')
    # Initialize parameters
    netG.initialize(ctx=ctx, init=init.Xavier())
    if args.model:
        netG.collect_params().load(args.model, ctx=ctx)
    netG.collect_params().reset_ctx(ctx)
    network_init(netD)
    net_label.initialize(ctx=ctx, init=mx.initializer.One())
    net_label.collect_params().setattr('grad_req', 'null')

    net_label.collect_params().reset_ctx(ctx)
    # net.collect_params().setattr('grad_req', 'null')
    # net.collect_params().reset_ctx(ctx)
    # trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta})
    # trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta})
    # trainerV = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta})
    # trainerL = gluon.Trainer(net_label.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta})
    trainerG = gluon.Trainer(netG.collect_params(), 'adam',
                             {'learning_rate': args.lr})
    trainerD = gluon.Trainer(netD.collect_params(), 'adam',
                             {'learning_rate': args.lr})
    trainerV = gluon.Trainer(net.collect_params(), 'adam',
                             {'learning_rate': args.lr})
    trainerL = gluon.Trainer(net_label.collect_params(), 'adam',
                             {'learning_rate': args.lr})
    return netG, netD, net, net_label, trainerG, trainerD, trainerV, trainerL
Exemple #2
0
class DiscoGAN(object):
    """
    Class for building the Disco GAN model
    """
    def __init__(self, learning_rate = 0.005, betas=(0.5, 0.999), conv_dim = 64):
        
        self.learning_rate = learning_rate
        self.betas = betas
        self.conv_dim = conv_dim
        self.build_model()
        

    def build_model(self):
        
        self.D_a = Discriminator(dim = self.conv_dim)
        self.D_b = Discriminator(dim = self.conv_dim)
        
        self.G_ab = GeneratorCNN(dim = self.conv_dim)
        self.G_ba = GeneratorCNN(dim = self.conv_dim)        
        
        network_init(self.D_a)
        network_init(self.D_b)
        network_init(self.G_ab)
        network_init(self.G_ba)
        
        
        self.chain_params()
        
        self.trainerD = gluon.Trainer(self.d, 'adam', {'learning_rate': self.learning_rate, 'beta1': self.betas[0], 'beta2': self.betas[1]})
        self.trainerG = gluon.Trainer(self.g, 'adam', {'learning_rate': self.learning_rate, 'beta1': self.betas[0], 'beta2': self.betas[1]})
            
        
        self.recon_criterion = gluon.loss.L2Loss()
        self.gan_criterion = gluon.loss.SigmoidBinaryCrossEntropyLoss()
        self.feat_criterion = gluon.loss.HingeLoss()
        
        
    def chain_params(self):
            
        self.g = {}

        for key, value in self.G_ab.collect_params().items() :
            self.g[key] = value
        for key, value in self.G_ba.collect_params().items() :
            self.g[key] = value
            
        self.d = {}

        for key, value in self.D_a.collect_params().items() :
            self.d[key] = value
        for key, value in self.D_b.collect_params().items() :
            self.d[key] = value
    
    
    def data_loader( self, train_data_dir, batch_size = 256, img_height = 64, img_width = 64):
    
    # Initiate the train and test generators with data Augumentation 
    
        datagen = ImageDataGenerator(
                rescale = 1./255,
                fill_mode = "nearest")


        data_generator = datagen.flow_from_directory(
                train_data_dir,
                target_size = (img_height, img_width),
                batch_size = batch_size, 
                class_mode = None,
                shuffle = True)
    
        return(data_generator)
        
    def get_gan_loss(self, dis_real, dis_fake):
        
        #make sure this is the right axis
        labels_dis_real =  nd.ones(shape=(dis_real.shape[0], 1)) #Variable(torch.ones( [dis_real.size()[0], 1] ))
        labels_dis_fake = nd.zeros(shape=(dis_fake.shape[0], 1))#Variable(torch.zeros([dis_fake.size()[0], 1] ))
        labels_gen = nd.ones(shape=(dis_fake.shape[0], 1))#Variable(torch.ones([dis_fake.size()[0], 1]))

        dis_loss = self.gan_criterion( dis_real, labels_dis_real ) * 0.5 + self.gan_criterion( dis_fake, labels_dis_fake ) * 0.5
        gen_loss = self.gan_criterion( dis_fake, labels_gen )

        return dis_loss, gen_loss
    
    def discriminator_loss(self,dis_real, dis_fake):
            
        labels_dis_real =  nd.ones(shape=(dis_real.shape[0], 1)) 
        labels_dis_fake = nd.zeros(shape=(dis_fake.shape[0], 1))
        dis_loss = self.gan_criterion( dis_real, labels_dis_real ) * 0.5 + self.gan_criterion( dis_fake, labels_dis_fake ) * 0.5
        return dis_loss
    
    def gan_loss(self, dis_fake):
        labels_gen = nd.ones(shape=(dis_fake.shape[0], 1))
        gen_loss = self.gan_criterion( dis_fake, labels_gen )
        return gen_loss
        
        
        

    def get_fm_loss(self, real_feats, fake_feats):
        losses = 0
        for real_feat, fake_feat in zip(real_feats, fake_feats):
            l2 = (real_feat.mean((0)) - fake_feat.mean((0))) * (real_feat.mean((0)) - fake_feat.mean((0)))
            loss = self.feat_criterion( nd.ones(shape=l2.shape), l2).mean()
            losses = losses +loss
        return losses
    


    def generate(self, a,b, path, save = True, epoch = None, num_images = 10, is_calc_loss = False  ):
            
        
        A = a.swapaxes(1,3).swapaxes(2,3)
        B = b.swapaxes(1,3).swapaxes(2,3)
        A, B = nd.array(A), nd.array(B)
        BA = self.G_ab( B )
        AB = self.G_ba( A )
        BAB = self.G_ba( BA )
        ABA = self.G_ab( AB )
        
        if is_calc_loss:
            gen_loss_A_total, gen_loss_B_total, dis_loss_A, dis_loss_B = self.calc_loss(A,B, BA, AB, BAB, ABA)
            
    
        
        
        if epoch is None :
            f_a = path+'A_Epoch_'
            f_b = path+'B_Epoch_'
        else :
            f_a =path +'A_'
            f_b =path +'B_'
     
        for im in range(num_images):      
            if save :
                            
                A_val = A[im]
                B_val = B[im]
                BA_val = BA[im]
                ABA_val = ABA[im]
                AB_val = AB[im]
                BAB_val = BAB[im]
                            
                filename = f_a+str(epoch)+'_im_'+str(im)+'.jpg'
                utils.save_image([A_val.asnumpy(),AB_val.asnumpy(),ABA_val.asnumpy()], filename = filename)
                filename = f_b+str(epoch)+'_im_'+str(im)+'.jpg'
                utils.save_image([B_val.asnumpy(),BA_val.asnumpy(),BAB_val.asnumpy()], filename = filename )
#            else:         
#                A_val = A[im].cpu().data.numpy() 
#                B_val = B[im].cpu().data.numpy() 
#                BA_val = BA[im].cpu().data.numpy()
#                ABA_val = ABA[im].cpu().data.numpy()
#                AB_val = AB[im].cpu().data.numpy()
#                BAB_val = BAB[im].cpu().data.numpy()
#                self.plt.close()
#                            
#                self.plt = self.show_images([A_val,AB_val,ABA_val], [B_val,BA_val,BAB_val])
        if is_calc_loss : 
            return gen_loss_A_total, gen_loss_B_total, dis_loss_A, dis_loss_B
    
    
    def calc_loss(self,x_A,x_B, x_BA, x_AB, x_BAB, x_ABA):
        
        A_dis_real, A_feats_real = self.D_a( x_A )
        A_dis_fake, A_feats_fake = self.D_a( x_BA )
    
        

        dis_loss_A = self.discriminator_loss( A_dis_real, A_dis_fake )

        B_dis_real, B_feats_real = self.D_b( x_B )
        B_dis_fake, B_feats_fake = self.D_b( x_AB )
        
        
        
        dis_loss_B = self.discriminator_loss( B_dis_real, B_dis_fake )

        recon_loss_A = self.recon_criterion( x_ABA, x_A )
        recon_loss_B = self.recon_criterion( x_BAB, x_B )
    

        A_dis_real, A_feats_real = self.D_a( x_A)
        A_dis_fake, A_feats_fake = self.D_a( x_BA )


        gen_loss_A = self.gan_loss( A_dis_fake )  


        fm_loss_A = self.get_fm_loss(A_feats_real, A_feats_fake)


        B_dis_real, B_feats_real = self.D_b( x_B )
        B_dis_fake, B_feats_fake = self.D_b( x_AB )

        gen_loss_B = self.gan_loss( B_dis_fake )  


        fm_loss_B = self.get_fm_loss(B_feats_real, B_feats_fake)
        rate = 0.01
        gen_loss_A_total = (gen_loss_B*0.1 + fm_loss_B*0.9) * (1.-rate) + recon_loss_A * rate
        gen_loss_B_total = (gen_loss_A*0.1 + fm_loss_A*0.9) * (1.-rate) + recon_loss_B * rate
                
    
        return gen_loss_A_total, gen_loss_B_total, dis_loss_A, dis_loss_B
        
        
    def train(self , n_epochs=100, batch_size=64, print_freq = 2, save = True, rate_tup = (0.01,0.5), 
              gan_curriculum = 1000,
              data_dir = '/home/chocolatethunder/Documents/Borealis/PyTorch/Disco_Gan/Data/facades/', 
              path = 'Generate_Images/'):
        
        train_dir_a = data_dir + 'train_A/'
        train_dir_b = data_dir + 'train_B/'
        
        val_dir_a = data_dir + 'val_A/'
        val_dir_b = data_dir + 'val_B/'
        
        path_train = path+'train'
        path_val = path+'val'
        
        val_A_generator = self.data_loader(val_dir_a, batch_size =  10)
        val_B_generator = self.data_loader(val_dir_b, batch_size = 10)
        
        self.G_ab.hybridize()
        self.G_ba.hybridize()
        self.D_a.hybridize()
        self.D_b.hybridize()
        
        
#        labels_real =  nd.ones(shape=(batch_size, 1)) 
#        labels_fake = nd.zeros(shape=(batch_size, 1))
#        
#        acc_a_tot = mx.metric.Accuracy()
#        acc_b_tot = mx.metric.Accuracy()
#        
#        acc_a_real = mx.metric.Accuracy()
#        acc_b_real = mx.metric.Accuracy()
#        
#        acc_a_fake = mx.metric.Accuracy()
#        acc_b_fake = mx.metric.Accuracy()
        
        for epoch in range(n_epochs):
            
            A_generator = self.data_loader(train_dir_a, batch_size = batch_size)
            B_generator = self.data_loader(train_dir_b, batch_size = batch_size)
            
            ctr = 0 
            for data_a, data_b in zip(A_generator, B_generator):    
                
                data_a = data_a.swapaxes(1,3).swapaxes(2,3)
                data_b = data_b.swapaxes(1,3).swapaxes(2,3)
                
                
                x_A, x_B = nd.array(data_a), nd.array(data_b)
                
                x_A.attach_grad()
                x_B.attach_grad()
                
                
                with autograd.record():


                    x_AB = self.G_ba(x_A)
                    x_BA = self.G_ab(x_B)
    
    
                    A_dis_real, A_feats_real = self.D_a( x_A )
                    A_dis_fake, A_feats_fake = self.D_a( x_BA )
                    

#                    acc_a_tot.update(preds = utils.binarize(temp), labels = labels_real)
#                    acc_a_real.update(preds = utils.binarize(A_dis_real), labels = labels_real)
#                    acc_a_tot.update(preds = utils.binarize(A_dis_fake), labels = labels_fake)
#                    acc_a_fake.update(preds = utils.binarize(A_dis_fake), labels = labels_fake)
                    
    
                    dis_loss_A = self.discriminator_loss( A_dis_real, A_dis_fake )

                    B_dis_real, B_feats_real = self.D_b( x_B )
                    B_dis_fake, B_feats_fake = self.D_b( x_AB )
                    
                    
#                    acc_b_tot.update(preds = utils.binarize(B_dis_real), labels = labels_real)
#                    acc_b_real.update(preds = utils.binarize(B_dis_real), labels = labels_real)
#                    acc_b_tot.update(preds = utils.binarize(B_dis_fake), labels = labels_fake)
#                    acc_b_fake.update(preds = utils.binarize(B_dis_fake), labels = labels_fake)
                    
                    
                    dis_loss_B = self.discriminator_loss( B_dis_real, B_dis_fake )
    
                    dis_loss = dis_loss_A + dis_loss_B
                    dis_loss.backward()
                self.trainerD.step(64)

                with autograd.record():


                    x_AB = self.G_ba(x_A)
                    x_BA = self.G_ab(x_B)
    

                    x_ABA = self.G_ab(x_AB)
                    x_BAB = self.G_ba(x_BA)


                    recon_loss_A = self.recon_criterion( x_ABA, x_A )
                    recon_loss_B = self.recon_criterion( x_BAB, x_B )
                
    
                    A_dis_real, A_feats_real = self.D_a( x_A)
                    A_dis_fake, A_feats_fake = self.D_a( x_BA )


                    gen_loss_A = self.gan_loss( A_dis_fake )  
    

                    fm_loss_A = self.get_fm_loss(A_feats_real, A_feats_fake)


                    B_dis_real, B_feats_real = self.D_b( x_B )
                    B_dis_fake, B_feats_fake = self.D_b( x_AB )
    
                    gen_loss_B = self.gan_loss( B_dis_fake )  
    

                    fm_loss_B = self.get_fm_loss(B_feats_real, B_feats_fake)
    
    
                    if ctr < gan_curriculum:
                        rate = rate_tup[0]
                    else:
                        rate = rate_tup[1]
                        
                    gen_loss_A_total = (gen_loss_B*0.1 + fm_loss_B*0.9) * (1.-rate) + recon_loss_A * rate
                    gen_loss_B_total = (gen_loss_A*0.1 + fm_loss_A*0.9) * (1.-rate) + recon_loss_B * rate
                
                    #total loss
                    gen_loss = gen_loss_A_total + gen_loss_B_total
                

                    gen_loss.backward()
                self.trainerG.step(64)
                
                if ctr % 4 == 0 :
                    print( "---------------------")
                    print ("GEN Loss:", gen_loss_A_total.asnumpy().mean(), gen_loss_B_total.asnumpy().mean())
                    print ("DIS Loss:", dis_loss_A.asnumpy().mean(), dis_loss_B.asnumpy().mean())
#                    print("Accuracy:", acc_a_tot.get()[1], acc_b_tot.get()[1])
#                    print("Accuracy of real:", acc_a_real.get()[1], acc_b_real.get()[1])
#                    print("Accuracy of fake:", acc_a_fake.get()[1], acc_b_fake.get()[1])
                
                if ctr == len(A_generator):
                    break
                ctr += 1
                
#            print( "---------------------")
#            print( "Accuracy of Epoch")
#            print("Accuracy:", acc_a_tot.get()[1], acc_b_tot.get()[1])
#            print("Accuracy of real:", acc_a_real.get()[1], acc_b_real.get()[1])
#            print("Accuracy of fake:", acc_a_fake.get()[1], acc_b_fake.get()[1])    
#            
#            acc_a_tot.reset()
#            acc_b_tot.reset()
#            acc_a_real.reset()
#            acc_b_real.reset()
#            acc_a_fake.reset()
#            acc_b_fake.reset()
                
            if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
                
                for data_a, data_b in zip(A_generator, B_generator):
                    self.generate(data_a, data_b, save = save, path = path_train, epoch = epoch)
                    break
                
                ctr = 0 
                
                gen_loss_a_mean = 0
                gen_loss_b_mean = 0
                dis_loss_a_mean = 0
                dis_loss_b_mean = 0
                for val_a, val_b in zip(val_A_generator, val_B_generator):
                    gen_loss_A_total, gen_loss_B_total, dis_loss_A, dis_loss_B = self.generate(val_a, val_b, 
                                            save = save, path = path_val, epoch = epoch, is_calc_loss = True)
                    gen_loss_a_mean +=gen_loss_A_total.asnumpy().mean()
                    gen_loss_b_mean +=gen_loss_B_total.asnumpy().mean()
                    dis_loss_a_mean +=dis_loss_A.asnumpy().mean()
                    dis_loss_b_mean +=dis_loss_B.asnumpy().mean()
                    
                    if ctr == len(val_A_generator):
                        break
                    ctr += 1

                print( "---------------------")
                print ("Val GEN Loss:", gen_loss_a_mean/ctr, gen_loss_b_mean/ctr)
                print ("Val DIS Loss:", dis_loss_a_mean/ctr, dis_loss_b_mean/ctr)
Exemple #3
0
from engine import train_generator
from engine import train_discriminator

device = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()

gen = Generator()
gen.collect_params().initialize(init=init.Normal(sigma=0.02),
                                force_reinit=True,
                                ctx=device)
# noise = random.randn(1, 100, 1, 1)
# output = gen(noise)
# print(output.shape)

dis = Discriminator()
dis.collect_params().initialize(init=init.Normal(sigma=0.02),
                                force_reinit=True,
                                ctx=device)
# noise = random.randn(1, 3, 64, 64)
# output = dis(noise)
# print(output.shape)

loss_fn = SigmoidBCELoss()

dataset = mx.gluon.data.vision.datasets.ImageFolderDataset('input/')
transform_fn = mx.gluon.data.vision.transforms.Compose([
    mx.gluon.data.vision.transforms.Resize(size=(64, 64)),
    mx.gluon.data.vision.transforms.ToTensor(),
    mx.gluon.data.vision.transforms.Normalize(
        mean=[0.49139969, 0.48215842, 0.44653093],
        std=[0.20220212, 0.19931542, 0.20086347]),
])
Exemple #4
0
                                optimizer_params={
                                    'learning_rate': args.lr_default,
                                    'beta1': 0.0,
                                    'beta2': 0.99
                                },
                                kvstore='local')

    # Set a different learning rate for style by setting the lr_mult of 0.01
    for k in generator.collect_params().keys():
        if k.startswith('hybridsequential2'):
            generator.collect_params()[k].lr_mult = 0.01

    discriminator = Discriminator(
        from_rgb_activate=not args.no_from_rgb_activate)
    discriminator.initialize(ctx=context)
    discriminator.collect_params().reset_ctx(context)

    d_optimizer = gluon.Trainer(discriminator.collect_params(),
                                optimizer='adam',
                                optimizer_params={
                                    'learning_rate': args.lr_default,
                                    'beta1': 0.0,
                                    'beta2': 0.99
                                },
                                kvstore='local')

    g_running = StyledGenerator(code_size)
    g_running.initialize(ctx=mx.gpu(0))
    g_running.collect_params().reset_ctx(mx.gpu(0))
    requires_grad(g_running, False)
Exemple #5
0
                              target_gv,
                              is_train=True,
                              segment_length=segment_length)

    ##############################################################################
    # Create Models
    ##############################################################################

    G_A = Generator()
    G_B = Generator()
    D_A = Discriminator()
    D_B = Discriminator()

    G_A.collect_params().initialize(ctx=contexts[0])
    G_B.collect_params().initialize(ctx=contexts[0])
    D_A.collect_params().initialize(ctx=contexts[0])
    D_B.collect_params().initialize(ctx=contexts[0])

    G_A_trainer = gluon.Trainer(
        G_A.collect_params(),
        optimizer='adam',
        optimizer_params={'learning_rate': G_learning_rate})

    G_B_trainer = gluon.Trainer(
        G_B.collect_params(),
        optimizer='adam',
        optimizer_params={'learning_rate': G_learning_rate})

    D_A_trainer = gluon.Trainer(
        D_A.collect_params(),
        optimizer='adam',