def train(epochs, iterations, batchsize, validsize, outdir, modeldir, data_path, extension, img_size, latent_dim, learning_rate, beta1, beta2, enable): # Dataset Definition dataloader = DataLoader(data_path, extension, img_size, latent_dim) print(dataloader) color_valid, line_valid = dataloader(validsize, mode="valid") noise_valid = dataloader.noise_generator(validsize) # Model Definition if enable: encoder = Encoder() encoder.to_gpu() enc_opt = set_optimizer(encoder) generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator, learning_rate, beta1, beta2) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2) # Loss Funtion Definition lossfunc = GauGANLossFunction() # Evaluation Definition evaluator = Evaluaton() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): color, line = dataloader(batchsize) z = dataloader.noise_generator(batchsize) # Discriminator update if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]), F.concat([color, line])) discriminator.cleargrads() dis_loss.backward() dis_opt.update() dis_loss.unchain_backward() sum_dis_loss += dis_loss.data # Generator update z = dataloader.noise_generator(batchsize) if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]), F.concat([color, line])) gen_loss += lossfunc.content_loss(y, color) if enable: gen_loss += 0.05 * F.gaussian_kl_divergence(mu, sigma) / batchsize generator.cleargrads() if enable: encoder.cleargrads() gen_loss.backward() gen_opt.update() if enable: enc_opt.update() gen_loss.unchain_backward() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) with chainer.using_config("train", False): y = generator(noise_valid, line_valid) y = y.data.get() sr = line_valid.data.get() cr = color_valid.data.get() evaluator(y, cr, sr, outdir, epoch, validsize=validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
right_of_box.append(ref) lotest = chainer.as_variable(xp.array(left_of_box).astype(xp.float32)) rotest = chainer.as_variable(xp.array(right_of_box).astype(xp.float32)) test_path = "./test.png" test, lefteye, leftlist, righteye, rightlist = prepare_test(test_path) left = chainer.as_variable(xp.array(lefteye).astype(xp.float32)).reshape( 1, 3, 32, 32) right = chainer.as_variable(xp.array(righteye).astype(xp.float32)).reshape( 1, 3, 32, 32) left = F.tile(left, (framesize, 1, 1, 1)) right = F.tile(right, (framesize, 1, 1, 1)) encoder = Encoder() encoder.to_gpu() enc_opt = set_optimizer(encoder) refine = Refine() refine.to_gpu() ref_opt = set_optimizer(refine) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, framesize): input_box = []