def train(epochs, iterations, dataset_path, test_path, outdir, batchsize, testsize, recon_weight, fm_weight, gp_weight, spectral_norm=False): # Dataset Definition dataloader = DatasetLoader(dataset_path, test_path) c_valid, s_valid = dataloader.test(testsize) # Model & Optimizer Definition if spectral_norm: generator = SNGenerator() else: generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Function Definition lossfunc = FUNITLossFunction() # Evaluator Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): c, ci, s, si = dataloader.train(batchsize) y = generator(c, s) y.unchain_backward() loss = lossfunc.dis_loss(discriminator, y, s, si) loss += lossfunc.gradient_penalty(discriminator, s, y, si) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() y_conert = generator(c, s) y_recon = generator(c, c) adv_loss, recon_loss, fm_loss = lossfunc.gen_loss( discriminator, y_conert, y_recon, s, c, si, ci) loss = adv_loss + recon_weight * recon_loss + fm_weight * fm_loss generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz('generator.model', generator) serializers.save_npz('discriminator.model', discriminator) with chainer.using_config('train', False): y = generator(c_valid, s_valid) y.unchain_backward() y = y.data.get() c = c_valid.data.get() s = s_valid.data.get() evaluator(y, c, s, outdir, epoch, testsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, testsize, img_path, seg_path, outdir, modeldir, n_dis, mode): # Dataset Definition dataloader = DatasetLoader(img_path, seg_path) print(dataloader) valid_noise = dataloader.test(testsize) # Model & Optimizer Definition generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Function Definition lossfunc = SGANLossFunction() # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): for _ in range(n_dis): t, s, noise = dataloader.train(batchsize) y_img, y_seg = generator(noise) loss = lossfunc.dis_loss(discriminator, y_img, y_seg, t, s) loss += lossfunc.gradient_penalty(discriminator, y_img, y_seg, t, s, mode=mode) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() _, _, noise = dataloader.train(batchsize) y_img, y_seg = generator(noise) loss = lossfunc.gen_loss(discriminator, y_img, y_seg) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss = loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator) with chainer.using_config('train', False): y_img, y_seg = generator(valid_noise) y_img = y_img.data.get() y_seg = y_seg.data.get() evaluator(y_img, y_seg, epoch, outdir, testsize=testsize) print(f"epoh: {epoch}") print(f"loss: {sum_loss / iterations}")