예제 #1
0
def train(epochs, iterations, batchsize, validsize, outdir, modeldir,
          data_path, extension, img_size, latent_dim, learning_rate, beta1,
          beta2, enable):

    # Dataset Definition
    dataloader = DataLoader(data_path, extension, img_size, latent_dim)
    print(dataloader)
    color_valid, line_valid = dataloader(validsize, mode="valid")
    noise_valid = dataloader.noise_generator(validsize)

    # Model Definition
    if enable:
        encoder = Encoder()
        encoder.to_gpu()
        enc_opt = set_optimizer(encoder)

    generator = Generator()
    generator.to_gpu()
    gen_opt = set_optimizer(generator, learning_rate, beta1, beta2)

    discriminator = Discriminator()
    discriminator.to_gpu()
    dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2)

    # Loss Funtion Definition
    lossfunc = GauGANLossFunction()

    # Evaluation Definition
    evaluator = Evaluaton()

    for epoch in range(epochs):
        sum_dis_loss = 0
        sum_gen_loss = 0
        for batch in range(0, iterations, batchsize):
            color, line = dataloader(batchsize)
            z = dataloader.noise_generator(batchsize)

            # Discriminator update
            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            y.unchain_backward()

            dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))

            discriminator.cleargrads()
            dis_loss.backward()
            dis_opt.update()
            dis_loss.unchain_backward()

            sum_dis_loss += dis_loss.data

            # Generator update
            z = dataloader.noise_generator(batchsize)

            if enable:
                mu, sigma = encoder(color)
                z = F.gaussian(mu, sigma)
            y = generator(z, line)

            gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]),
                                         F.concat([color, line]))
            gen_loss += lossfunc.content_loss(y, color)

            if enable:
                gen_loss += 0.05 * F.gaussian_kl_divergence(mu,
                                                            sigma) / batchsize

            generator.cleargrads()
            if enable:
                encoder.cleargrads()
            gen_loss.backward()
            gen_opt.update()
            if enable:
                enc_opt.update()
            gen_loss.unchain_backward()

            sum_gen_loss += gen_loss.data

            if batch == 0:
                serializers.save_npz(f"{modeldir}/generator_{epoch}.model",
                                     generator)

                with chainer.using_config("train", False):
                    y = generator(noise_valid, line_valid)
                y = y.data.get()
                sr = line_valid.data.get()
                cr = color_valid.data.get()

                evaluator(y, cr, sr, outdir, epoch, validsize=validsize)

        print(f"epoch: {epoch}")
        print(
            f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}"
        )
예제 #2
0
    right_of_box.append(ref)

lotest = chainer.as_variable(xp.array(left_of_box).astype(xp.float32))
rotest = chainer.as_variable(xp.array(right_of_box).astype(xp.float32))

test_path = "./test.png"
test, lefteye, leftlist, righteye, rightlist = prepare_test(test_path)
left = chainer.as_variable(xp.array(lefteye).astype(xp.float32)).reshape(
    1, 3, 32, 32)
right = chainer.as_variable(xp.array(righteye).astype(xp.float32)).reshape(
    1, 3, 32, 32)
left = F.tile(left, (framesize, 1, 1, 1))
right = F.tile(right, (framesize, 1, 1, 1))

encoder = Encoder()
encoder.to_gpu()
enc_opt = set_optimizer(encoder)

refine = Refine()
refine.to_gpu()
ref_opt = set_optimizer(refine)

discriminator = Discriminator()
discriminator.to_gpu()
dis_opt = set_optimizer(discriminator)

for epoch in range(epochs):
    sum_gen_loss = 0
    sum_dis_loss = 0
    for batch in range(0, iterations, framesize):
        input_box = []