def train(epochs, iterations, batchsize, validsize, outdir, modeldir, data_path, extension, img_size, latent_dim, learning_rate, beta1, beta2, enable): # Dataset Definition dataloader = DataLoader(data_path, extension, img_size, latent_dim) print(dataloader) color_valid, line_valid = dataloader(validsize, mode="valid") noise_valid = dataloader.noise_generator(validsize) # Model Definition if enable: encoder = Encoder() encoder.to_gpu() enc_opt = set_optimizer(encoder) generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator, learning_rate, beta1, beta2) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2) # Loss Funtion Definition lossfunc = GauGANLossFunction() # Evaluation Definition evaluator = Evaluaton() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): color, line = dataloader(batchsize) z = dataloader.noise_generator(batchsize) # Discriminator update if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]), F.concat([color, line])) discriminator.cleargrads() dis_loss.backward() dis_opt.update() dis_loss.unchain_backward() sum_dis_loss += dis_loss.data # Generator update z = dataloader.noise_generator(batchsize) if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]), F.concat([color, line])) gen_loss += lossfunc.content_loss(y, color) if enable: gen_loss += 0.05 * F.gaussian_kl_divergence(mu, sigma) / batchsize generator.cleargrads() if enable: encoder.cleargrads() gen_loss.backward() gen_opt.update() if enable: enc_opt.update() gen_loss.unchain_backward() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) with chainer.using_config("train", False): y = generator(noise_valid, line_valid) y = y.data.get() sr = line_valid.data.get() cr = color_valid.data.get() evaluator(y, cr, sr, outdir, epoch, validsize=validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
y = encoder(F.concat([x, opt])) _, channels, height, width = y.shape y = y.reshape(1, framesize, channels, height, width).transpose(0, 2, 1, 3, 4) opt3 = opt.reshape(1, framesize, channels, height, width).transpose(0, 2, 1, 3, 4) y = refine(y) t = t.reshape(1, framesize, channels, height, width).transpose(0, 2, 1, 3, 4) gen_loss = F.mean_absolute_error(y, t) #y_dis = discriminator(y) #gen_loss+=F.mean(F.softplus(-y_dis)) encoder.cleargrads() #decoder.cleargrads() refine.cleargrads() gen_loss.backward() enc_opt.update() #dec_opt.update() ref_opt.update() gen_loss.unchain_backward() #for p in discriminator.params(): # p.data = xp.clip(p.data,-0.01,0.01) sum_gen_loss += gen_loss.data.get()