Beispiel #1
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          data_path, extension, img_size, latent_dim, learning_rate, beta1,
          beta2, enable):

    # Dataset Definition
    dataloader = DataLoader(data_path, extension, img_size, latent_dim)
    print(dataloader)
    color_valid, line_valid = dataloader(validsize, mode="valid")
    noise_valid = dataloader.noise_generator(validsize)

    # Model Definition
    if enable:
        encoder = Encoder()
        encoder.to_gpu()
        enc_opt = set_optimizer(encoder)

    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator, learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2)

    # Loss Funtion Definition
    lossfunc = GauGANLossFunction()

    # Evaluation Definition
    evaluator = Evaluaton()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            color, line = dataloader(batchsize)
            z = dataloader.noise_generator(batchsize)

            # Discriminator update
            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            sum_dis_loss += dis_loss.data

            # Generator update
            z = dataloader.noise_generator(batchsize)

            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))
            gen_loss += lossfunc.content_loss(y, color)

            if enable:
                gen_loss += 0.05 * F.gaussian_kl_divergence(mu,
                                                            sigma) / batchsize

            generator.cleargrads()
            if enable:
                encoder.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            if enable:
                enc_opt.update()
            gen_loss.unchain_backward()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

                with chainer.using_config("train", False):
                    y = generator(noise_valid, line_valid)
                y = y.data.get()
                sr = line_valid.data.get()
                cr = color_valid.data.get()

                evaluator(y, cr, sr, outdir, epoch, validsize=validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
Beispiel #2
0
        y = encoder(F.concat([x, opt]))

        _, channels, height, width = y.shape
        y = y.reshape(1, framesize, channels, height,
                      width).transpose(0, 2, 1, 3, 4)
        opt3 = opt.reshape(1, framesize, channels, height,
                           width).transpose(0, 2, 1, 3, 4)
        y = refine(y)

        t = t.reshape(1, framesize, channels, height,
                      width).transpose(0, 2, 1, 3, 4)
        gen_loss = F.mean_absolute_error(y, t)
        #y_dis = discriminator(y)
        #gen_loss+=F.mean(F.softplus(-y_dis))

        encoder.cleargrads()
        #decoder.cleargrads()
        refine.cleargrads()

        gen_loss.backward()

        enc_opt.update()
        #dec_opt.update()
        ref_opt.update()

        gen_loss.unchain_backward()

        #for p in discriminator.params():
        #    p.data = xp.clip(p.data,-0.01,0.01)

        sum_gen_loss += gen_loss.data.get()