Exemple #1
0
def d_logistic_loss(
    discriminator: Discriminator,
    trues: torch.Tensor,
    fakes: torch.Tensor,
    labels: torch.Tensor,
    alpha: float,
    **kwargs # rlgamma
) -> tuple:
    if 'rlgamma' in kwargs:
        rlgamma = kwargs['rlgamma']
    else:
        rlgamma = 10

    d_fakes = discriminator.forward( fakes, labels, alpha )
    trues.requires_grad_()
    d_trues = discriminator.forward( trues, labels, alpha )
    loss = F.softplus(d_fakes).mean() + F.softplus(-d_trues).mean()
    
    if rlgamma > 0:
        grad = torch.autograd.grad(
            d_trues.sum(),
            trues,
            create_graph=True
        )[0]
        loss += rlgamma / 2 * (grad ** 2).sum(dim=(1, 2, 3)).mean()
    
    return (loss, )
Exemple #2
0
def d_lsgan_loss(
    discriminator: Discriminator,
    trues: torch.Tensor,
    fakes: torch.Tensor,
    labels: torch.Tensor,
    alpha: float,
    **kwargs
) -> tuple:
    """平均二乗誤差(MSELoss)に基づくDiscriminatorの損失を計算する.
    
    識別器に正解データ(訓練画像)からなるバッチと,
    生成器からの偽データからなるバッチをそれぞれ与える.
    それらの出力からMSELossを計算し, 損失の平均を取る.

    Args:
        discriminator (Discriminator): 識別器モデル
        trues (torch.Tensor): 正解データ(訓練画像)からなるバッチ.
        fakes (torch.Tensor): 生成器が生成した偽データからなるバッチ.
        labels (torch.Tensor): 分類に用いるラベルのバッチ.
        alpha (float): Style Mixingの割合.

    Returns:
        torch.Tensor: 損失(スカラ値).
    """
    d_trues = discriminator.forward(trues, labels, alpha)
    d_fakes = discriminator.forward(fakes, labels, alpha)
    
    loss_trues = F.mse_loss(d_trues, torch.ones_like(d_trues))
    loss_fakes = F.mse_loss(d_fakes, torch.zeros_like(d_fakes))
    
    loss = (loss_trues + loss_fakes) / 2
    
    return (loss, )
Exemple #3
0
def g_logistic_loss(
    discriminator: Discriminator,
    fakes: torch.Tensor,
    labels: torch.Tensor,
    alpha: float,
    **kwargs
) -> tuple:
    d_fakes = discriminator.forward(fakes, labels, alpha)
    return (F.softplus(-d_fakes).mean(), )
Exemple #4
0
def g_wgan_loss(
    discriminator: Discriminator,
    fakes: torch.Tensor,
    labels: torch.Tensor,
    alpha: float,
    **kwargs
) -> tuple:
    d_fakes = discriminator.forward( fakes, labels, alpha )
    loss = -d_fakes.mean()
    return (loss, )
Exemple #5
0
def d_wgan_loss(
    discriminator: Discriminator,
    trues: torch.Tensor,
    fakes: torch.Tensor,
    labels: torch.Tensor,
    alpha: float,
    **kwargs
) -> tuple:
    epsilon_drift = 1e-3
    lambda_gp = 10
    
    batch_size = fakes.shape[0]
    d_trues = discriminator.forward(trues, labels, alpha)
    d_fakes = discriminator.forward(fakes, labels, alpha)
    
    loss_wd = d_trues.mean() - d_fakes.mean()
    
    # gradient penalty
    epsilon = torch.rand(
        batch_size,
        1,
        1,
        1,
        dtype=fakes.dtype,
        device=fakes.device
    ) # -> shape : [batch_size, 1, 1, 1]
    intpl = epsilon * fakes + (1 - epsilon) * trues
    intpl.requires_grad_()
    
    f = discriminator.forward(intpl, labels, alpha)
    grad = torch.autograd.grad(f.sum(), intpl, create_graph=True)[0]
    grad_norm = grad.reshape(batch_size, -1).norm(dim=1)
    loss_gp = lambda_gp * ((grad_norm - 1) ** 2).mean()
    
    # drift
    loss_drift = epsilon_drift * (d_trues ** 2).mean()
    
    loss = -loss_wd + loss_gp + loss_drift
    wd = loss_wd.item()
    
    return (loss, wd)
Exemple #6
0
def g_lsgan_loss(
    discriminator: Discriminator,
    fakes: torch.Tensor,
    labels: torch.Tensor,
    alpha: float,
    **kwargs
) -> tuple:
    d_fakes = discriminator.forward(fakes, labels, alpha)
    
    loss = F.mse_loss( d_fakes, torch.ones_like(d_fakes) )
    loss /= 2
    
    return (loss, )
Exemple #7
0
                noise_sample = Variable(noise(len(sample_data), seq_length))

                #Use this line if generator outputs hidden states: dis_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g)
                dis_fake_data = generator.forward(noise_sample, h_g).detach()

                y_pred_fake, (h_d_n, c_d_n) = discriminator(dis_fake_data, h_d)

                loss_fake = loss(y_pred_fake,
                                 torch.zeros([len(sample_data), 1]).cuda())
                loss_fake.backward()

                #Train discriminator on real data
                h_d = discriminator.init_hidden()
                real_data = Variable(sample_data.float()).cuda()
                y_pred_real, (h_d_n,
                              c_d_n) = discriminator.forward(real_data, h_d)

                loss_real = loss(y_pred_real,
                                 torch.ones([len(sample_data), 1]).cuda())
                loss_real.backward()

                d_optimizer.step(
                )  #Updating the weights based on the predictions for both real and fake calculations.

            for g in range(G_rounds):
                #Train Generator
                generator.zero_grad()
                h_g = generator.init_hidden()
                h_d = discriminator.init_hidden()
                noise_sample = Variable(noise(len(sample_data), seq_length))
class Pipeline():
    def __init__(self):
        self.DBpath = 'D:/processed/'
        self.graphPath = ''
        self.modelPath = 'D:/GANSProject/model/'
        self.imgDim = 128
        self.batchSize = 4
        self.numClass = 5
        self.lambdaCls = 1
        self.lambdaRec = 10
        self.lambdaGp = 10

        self.g_skip_step = 5
        self.epochs = 10
        self.epochsDecay = 10
        self.logstep = 10
        self.epochsSave = 2
        self.learningRateD = 0.00001
        self.learningRateG = 0.00001
        self.lrDecaysD = np.linspace(self.learningRateD,0,self.epochs-self.epochsDecay+2)
        self.lrDecaysD = self.lrDecaysD[1:]
        self.lrDecaysG = np.linspace(self.learningRateG,0,self.epochs-self.epochsDecay+2)
        self.lrDecaysG = self.lrDecaysG[1:]
        self.clipD = 0.005
        self.sample_step = 50
        self.save_step = 5000

    def init_model(self):
        #if not os.path.exists(self.modelPath) or os.listdir(self.modelPath) == []:
        # Create the whole training graph
        self.realX = tf.placeholder(tf.float32, [None, self.imgDim, self.imgDim, 3], name="realX")
        self.realLabels = tf.placeholder(tf.float32, [None, self.numClass], name="realLabels")
        self.realLabelsOneHot = tf.placeholder(tf.float32, [None, self.imgDim, self.imgDim, self.numClass], name="realLabelsOneHot")
        self.fakeLabels = tf.placeholder(tf.float32, [None, self.numClass], name="fakeLabels")
        self.fakeLabelsOneHot = tf.placeholder(tf.float32, [None, self.imgDim, self.imgDim, self.numClass], name="fakeLabelsOneHot")
        self.alphagp = tf.placeholder(tf.float32, [], name="alphagp")


        # Initialize the generator and discriminator
        self.Gen = Generator()
        self.Dis = Discriminator()



        # -----------------------------------------------------------------------------------------
        # -----------------------------------Create D training pipeline----------------------------
        # -----------------------------------------------------------------------------------------

        # Create fake image
        self.fakeX = self.Gen.recForward(self.realX, self.fakeLabelsOneHot)
        YSrc_real, YCls_real = self.Dis.forward(self.realX)
        YSrc_fake, YCls_fake = self.Dis.forward(self.fakeX)

        YCls_real = tf.squeeze(YCls_real)  # remove void dimensions
        self.d_loss_real = - tf.reduce_mean(YSrc_real)
        self.d_loss_fake = tf.reduce_mean(YSrc_fake)
        self.d_loss_cls = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.realLabels,logits=YCls_real, name="d_loss_cls")) / self.batchSize




        #TOTAL LOSS
        self.d_loss = self.d_loss_real + self.d_loss_fake + self.lambdaCls * self.d_loss_cls #+ self.d_loss_gp
        vars = tf.trainable_variables()
        self.d_params = [v for v in vars if v.name.startswith('Discriminator/')]
        train_D = tf.train.AdamOptimizer(learning_rate=self.learningRateD, beta1=0.5, beta2=0.999)
        self.train_D_loss = train_D.minimize(self.d_loss, var_list=self.d_params)
        # gvs = self.train_D.compute_gradients(self.d_loss, var_list=self.d_params)
        # capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
        # self.train_D_loss = self.train_D.apply_gradients(capped_gvs)

        #-------------GRADIENT PENALTY---------------------------
        interpolates = self.alphagp * self.realX + (1 - self.alphagp) * self.fakeX
        out,_ = self.Dis.forward(interpolates)
        gradients = tf.gradients(out, [interpolates])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
        _gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0))
        self.d_loss_gp   = self.lambdaGp * _gradient_penalty
        self.train_D_gp = train_D.minimize(self.d_loss_gp, var_list=self.d_params)
        # gvs = self.train_D.compute_gradients(self.d_loss_gp)
        # capped_gvs = [(ClipIfNotNone(grad), var) for grad, var in gvs]
        # self.train_D_gp = self.train_D.apply_gradients(capped_gvs)
        #-------------------------------------------------------------------------------

        #-----------------accuracy--------------------------------------------------------------
        YCls_real_sigmoid = tf.sigmoid(YCls_real)
        predicted = tf.to_int32(YCls_real_sigmoid > 0.5)
        labels = tf.to_int32(self.realLabels)
        correct = tf.to_float(tf.equal(predicted, labels))
        hundred = tf.constant(100.0)
        self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), axis=0) * hundred
        #--------------------------------------------------------------------------------------


        #CLIP D WEIGHTS
        #self.clip_D = [p.assign(tf.clip_by_value(p, -self.clipD, self.clipD)) for p in self.d_params]


        # -----------------------------------------------------------------------------------------
        # ----------------------------Create G training pipeline-----------------------------------
        # -----------------------------------------------------------------------------------------
        #original to target and target to original domain
        #self.fakeX = self.Gen.recForward(self.realX, self.fakeLabelsOneHot)
        rec_x = self.Gen.recForward(self.fakeX,self.realLabelsOneHot)

        # compute losses
        #out_src, out_cls = self.Dis.forward(self.fakeX)
        self.g_loss_adv = - tf.reduce_mean(YSrc_fake)
        self.g_loss_cls = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.fakeLabels,logits=tf.squeeze(YCls_fake))) / self.batchSize

        self.g_loss_rec = tf.reduce_mean(tf.abs(self.realX - rec_x))
        # total G loss and optimize
        self.g_loss = self.g_loss_adv + self.lambdaCls * self.g_loss_cls + self.lambdaRec * self.g_loss_rec
        train_G = tf.train.AdamOptimizer(learning_rate=self.learningRateG, beta1=0.5, beta2=0.999)
        self.g_params = [v for v in vars if v.name.startswith('Generator/')]

        self.train_G_loss = train_G.minimize(self.g_loss, var_list=self.g_params)
        # gvs = self.train_G.compute_gradients(self.g_loss)
        # capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs]
        # self.train_G_loss = self.train_G.apply_gradients(capped_gvs)

        #TF session

        self.saver = tf.train.Saver()
        self.init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(self.init)


        #restore model if it exists
        if os.listdir(self.modelPath) == []:
            self.init = tf.global_variables_initializer()
            self.sess = tf.Session()
            self.sess.run(self.init)
            self.epoch_index = 1
            self.picture = 0


        else:
            self.sess = tf.Session()
            #self.saver = tf.train.import_meta_graph(os.path.join(self.modelPath, "model49999_1.meta"))
            checkpoint = tf.train.latest_checkpoint(self.modelPath)
            self.saver.restore(self.sess, checkpoint)
            #-------------------------------------------------------------------------------------------

            model_info = checkpoint.split("model/model",1)[1].split("_",1)
            self.picture = int(model_info[0])
            self.epoch_index = int(model_info[1])


        #writer = tf.summary.FileWriter(self.graphPath, graph=tf.get_default_graph())

    def train(self):
        # -----------------------------------------------------------------------------------------
        # -----------------------------------START TRAINING----------------------------------------
        # -----------------------------------------------------------------------------------------
        images = []
        trueLabels = []
        gloss = 0
        g_loss_cls = 0
        g_loss_rec = 0
        d_loss_real = 0
        # test image
        img_test = normalize(sci.imread(os.path.join(self.DBpath, "002000_[1, 0, 0, 1, 0].jpg")))
        #002000_[1, 0, 0, 1, 0]
        labels = np.array([0, 1, 0, 1, 0])
        img_test = np.stack([img_test])
        #----------------------------------------------
        # -----------------------------------------------------------------------------------------
        # -----------------------------------BEGIN TRAINING----------------------------------------
        # -----------------------------------------------------------------------------------------
        ##DB
        filenames = os.listdir(self.DBpath)
        for e in range(self.epoch_index,self.epochs):
            train_subset = filenames[self.batchSize*self.picture:]
            for i, filename in enumerate(train_subset):
                img = sci.imread(os.path.join(self.DBpath, filename))
                splits = filename.split('_')
                trueLabels.append(literal_eval(splits[1].split('.')[0]))

                # Normalization and random flip (data augmentation)
                img = normalize(img)
                if random.random() > 0.5:
                    img = np.fliplr(img)

                images.append(img)
                #print(filename, np.mean(img))
                if len(images) % self.batchSize == 0:

                    # Create fake labels and associated images
                    randomLabels = np.random.randint(2, size=(self.batchSize, self.numClass))
                    # realX_fakeLabels has to contain the original image with the labels to generate concatenated
                    imagesWithFakeLabels = stackLabels(images, randomLabels)
                    epsilon = np.random.rand()
                    # -----------------------------------------------------------------------------------------
                    # -----------------------------------TRAIN DISCRIMINATOR-----------------------------------
                    # -----------------------------------------------------------------------------------------
                    #print("training discriminator...")

                    alpha = np.random.uniform(low=0, high=1.0)
                    _ = self.sess.run([self.train_D_loss],
                                            feed_dict={
                                                       self.realLabels: trueLabels,
                                                       self.realLabelsOneHot: stackLabelsOnly(trueLabels),
                                                       self.fakeLabels: randomLabels,
                                                       self.fakeLabelsOneHot: stackLabelsOnly(randomLabels),
                                                       self.realX: np.stack(images),
                                                       self.alphagp: alpha,
                                                       })

                    #GRADIENT PENALTY
                    _ = self.sess.run([self.train_D_gp],
                                            feed_dict={
                                                       self.realLabels: trueLabels,
                                                       self.realLabelsOneHot: stackLabelsOnly(trueLabels),
                                                       self.fakeLabels: randomLabels,
                                                       self.fakeLabelsOneHot: stackLabelsOnly(randomLabels),
                                                       self.realX: np.stack(images),
                                                       self.alphagp: alpha,
                                                       })
                    #CLIPPING
                    # _ = self.sess.run([self.clip_D])



                    if (self.picture+1) % self.g_skip_step == 0:
                    #if np.abs(d_loss_real) > np.abs(g_loss_adv):

                        # -----------------------------------------------------------------------------------------
                        # -----------------------------------TRAIN GENERATOR---------------------------------------
                        # -----------------------------------------------------------------------------------------
                        #print("training generator...")
                        _, = self.sess.run([self.train_G_loss],
                                                feed_dict={
                                                           self.realLabelsOneHot: stackLabelsOnly(trueLabels),
                                                           self.realX: np.stack(images),
                                                           self.fakeLabels: randomLabels,
                                                           self.fakeLabelsOneHot: stackLabelsOnly(randomLabels)})


                    # -----------------------------------------------------------------------------------------
                    # -----------------------------------PRINTING LOSSES---------------------------------------
                    # -----------------------------------------------------------------------------------------
                    #if (self.picture + 1) % self.logstep == 0:
                    print("printing losses...")
                    dloss, d_loss_real, d_loss_fake, d_loss_gp, d_loss_cls, accuracy, gloss, g_loss_adv, g_loss_cls, g_loss_rec = self.sess.run(
                                            [
                                             self.d_loss,
                                             self.d_loss_real,
                                             self.d_loss_fake,
                                             self.d_loss_gp,
                                             self.d_loss_cls,
                                             self.accuracy,
                                             self.g_loss,
                                             self.g_loss_adv,
                                             self.g_loss_cls,
                                             self.g_loss_rec,
                                             ],
                                            feed_dict={
                                                       self.realLabels: trueLabels,
                                                       self.realLabelsOneHot: stackLabelsOnly(trueLabels),
                                                       self.realX: np.stack(images),
                                                       self.fakeLabels: randomLabels,
                                                       self.fakeLabelsOneHot: stackLabelsOnly(randomLabels),
                                                       self.alphagp: alpha
                                                       })



                    #print("Loss = " , dloss + gloss, " ", "Dloss = " , dloss, " ", "Gloss = ", gloss, "Epoch =", e)
                    #print("Dloss = " , dloss, " d_loss_real: ", d_loss_real, " d_loss_fake: ", d_loss_fake, " gradient penalty: ", _gradient_penalty, " d_loss_cls: ", d_loss_cls)
                    #print("YCls_real: ")
                    #print(YCls_real)
                    #print([np.mean(i) for i in gradLoss])
                    print("---------------------------")
                    print(self.batchSize, " Batch: ", self.picture, "Accuracy: ",accuracy, "Dloss = " , dloss, " d_loss_real: ", d_loss_real, " d_loss_fake: ", d_loss_fake, " d_loss_gp: ", d_loss_gp, " d_loss_cls: ", d_loss_cls)
                    print(self.batchSize, " Batch: ", self.picture, "Accuracy: ",accuracy, "Gloss = ", gloss, "g_loss_Adv: ", g_loss_adv, "g_loss_cls: ", g_loss_cls*self.lambdaCls, "g_loss_rec: ", g_loss_rec*self.lambdaRec, "epoch: ", e)

                    #print("Picture: ", i, "Accuracy: ",accuracy, "Loss = " , dloss + gloss, " ", "Dloss = " , dloss, " ", "Gloss = ", gloss, "Epoch =", e)

                    print("---------------------------")

                    # RESET
                    images = []
                    trueLabels = []
                    self.picture+=1
                #save images

                if (self.picture+1) % self.sample_step == 0:
                    generatedImage = np.squeeze(self.sess.run([self.fakeX], feed_dict={self.realX: img_test,self.fakeLabelsOneHot: stackLabelsOnly([labels])}), axis=0)
                    sci.imsave('D:/GANSProject/samples/outfile' + str(self.picture) + "_" + str(e) + '.jpg', denormalize(generatedImage))

                # Save model every 5k batches
                if (self.picture + 1) % self.save_step == 0:
                    if not os.path.exists(self.modelPath):
                        os.makedirs(self.modelPath)
                    save_path = self.saver.save(self.sess, os.path.join(self.modelPath, "model" + str(self.picture) + "_" + str(e)))
                    print("Model saved in file: %s" % save_path)
            #set images index to 0 to start iterating the DB again
            self.picture = 0




    def test(self, labels=None, img=None):
        if not os.path.exists(self.modelPath):
            print("The model does not exit")
            return

        with tf.Session() as sess:
            # Restore model weights from previously saved model
            folder = os.path.dirname(os.path.normpath(self.modelPath))
            saver = tf.train.import_meta_graph(os.path.join(folder,'model.ckpt.meta'))
            saver.restore(sess, tf.train.latest_checkpoint(folder))

            if img is None:
                img = sci.imread(os.path.join(self.DBpath, "000001_[0, 0, 1, 0, 1].jpg"))
            if labels is None:
                labels = np.random.randint(2, size=(1, 5))

            # Get tensors
            recov_fakeX = sess.graph.get_tensor_by_name("fakeX:0")
            recov_realX = sess.graph.get_tensor_by_name("realX:0")
            recov_fakeLabelsOneHot = sess.graph.get_tensor_by_name("fakeLabelsOneHot:0")

            img = np.reshape(img, (1, 128, 128, 3))
            generatedImage = np.squeeze(sess.run([recov_fakeX], feed_dict={recov_realX: np.stack(img),
                                                                                recov_fakeLabelsOneHot: stackLabelsOnly(
                                                                                   labels)}), axis=0)

            # sci.imsave('img.jpg', img)
            # img = normalize(img)
            # sci.imsave('out.jpg', denormalize(img))

            sci.imsave('outfile1.jpg', denormalize(generatedImage))


    def random_samples(self):

        filenames = os.listdir(self.DBpath)
        random_pics_idx = np.random.randint(low=0, high=len(filenames), size=10)
        rows = []
        for e in random_pics_idx:
            img = np.stack([normalize(sci.imread(os.path.join(self.DBpath, filenames[e])))])
            splits = filenames[e].split('_')
            labels = literal_eval(splits[1].split('.')[0])

            row_images = []
            row_images.append(denormalize(np.squeeze(img)))
            for j in range(0,len(labels)):
                fakeLabels = np.copy(labels)
                if j < 3: # hair label
                    if fakeLabels[j] == 0:
                        fakeLabels[0:3] = [0]*3
                        fakeLabels[j] = 1
                else:
                    fakeLabels[j] = 0 if fakeLabels[j]==1 else 1

                generatedImage = np.squeeze(self.sess.run([self.fakeX], feed_dict={self.realX: img,self.fakeLabelsOneHot: stackLabelsOnly([fakeLabels])}), axis=0)
                row_images.append(denormalize(generatedImage))

            row = np.concatenate(row_images, axis=1)
            rows.append(row)
        samples = np.concatenate(rows, axis=0)

        sci.imsave('D:/GANSProject/samples/random_samples.jpg', samples)