def infer(testsize, outdir, model_path, con_path, sty_path, extension, coord_size, crop_size, alpha, ): # Dataset definition dataloader = DatasetLoader(con_path, sty_path, extension, coord_size, crop_size) print(dataloader) con_valid, sty_valid = dataloader.valid(testsize) # Mode & Optimizer defnition decoder = Decoder() decoder.to_gpu() serializers.load_npz(model_path, decoder) vgg = VGG() vgg.to_gpu() # Visualizer definition visualizer = Visualizer() with chainer.using_config("train", False): style_feat_list = vgg(sty_valid) content_feat = vgg(con_valid)[-1] t = adain(content_feat, style_feat_list[-1]) t = alpha * t + (1 - alpha) * content_feat g_t = decoder(t) g_t = g_t.data.get() con = con_valid.data.get() sty = sty_valid.data.get() visualizer(con, sty, g_t, outdir, 0, testsize)
def train(epochs, iterations, batchsize, validsize, src_path, tgt_path, extension, img_size, outdir, modeldir, lr_dis, lr_gen, beta1, beta2): # Dataset definition dataset = DatasetLoader(src_path, tgt_path, extension, img_size) print(dataset) x_val, x_mask_val, y_val, y_mask_val = dataset.valid(validsize) # Model & Optimizer definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy, lr_gen, beta1, beta2) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx, lr_gen, beta1, beta2) discriminator_y = Discriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y, lr_dis, beta1, beta2) discriminator_x = Discriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x, lr_dis, beta1, beta2) # Loss Function definition lossfunc = InstaGANLossFunction() # Visualizer definition visualize = Visualizer() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, x_mask, y, y_mask = dataset.train(batchsize) # discriminator update xy, xy_mask = generator_xy(x, x_mask) yx, yx_mask = generator_yx(y, y_mask) xy.unchain_backward() xy_mask.unchain_backward() yx.unchain_backward() yx_mask.unchain_backward() dis_loss = lossfunc.adversarial_dis_loss(discriminator_y, xy, xy_mask, y, y_mask) dis_loss += lossfunc.adversarial_dis_loss(discriminator_x, yx, yx_mask, x, x_mask) discriminator_y.cleargrads() discriminator_x.cleargrads() dis_loss.backward() dis_y_opt.update() dis_x_opt.update() sum_dis_loss += dis_loss.data # generator update xy, xy_mask = generator_xy(x, x_mask) yx, yx_mask = generator_yx(y, y_mask) xyx, xyx_mask = generator_yx(xy, xy_mask) yxy, yxy_mask = generator_xy(yx, yx_mask) x_id, x_mask_id = generator_yx(x, x_mask) y_id, y_mask_id = generator_xy(y, y_mask) gen_loss = lossfunc.adversarial_gen_loss(discriminator_y, xy, xy_mask) gen_loss += lossfunc.adversarial_gen_loss(discriminator_x, yx, yx_mask) gen_loss += lossfunc.cycle_consistency_loss( xyx, xyx_mask, x, x_mask) gen_loss += lossfunc.cycle_consistency_loss( yxy, yxy_mask, y, y_mask) gen_loss += lossfunc.identity_mapping_loss(x_id, x_mask_id, x, x_mask) gen_loss += lossfunc.identity_mapping_loss(y_id, y_mask_id, y, y_mask) gen_loss += lossfunc.context_preserving_loss( xy, xy_mask, x, x_mask) gen_loss += lossfunc.context_preserving_loss( yx, yx_mask, y, y_mask) generator_xy.cleargrads() generator_yx.cleargrads() gen_loss.backward() gen_xy_opt.update() gen_yx_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx) xy, xy_mask = generator_xy(x_val, x_mask_val) yx, yx_mask = generator_yx(y_val, y_mask_val) x = x_val.data.get() x_mask = x_mask_val.data.get() xy = xy.data.get() xy_mask = xy_mask.data.get() visualize(x, x_mask, xy, xy_mask, outdir, epoch, validsize, switch="mtot") y = y_val.data.get() y_mask = y_mask_val.data.get() yx = yx.data.get() yx_mask = yx_mask.data.get() visualize(y, y_mask, yx, yx_mask, outdir, epoch, validsize, switch="ttom") print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, src_path, tgt_path, extension, img_size, learning_rate, beta1 ): # Dataset definition dataloader = DatasetLoader(src_path, tgt_path, extension, img_size) print(dataloader) src_val = dataloader.valid(validsize) # Model & Optimizer definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy, learning_rate, beta1) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx, learning_rate, beta1) discriminator_y = Discriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y, learning_rate, beta1) discriminator_x = Discriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x, learning_rate, beta1) # LossFunction definition lossfunc = CycleGANLossCalculator() # Visualization visualizer = Visualization() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, y = dataloader.train(batchsize) # Discriminator update xy = generator_xy(x) yx = generator_yx(y) xy.unchain_backward() yx.unchain_backward() dis_loss_xy = lossfunc.dis_loss(discriminator_y, xy, y) dis_loss_yx = lossfunc.dis_loss(discriminator_x, yx, x) dis_loss = dis_loss_xy + dis_loss_yx discriminator_x.cleargrads() discriminator_y.cleargrads() dis_loss.backward() dis_x_opt.update() dis_y_opt.update() sum_dis_loss += dis_loss.data # Generator update xy = generator_xy(x) yx = generator_yx(y) xyx = generator_yx(xy) yxy = generator_xy(yx) y_id = generator_xy(y) x_id = generator_yx(x) # adversarial loss gen_loss_xy = lossfunc.gen_loss(discriminator_y, xy) gen_loss_yx = lossfunc.gen_loss(discriminator_x, yx) # cycle-consitency loss cycle_y = lossfunc.cycle_consitency_loss(yxy, y) cycle_x = lossfunc.cycle_consitency_loss(xyx, x) # identity mapping loss identity_y = lossfunc.identity_mapping_loss(y_id, y) identity_x = lossfunc.identity_mapping_loss(x_id, x) gen_loss = gen_loss_xy + gen_loss_yx + cycle_x + cycle_y + identity_x + identity_y generator_xy.cleargrads() generator_yx.cleargrads() gen_loss.backward() gen_xy_opt.update() gen_yx_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx) with chainer.using_config('train', False): tgt = generator_xy(src_val) src = src_val.data.get() tgt = tgt.data.get() visualizer(src, tgt, outdir, epoch, validsize) print(f"epoch: {epoch}") print(F"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}")
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, extension, train_size, valid_size, data_path, sketch_path, digi_path, learning_rate, beta1, weight_decay): # Dataset definition dataset = DatasetLoader(data_path, sketch_path, digi_path, extension, train_size, valid_size) print(dataset) x_val, t_val = dataset.valid(validsize) # Model & Optimizer definition unet = UNet() unet.to_gpu() unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay) # Loss function definition lossfunc = Pix2pixLossCalculator() # Visualization definition visualizer = Visualizer() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x, t = dataset.train(batchsize) # Discriminator update y = unet(x) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, y, t) discriminator.cleargrads() dis_loss.backward() dis_opt.update() sum_dis_loss += dis_loss.data # Generator update y = unet(x) gen_loss = lossfunc.gen_loss(discriminator, y) gen_loss += lossfunc.content_loss(y, t) unet.cleargrads() gen_loss.backward() unet_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet) with chainer.using_config("train", False): y = unet(x_val) x = x_val.data.get() t = t_val.data.get() y = y.data.get() visualizer(x, t, y, outdir, epoch, validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}" )
def train(epochs, iterations, batchsize, testsize, outdir, modeldir, n_dis, img_path, tag_path): # Dataset Definition dataloader = DatasetLoader(img_path, tag_path) zvis_valid, ztag_valid = dataloader.valid(batchsize) noise_valid = F.concat([zvis_valid, ztag_valid]) # Model & Optimizer Definition generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Functio Definition lossfunc = RGANLossFunction() # Evaluation evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): for _ in range(n_dis): zvis, ztag, img, tag = dataloader.train(batchsize) y = generator(F.concat([zvis, ztag])) y.unchain_backward() loss = lossfunc.dis_loss(discriminator, y, img, tag, ztag) loss += lossfunc.gradient_penalty(discriminator, img, tag) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() zvis, ztag, _, _ = dataloader.train(batchsize) y = generator(F.concat([zvis, ztag])) loss = lossfunc.gen_loss(discriminator, y, ztag) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator) with chainer.using_config('train', False): y = generator(noise_valid) y = y.data.get() evaluator(y, outdir, epoch, testsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")