def train(epochs, iterations, batchsize, outdir, data_path): # Dataset Definition dataloader = DatasetLoader(data_path) # Model & Optimizer Definition #generator = Generator() generator = GeneratorWithCIN() generator.to_gpu() gen_opt = set_optimizer(generator, alpha=0.0002) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, alpha=0.0001) # Loss Function Definition lossfunc = StarGANVC2LossFunction() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): x_sp, x_label, y_sp, y_label = dataloader.train(batchsize) y_fake = generator(x_sp, F.concat([y_label, x_label])) y_fake.unchain_backward() loss = lossfunc.dis_loss(discriminator, y_fake, x_sp, y_label, x_label) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() y_fake = generator(x_sp, F.concat([y_label, x_label])) x_fake = generator(y_fake, F.concat([x_label, y_label])) x_identity = generator(x_sp, F.concat([x_label, x_label])) loss = lossfunc.gen_loss(discriminator, y_fake, x_fake, x_sp, F.concat([y_label, x_label])) if epoch < 50: loss += lossfunc.identity_loss(x_identity, x_sp) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"modeldirCIN/generator_{epoch}.model", generator) serializers.save_npz('discriminator.model', discriminator) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, batchsize, iterations, nc_size, data_path, modeldir): # Dataset definition dataset = DatasetLoader(data_path, nc_size) # Model Definition & Optimizer Definition generator = Generator(nc_size) generator.to_gpu() gen_opt = set_optimizer(generator, 0.0001, 0.5) discriminator = Discriminator(nc_size) discriminator.to_gpu() dis_opt = set_optimizer(discriminator, 0.0001, 0.5) for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, x_label, y, y_label = dataset.train(batchsize) y_fake = generator(x, y_label) y_fake.unchain_backward() loss = adversarial_loss_dis(discriminator, y_fake, x, y_label, x_label) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() sum_dis_loss += loss.data y_fake = generator(x, y_label) x_fake = generator(y_fake, x_label) x_id = generator(x, x_label) loss = adversarial_loss_gen(discriminator, y_fake, x_fake, x, y_label) if epoch < 20: loss += 10 * F.mean_absolute_error(x_id, x) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_gen_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz("discriminator.model", discriminator) print(f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}")
seg.unchain_backward() std_data = xp.std(t.data, axis=0, keepdims=True) rnd_x = xp.random.uniform(0, 1, t.shape).astype(xp.float32) x_perturbed = rnd_x * t + (1 - rnd_x) * x s_perturbed = rnd_x * s + (1 - rnd_x) * seg y_perturbed = discriminator(x_perturbed, s_perturbed) grad, = chainer.grad([y_perturbed], [x_perturbed], enable_double_backprop=True) grad = F.sqrt(F.batch_l2_norm_squared(grad)) loss_grad = lambda1 * F.mean_squared_error(grad, xp.ones_like(grad.data)) dis_loss += loss_grad discriminator.cleargrads() dis_loss.backward() dis_loss.unchain_backward() dis_opt.update() z = chainer.as_variable( xp.random.uniform(-1, 1, (batchsize, 256)).astype(xp.float32)) x, seg = generator(z) fake = discriminator(x, seg) gen_loss = loss_hinge_gen(fake) generator.cleargrads() gen_loss.backward() gen_loss.unchain_backward() gen_opt.update()
def train(epochs, iterations, batchsize, validsize, src_path, tgt_path, extension, img_size, outdir, modeldir, lr_dis, lr_gen, beta1, beta2): # Dataset definition dataset = DatasetLoader(src_path, tgt_path, extension, img_size) print(dataset) x_val, x_mask_val, y_val, y_mask_val = dataset.valid(validsize) # Model & Optimizer definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy, lr_gen, beta1, beta2) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx, lr_gen, beta1, beta2) discriminator_y = Discriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y, lr_dis, beta1, beta2) discriminator_x = Discriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x, lr_dis, beta1, beta2) # Loss Function definition lossfunc = InstaGANLossFunction() # Visualizer definition visualize = Visualizer() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, x_mask, y, y_mask = dataset.train(batchsize) # discriminator update xy, xy_mask = generator_xy(x, x_mask) yx, yx_mask = generator_yx(y, y_mask) xy.unchain_backward() xy_mask.unchain_backward() yx.unchain_backward() yx_mask.unchain_backward() dis_loss = lossfunc.adversarial_dis_loss(discriminator_y, xy, xy_mask, y, y_mask) dis_loss += lossfunc.adversarial_dis_loss(discriminator_x, yx, yx_mask, x, x_mask) discriminator_y.cleargrads() discriminator_x.cleargrads() dis_loss.backward() dis_y_opt.update() dis_x_opt.update() sum_dis_loss += dis_loss.data # generator update xy, xy_mask = generator_xy(x, x_mask) yx, yx_mask = generator_yx(y, y_mask) xyx, xyx_mask = generator_yx(xy, xy_mask) yxy, yxy_mask = generator_xy(yx, yx_mask) x_id, x_mask_id = generator_yx(x, x_mask) y_id, y_mask_id = generator_xy(y, y_mask) gen_loss = lossfunc.adversarial_gen_loss(discriminator_y, xy, xy_mask) gen_loss += lossfunc.adversarial_gen_loss(discriminator_x, yx, yx_mask) gen_loss += lossfunc.cycle_consistency_loss( xyx, xyx_mask, x, x_mask) gen_loss += lossfunc.cycle_consistency_loss( yxy, yxy_mask, y, y_mask) gen_loss += lossfunc.identity_mapping_loss(x_id, x_mask_id, x, x_mask) gen_loss += lossfunc.identity_mapping_loss(y_id, y_mask_id, y, y_mask) gen_loss += lossfunc.context_preserving_loss( xy, xy_mask, x, x_mask) gen_loss += lossfunc.context_preserving_loss( yx, yx_mask, y, y_mask) generator_xy.cleargrads() generator_yx.cleargrads() gen_loss.backward() gen_xy_opt.update() gen_yx_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx) xy, xy_mask = generator_xy(x_val, x_mask_val) yx, yx_mask = generator_yx(y_val, y_mask_val) x = x_val.data.get() x_mask = x_mask_val.data.get() xy = xy.data.get() xy_mask = xy_mask.data.get() visualize(x, x_mask, xy, xy_mask, outdir, epoch, validsize, switch="mtot") y = y_val.data.get() y_mask = y_mask_val.data.get() yx = yx.data.get() yx_mask = yx_mask.data.get() visualize(y, y_mask, yx, yx_mask, outdir, epoch, validsize, switch="ttom") print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, src_path, tgt_path, extension, img_size, learning_rate, beta1 ): # Dataset definition dataloader = DatasetLoader(src_path, tgt_path, extension, img_size) print(dataloader) src_val = dataloader.valid(validsize) # Model & Optimizer definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy, learning_rate, beta1) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx, learning_rate, beta1) discriminator_y = Discriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y, learning_rate, beta1) discriminator_x = Discriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x, learning_rate, beta1) # LossFunction definition lossfunc = CycleGANLossCalculator() # Visualization visualizer = Visualization() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, y = dataloader.train(batchsize) # Discriminator update xy = generator_xy(x) yx = generator_yx(y) xy.unchain_backward() yx.unchain_backward() dis_loss_xy = lossfunc.dis_loss(discriminator_y, xy, y) dis_loss_yx = lossfunc.dis_loss(discriminator_x, yx, x) dis_loss = dis_loss_xy + dis_loss_yx discriminator_x.cleargrads() discriminator_y.cleargrads() dis_loss.backward() dis_x_opt.update() dis_y_opt.update() sum_dis_loss += dis_loss.data # Generator update xy = generator_xy(x) yx = generator_yx(y) xyx = generator_yx(xy) yxy = generator_xy(yx) y_id = generator_xy(y) x_id = generator_yx(x) # adversarial loss gen_loss_xy = lossfunc.gen_loss(discriminator_y, xy) gen_loss_yx = lossfunc.gen_loss(discriminator_x, yx) # cycle-consitency loss cycle_y = lossfunc.cycle_consitency_loss(yxy, y) cycle_x = lossfunc.cycle_consitency_loss(xyx, x) # identity mapping loss identity_y = lossfunc.identity_mapping_loss(y_id, y) identity_x = lossfunc.identity_mapping_loss(x_id, x) gen_loss = gen_loss_xy + gen_loss_yx + cycle_x + cycle_y + identity_x + identity_y generator_xy.cleargrads() generator_yx.cleargrads() gen_loss.backward() gen_xy_opt.update() gen_yx_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx) with chainer.using_config('train', False): tgt = generator_xy(src_val) src = src_val.data.get() tgt = tgt.data.get() visualizer(src, tgt, outdir, epoch, validsize) print(f"epoch: {epoch}") print(F"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}")
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, data_path, extension, img_size, latent_dim, learning_rate, beta1, beta2, enable): # Dataset Definition dataloader = DataLoader(data_path, extension, img_size, latent_dim) print(dataloader) color_valid, line_valid = dataloader(validsize, mode="valid") noise_valid = dataloader.noise_generator(validsize) # Model Definition if enable: encoder = Encoder() encoder.to_gpu() enc_opt = set_optimizer(encoder) generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator, learning_rate, beta1, beta2) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2) # Loss Funtion Definition lossfunc = GauGANLossFunction() # Evaluation Definition evaluator = Evaluaton() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): color, line = dataloader(batchsize) z = dataloader.noise_generator(batchsize) # Discriminator update if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]), F.concat([color, line])) discriminator.cleargrads() dis_loss.backward() dis_opt.update() dis_loss.unchain_backward() sum_dis_loss += dis_loss.data # Generator update z = dataloader.noise_generator(batchsize) if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]), F.concat([color, line])) gen_loss += lossfunc.content_loss(y, color) if enable: gen_loss += 0.05 * F.gaussian_kl_divergence(mu, sigma) / batchsize generator.cleargrads() if enable: encoder.cleargrads() gen_loss.backward() gen_opt.update() if enable: enc_opt.update() gen_loss.unchain_backward() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) with chainer.using_config("train", False): y = generator(noise_valid, line_valid) y = y.data.get() sr = line_valid.data.get() cr = color_valid.data.get() evaluator(y, cr, sr, outdir, epoch, validsize=validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
def train(epochs, iterations, batchsize, data_path, modeldir, extension, img_size, learning_rate, beta1, weight_decay): # Dataset definition dataset = DatasetLoader(data_path, extension, img_size) # Model & Optimizer definition generator = Generator(dataset.number) generator.to_gpu() gen_opt = set_optimizer(generator, learning_rate, beta1, weight_decay) discriminator = Discriminator(dataset.number) discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay) # Loss Function definition lossfunc = RelGANLossFunction() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x, x_label, y, y_label, z, z_label = dataset.train(batchsize) # Discriminator update # Adversairal loss a = y_label - x_label fake = generator(x, a) fake.unchain_backward() loss = lossfunc.adversarial_loss_dis(discriminator, fake, y) # Interpolation loss rnd = np.random.randint(2) if rnd == 0: alpha = xp.random.uniform(0, 0.5, size=batchsize) else: alpha = xp.random.uniform(0.5, 1.0, size=batchsize) alpha = chainer.as_variable(alpha.astype(xp.float32)) alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number)) fake_0 = generator(x, y_label - y_label) fake_1 = generator(x, alpha * a) fake_0.unchain_backward() fake_1.unchain_backward() loss += 10 * lossfunc.interpolation_loss_dis( discriminator, fake_0, fake, fake_1, alpha, rnd) # Matching loss v2 = y_label - z_label v3 = z_label - x_label loss += lossfunc.matching_loss_dis(discriminator, x, fake, y, z, a, v2, v3) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() sum_dis_loss += loss.data # Generator update # Adversarial loss fake = generator(x, a) loss = lossfunc.adversarial_loss_gen(discriminator, fake) # Interpolation loss rnd = np.random.randint(2) if rnd == 0: alpha = xp.random.uniform(0, 0.5, size=batchsize) else: alpha = xp.random.uniform(0.5, 1.0, size=batchsize) alpha = chainer.as_variable(alpha.astype(xp.float32)) alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number)) fake_alpha = generator(x, alpha * a) loss += 10 * lossfunc.interpolation_loss_gen( discriminator, fake_alpha) # Matching loss loss += lossfunc.matching_loss_gen(discriminator, x, fake, a) # Cycle-consistency loss cyc = generator(fake, -a) loss += 10 * F.mean_absolute_error(cyc, x) # Self-reconstruction loss fake_0 = generator(x, y_label - y_label) loss += 10 * F.mean_absolute_error(fake_0, x) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_gen_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) print( f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}" )
def train(epochs, iterations, batchsize, testsize, img_path, seg_path, outdir, modeldir, n_dis, mode): # Dataset Definition dataloader = DatasetLoader(img_path, seg_path) print(dataloader) valid_noise = dataloader.test(testsize) # Model & Optimizer Definition generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Function Definition lossfunc = SGANLossFunction() # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): for _ in range(n_dis): t, s, noise = dataloader.train(batchsize) y_img, y_seg = generator(noise) loss = lossfunc.dis_loss(discriminator, y_img, y_seg, t, s) loss += lossfunc.gradient_penalty(discriminator, y_img, y_seg, t, s, mode=mode) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() _, _, noise = dataloader.train(batchsize) y_img, y_seg = generator(noise) loss = lossfunc.gen_loss(discriminator, y_img, y_seg) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss = loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator) with chainer.using_config('train', False): y_img, y_seg = generator(valid_noise) y_img = y_img.data.get() y_seg = y_seg.data.get() evaluator(y_img, y_seg, epoch, outdir, testsize=testsize) print(f"epoh: {epoch}") print(f"loss: {sum_loss / iterations}")
x_y.unchain_backward() y_x.unchain_backward() dis_fake_xy = discriminator_xy(x_y) dis_real_xy = discriminator_xy(t) #dis_loss_xy=F.mean(F.softplus(dis_fake_xy))+F.mean(F.softplus(-dis_real_xy)) dis_loss_xy = least_square_loss(dis_fake_xy, dis_real_xy) dis_fake_yx = discriminator_yx(y_x) dis_real_yx = discriminator_yx(x) #dis_loss_yx=F.mean(F.softplus(dis_fake_yx))+F.mean(F.softplus(-dis_real_yx)) dis_loss_yx = least_square_loss(dis_fake_yx, dis_real_yx) dis_loss = dis_loss_xy + dis_loss_yx discriminator_xy.cleargrads() discriminator_yx.cleargrads() dis_loss.backward() dis_opt_xy.update() dis_opt_yx.update() dis_loss.unchain_backward() x_y = generator_xy(x) x_y_x = generator_yx(x_y) y_x = generator_yx(t) y_x_y = generator_xy(y_x) dis_fake_xy = discriminator_xy(x_y) dis_fake_yx = discriminator_yx(y_x)
def train(epochs, iterations, dataset_path, test_path, outdir, batchsize, testsize, recon_weight, fm_weight, gp_weight, spectral_norm=False): # Dataset Definition dataloader = DatasetLoader(dataset_path, test_path) c_valid, s_valid = dataloader.test(testsize) # Model & Optimizer Definition if spectral_norm: generator = SNGenerator() else: generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Function Definition lossfunc = FUNITLossFunction() # Evaluator Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): c, ci, s, si = dataloader.train(batchsize) y = generator(c, s) y.unchain_backward() loss = lossfunc.dis_loss(discriminator, y, s, si) loss += lossfunc.gradient_penalty(discriminator, s, y, si) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() y_conert = generator(c, s) y_recon = generator(c, c) adv_loss, recon_loss, fm_loss = lossfunc.gen_loss( discriminator, y_conert, y_recon, s, c, si, ci) loss = adv_loss + recon_weight * recon_loss + fm_weight * fm_loss generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz('generator.model', generator) serializers.save_npz('discriminator.model', discriminator) with chainer.using_config('train', False): y = generator(c_valid, s_valid) y.unchain_backward() y = y.data.get() c = c_valid.data.get() s = s_valid.data.get() evaluator(y, c, s, outdir, epoch, testsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, modeldir, extension, time_width, mel_bins, sampling_rate, g_learning_rate, d_learning_rate, beta1, beta2, identity_epoch, adv_type, residual_flag, data_path): # Dataset Definition dataloader = DatasetLoader(data_path) # Model & Optimizer Definition generator = GeneratorWithCIN(adv_type=adv_type) generator.to_gpu() gen_opt = set_optimizer(generator, g_learning_rate, beta1, beta2) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, d_learning_rate, beta1, beta2) # Loss Function Definition lossfunc = StarGANVC2LossFunction() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x_sp, x_label, y_sp, y_label = dataloader.train(batchsize) if adv_type == 'sat': y_fake = generator(x_sp, F.concat([y_label, x_label])) elif adv_type == 'orig': y_fake = generator(x_sp, y_label) else: raise AttributeError y_fake.unchain_backward() if adv_type == 'sat': advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss( discriminator, y_fake, x_sp, F.concat([y_label, x_label]), F.concat([x_label, y_label]), residual_flag) elif adv_type == 'orig': advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss( discriminator, y_fake, x_sp, y_label, x_label, residual_flag) else: raise AttributeError dis_loss = advloss_dis_real + advloss_dis_fake discriminator.cleargrads() dis_loss.backward() dis_opt.update() dis_loss.unchain_backward() if adv_type == 'sat': y_fake = generator(x_sp, F.concat([y_label, x_label])) x_fake = generator(y_fake, F.concat([x_label, y_label])) x_identity = generator(x_sp, F.concat([x_label, x_label])) advloss_gen_fake, cycle_loss = lossfunc.gen_loss( discriminator, y_fake, x_fake, x_sp, F.concat([y_label, x_label]), residual_flag) elif adv_type == 'orig': y_fake = generator(x_sp, y_label) x_fake = generator(y_fake, x_label) x_identity = generator(x_sp, x_label) advloss_gen_fake, cycle_loss = lossfunc.gen_loss( discriminator, y_fake, x_fake, x_sp, y_label, residual_flag) else: raise AttributeError if epoch < identity_epoch: identity_loss = lossfunc.identity_loss(x_identity, x_sp) else: identity_loss = call_zeros(advloss_dis_fake) gen_loss = advloss_gen_fake + cycle_loss + identity_loss generator.cleargrads() gen_loss.backward() gen_opt.update() gen_loss.unchain_backward() sum_dis_loss += dis_loss.data sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
rnd_x = xp.random.uniform(0, 1, x_dis.shape).astype(xp.float32) x_perturbed = Variable(cuda.to_gpu(x_dis + 0.5 * rnd_x * std_data)) x_dis = Variable(cuda.to_gpu(x_dis)) y_dis = dis_model(x_dis, Variable(t_dis)) dis_loss += F.mean(F.softplus(-y_dis)) y_perturbed = dis_model(x_perturbed, Variable(t_dis)) grad, = chainer.grad([y_perturbed], [x_perturbed], enable_double_backprop=True) grad = F.sqrt(F.batch_l2_norm_squared(grad)) loss_grad = lambda1 * F.mean_squared_error(grad, xp.ones_like(grad.data)) dis_loss += loss_grad dis_model.cleargrads() dis_loss.backward() dis_loss.unchain_backward() dis_opt.update() z = Variable(xp.random.normal(size=(batchsize, 128), dtype=xp.float32)) label = cuda.to_gpu(get_fake_tag_batch(batchsize, dims, threshold)) z = F.concat([z, Variable(label)]) x = gen_model(z) y = dis_model(x, Variable(label)) gen_loss = F.mean(F.softplus(-y)) gen_model.cleargrads() gen_loss.backward() gen_loss.unchain_backward() gen_opt.update()
def main(): parser = argparse.ArgumentParser(description='') parser.add_argument('out') parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU device ID') parser.add_argument('--epoch', '-e', type=int, default=200, help='# of epoch') parser.add_argument('--batch_size', '-b', type=int, default=10) parser.add_argument('--memory_size', '-m', type=int, default=500) parser.add_argument('--real_label', type=float, default=0.9) parser.add_argument('--fake_label', type=float, default=0.0) parser.add_argument('--block_num', type=int, default=6) parser.add_argument('--g_nobn', dest='g_bn', action='store_false', default=True) parser.add_argument('--d_nobn', dest='d_bn', action='store_false', default=True) parser.add_argument('--variable_size', action='store_true', default=False) parser.add_argument('--lambda_dis_real', type=float, default=0) parser.add_argument('--size', type=int, default=128) parser.add_argument('--lambda_', type=float, default=10) # args = parser.parse_args() args, unknown = parser.parse_known_args() # log directory out = datetime.datetime.now().strftime('%m%d%H') out = out + '_' + args.out out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", out)) os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True) os.makedirs(os.path.join(out_dir, 'visualize'), exist_ok=True) # hyper parameter with open(os.path.join(out_dir, 'setting.txt'), 'w') as f: for k, v in args._get_kwargs(): print('{} = {}'.format(k, v)) f.write('{} = {}\n'.format(k, v)) trainA = ImageDataset('horse2zebra/trainA', augmentation=True, image_size=256, final_size=args.size) trainB = ImageDataset('horse2zebra/trainB', augmentation=True, image_size=256, final_size=args.size) testA = ImageDataset('horse2zebra/testA', image_size=256, final_size=args.size) testB = ImageDataset('horse2zebra/testB', image_size=256, final_size=args.size) train_iterA = chainer.iterators.MultiprocessIterator(trainA, args.batch_size, n_processes=min( 8, args.batch_size)) train_iterB = chainer.iterators.MultiprocessIterator(trainB, args.batch_size, n_processes=min( 8, args.batch_size)) N = len(trainA) # genA convert B -> A, genB convert A -> B genA = Generator(block_num=args.block_num, bn=args.g_bn) genB = Generator(block_num=args.block_num, bn=args.g_bn) # disA discriminate realA and fakeA, disB discriminate realB and fakeB disA = Discriminator(bn=args.d_bn) disB = Discriminator(bn=args.d_bn) if args.gpu >= 0: cuda.get_device_from_id(args.gpu).use() genA.to_gpu() genB.to_gpu() disA.to_gpu() disB.to_gpu() optimizer_genA = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.9) optimizer_genB = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.9) optimizer_disA = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.9) optimizer_disB = chainer.optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.9) optimizer_genA.setup(genA) optimizer_genB.setup(genB) optimizer_disA.setup(disA) optimizer_disB.setup(disB) # start training start = time.time() fake_poolA = np.zeros( (args.memory_size, 3, args.size, args.size)).astype('float32') fake_poolB = np.zeros( (args.memory_size, 3, args.size, args.size)).astype('float32') lambda_ = args.lambda_ const_realA = np.asarray([testA.get_example(i) for i in range(10)]) const_realB = np.asarray([testB.get_example(i) for i in range(10)]) iterations = 0 for epoch in range(args.epoch): if epoch > 100: decay_rate = 0.0002 / 100 optimizer_genA.alpha -= decay_rate optimizer_genB.alpha -= decay_rate optimizer_disA.alpha -= decay_rate optimizer_disB.alpha -= decay_rate # train iter_num = N // args.batch_size for i in range(iter_num): # load real batch imagesA = train_iterA.next() imagesB = train_iterB.next() if args.variable_size: crop_size = np.random.choice([160, 192, 224, 256]) resize_size = np.random.choice([160, 192, 224, 256]) imagesA = [ random_augmentation(image, crop_size, resize_size) for image in imagesA ] imagesB = [ random_augmentation(image, crop_size, resize_size) for image in imagesB ] realA = chainer.Variable(genA.xp.asarray(imagesA, 'float32')) realB = chainer.Variable(genB.xp.asarray(imagesB, 'float32')) # load fake batch if iterations < args.memory_size: fakeA = genA(realB) fakeB = genB(realA) fakeA.unchain_backward() fakeB.unchain_backward() else: fake_imagesA = fake_poolA[np.random.randint( args.memory_size, size=args.batch_size)] fake_imagesB = fake_poolB[np.random.randint( args.memory_size, size=args.batch_size)] if args.variable_size: fake_imagesA = [ random_augmentation(image, crop_size, resize_size) for image in fake_imagesA ] fake_imagesB = [ random_augmentation(image, crop_size, resize_size) for image in fake_imagesB ] fakeA = chainer.Variable(genA.xp.asarray(fake_imagesA)) fakeB = chainer.Variable(genA.xp.asarray(fake_imagesB)) ############################ # (1) Update D network ########################### # dis A y_realA = disA(realA) y_fakeA = disA(fakeA) loss_disA = (F.sum((y_realA - args.real_label) ** 2) + F.sum((y_fakeA - args.fake_label) ** 2)) \ / np.prod(y_fakeA.shape) # dis B y_realB = disB(realB) y_fakeB = disB(fakeB) loss_disB = (F.sum((y_realB - args.real_label) ** 2) + F.sum((y_fakeB - args.fake_label) ** 2)) \ / np.prod(y_fakeB.shape) # discriminate real A and real B not only realA and fakeA if args.lambda_dis_real > 0: y_realB = disA(realB) loss_disA += F.sum( (y_realB - args.fake_label)**2) / np.prod(y_realB.shape) y_realA = disB(realA) loss_disB += F.sum( (y_realA - args.fake_label)**2) / np.prod(y_realA.shape) # update dis disA.cleargrads() disB.cleargrads() loss_disA.backward() loss_disB.backward() optimizer_disA.update() optimizer_disB.update() ########################### # (2) Update G network ########################### # gan A fakeA = genA(realB) y_fakeA = disA(fakeA) loss_ganA = F.sum( (y_fakeA - args.real_label)**2) / np.prod(y_fakeA.shape) # gan B fakeB = genB(realA) y_fakeB = disB(fakeB) loss_ganB = F.sum( (y_fakeB - args.real_label)**2) / np.prod(y_fakeB.shape) # rec A recA = genA(fakeB) loss_recA = F.mean_absolute_error(recA, realA) # rec B recB = genB(fakeA) loss_recB = F.mean_absolute_error(recB, realB) # gen loss loss_gen = loss_ganA + loss_ganB + lambda_ * (loss_recA + loss_recB) # loss_genB = loss_ganB + lambda_ * (loss_recB + loss_recA) # update gen genA.cleargrads() genB.cleargrads() loss_gen.backward() # loss_genB.backward() optimizer_genA.update() optimizer_genB.update() # logging logger.plot('loss dis A', float(loss_disA.data)) logger.plot('loss dis B', float(loss_disB.data)) logger.plot('loss rec A', float(loss_recA.data)) logger.plot('loss rec B', float(loss_recB.data)) logger.plot('loss gen A', float(loss_gen.data)) # logger.plot('loss gen B', float(loss_genB.data)) logger.tick() # save to replay buffer fakeA = cuda.to_cpu(fakeA.data) fakeB = cuda.to_cpu(fakeB.data) for k in range(args.batch_size): fake_sampleA = fakeA[k] fake_sampleB = fakeB[k] if args.variable_size: fake_sampleA = cv2.resize( fake_sampleA.transpose(1, 2, 0), (256, 256), interpolation=cv2.INTER_AREA).transpose(2, 0, 1) fake_sampleB = cv2.resize( fake_sampleB.transpose(1, 2, 0), (256, 256), interpolation=cv2.INTER_AREA).transpose(2, 0, 1) fake_poolA[(iterations * args.batch_size) % args.memory_size + k] = fake_sampleA fake_poolB[(iterations * args.batch_size) % args.memory_size + k] = fake_sampleB iterations += 1 progress_report(iterations, start, args.batch_size) if epoch % 5 == 0: logger.flush(out_dir) visualize(genA, genB, const_realA, const_realB, epoch=epoch, savedir=os.path.join(out_dir, 'visualize')) serializers.save_hdf5( os.path.join(out_dir, "models", "{:03d}.disA.model".format(epoch)), disA) serializers.save_hdf5( os.path.join(out_dir, "models", "{:03d}.disB.model".format(epoch)), disB) serializers.save_hdf5( os.path.join(out_dir, "models", "{:03d}.genA.model".format(epoch)), genA) serializers.save_hdf5( os.path.join(out_dir, "models", "{:03d}.genB.model".format(epoch)), genB)
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, extension, train_size, valid_size, data_path, sketch_path, digi_path, learning_rate, beta1, weight_decay): # Dataset definition dataset = DatasetLoader(data_path, sketch_path, digi_path, extension, train_size, valid_size) print(dataset) x_val, t_val = dataset.valid(validsize) # Model & Optimizer definition unet = UNet() unet.to_gpu() unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay) # Loss function definition lossfunc = Pix2pixLossCalculator() # Visualization definition visualizer = Visualizer() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x, t = dataset.train(batchsize) # Discriminator update y = unet(x) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, y, t) discriminator.cleargrads() dis_loss.backward() dis_opt.update() sum_dis_loss += dis_loss.data # Generator update y = unet(x) gen_loss = lossfunc.gen_loss(discriminator, y) gen_loss += lossfunc.content_loss(y, t) unet.cleargrads() gen_loss.backward() unet_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet) with chainer.using_config("train", False): y = unet(x_val) x = x_val.data.get() t = t_val.data.get() y = y.data.get() visualizer(x, t, y, outdir, epoch, validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}" )
def train(epochs, iterations, batchsize, testsize, outdir, modeldir, n_dis, img_path, tag_path): # Dataset Definition dataloader = DatasetLoader(img_path, tag_path) zvis_valid, ztag_valid = dataloader.valid(batchsize) noise_valid = F.concat([zvis_valid, ztag_valid]) # Model & Optimizer Definition generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Functio Definition lossfunc = RGANLossFunction() # Evaluation evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): for _ in range(n_dis): zvis, ztag, img, tag = dataloader.train(batchsize) y = generator(F.concat([zvis, ztag])) y.unchain_backward() loss = lossfunc.dis_loss(discriminator, y, img, tag, ztag) loss += lossfunc.gradient_penalty(discriminator, img, tag) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() zvis, ztag, _, _ = dataloader.train(batchsize) y = generator(F.concat([zvis, ztag])) loss = lossfunc.gen_loss(discriminator, y, ztag) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator) with chainer.using_config('train', False): y = generator(noise_valid) y = y.data.get() evaluator(y, outdir, epoch, testsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def main(args): #initialize models and load mnist dataset G = Generator() D = Discriminator() x = load_dataset() #build optimizer of generator opt_generator = chainer.optimizers.Adam().setup(G) opt_generator.use_cleargrads() #build optimizer of discriminator opt_discriminator = chainer.optimizers.Adam().setup(D) opt_generator.use_cleargrads() #make the output folder if not os.path.exists(args.output): os.makedirs(args.output, exist_ok=True) #list of loss Glosses = [] Dlosses = [] print("Now starting training loop...") #begin training process for train_iter in range(1, args.num_epochs + 1): for i in range(0, len(x), 100): #Clears all gradient arrays. #The following should be called before the backward computation at every iteration of the optimization. G.cleargrads() D.cleargrads() #Train the generator noise_samples = sample(100) Gloss = 0.5 * F.sum(F.square(D(G(np.asarray(noise_samples))) - 1)) Gloss.backward() opt_generator.update() #As above G.cleargrads() D.cleargrads() #Train the discriminator noise_samples = sample(100) Dreal = D(np.asarray(x[i:i + 100])) Dgen = D(G(np.asarray(noise_samples))) Dloss = 0.5 * F.sum(F.square( (Dreal - 1.0))) + 0.5 * F.sum(F.square(Dgen)) Dloss.backward() opt_discriminator.update() #save loss from each batch Glosses.append(Gloss.data) Dlosses.append(Dloss.data) if train_iter % 10 == 0: print("epoch {0:04d}".format(train_iter), end=", ") print("Gloss: {}".format(Gloss.data), end=", ") print("Dloss: {}".format(Dloss.data)) noise_samples = sample(100) print_sample( os.path.join(args.output, "epoch_{0:04}.png".format(train_iter)), noise_samples, G) print("The training process is finished.") plotLoss(train_iter, Dlosses, Glosses)
fake.unchain_backward() fake_2.unchain_backward() fake_4.unchain_backward() # LSGAN #adver_loss=0.5*(F.sum((dis_color-1.0)**2)+F.sum(dis_fake**2))/batchsize #adver_loss+=0.5*(F.sum((dis2_color-1.0)**2)+F.sum(dis2_fake**2))/batchsize #adver_loss+=0.5*(F.sum((dis4_color-1.0)**2)+F.sum(dis4_fake**2))/batchsize # DCGAN adver_loss = F.mean(F.softplus(-dis_color)) + F.mean(F.softplus(dis_fake)) adver_loss+=F.mean(F.softplus(-dis2_color)) + F.mean(F.softplus(dis2_fake)) adver_loss+=F.mean(F.softplus(-dis4_color)) + F.mean(F.softplus(dis4_fake)) discriminator.cleargrads() discriminator_2.cleargrads() discriminator_4.cleargrads() adver_loss.backward() dis_opt.update() dis2_opt.update() dis4_opt.update() adver_loss.unchain_backward() fake,_=global_generator(line) fake_2=F.average_pooling_2d(fake,3,2,1) fake_4=F.average_pooling_2d(fake_2,3,2,1) dis_fake,fake_feat=discriminator(F.concat([line,fake])) dis2_fake,fake_feat2=discriminator_2(F.concat([line_2,fake_2])) dis4_fake,fake_feat3=discriminator_4(F.concat([line_4,fake_4]))
def train_refine(epochs, iterations, batchsize, validsize, data_path, sketch_path, digi_path, st_path, extension, img_size, crop_size, outdir, modeldir, adv_weight, enf_weight): # Dataset Definition dataloader = RefineDataset(data_path, sketch_path, digi_path, st_path, extension=extension, img_size=img_size, crop_size=crop_size) print(dataloader) color_valid, line_valid, mask_valid, ds_valid, cm_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition generator = SAGeneratorWithGuide(attn_type="sa", base=64, bn=True, activ=F.relu) generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) vgg = VGG() vgg.to_gpu() vgg_opt = set_optimizer(vgg) vgg.base.disable_update() # Loss Function Definition lossfunc = LossCalculator() # Evaluation Definition evaluator = Evaluation() iteration = 0 for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): iteration += 1 color, line, mask, mask_ds, color_mask = dataloader(batchsize) line_input = F.concat([line, mask]) extractor = vgg(color_mask, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() # Discriminator update fake, _ = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) t_dis = discriminator(color, extractor) loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis) fake.unchain_backward() discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() sum_dis_loss += loss.data # Generator update fake, guide = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) loss = adv_weight * lossfunc.gen_hinge_loss(y_dis) loss += lossfunc.content_loss(fake, color) loss += 0.9 * lossfunc.content_loss(guide, color) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_gen_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) extractor = vgg(cm_valid, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() line_valid_input = F.concat([line_valid, mask_valid]) with chainer.using_config('train', False): y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor) y_valid = y_valid.data.get() c_valid = color_valid.data.get() input_valid = line_valid_input.data.get() cm_val = cm_valid.data.get() guide_valid = guide_valid.data.get() input_valid = np.concatenate([input_valid[:, 3:6], cm_val], axis=1) evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize) print(f"iter: {iteration} dis loss: {sum_dis_loss} gen loss: {gen_loss}")
class Train: def __init__(self): self.data = dataset() self.data.reset() self.reset() # self.load(1) self.setLR() self.time = time.time() self.dataRate = xp.float32(0.8) self.mado = xp.hanning(442).astype(xp.float32) # n=10 # load_npz(f"param/gen/gen_{n}.npz",self.generator) # load_npz(f"param/dis/dis_{n}.npz",self.discriminator) self.training(batchsize=6) def reset(self): self.generator = None self.discriminator = None self.generator = Generator() self.discriminator = Discriminator() self.generator.to_gpu() self.discriminator.to_gpu() def setLR(self, lr=0.002): self.gen_opt = optimizers.Adam(alpha=lr) self.gen_opt.setup(self.generator) self.gen_opt.add_hook(optimizer.WeightDecay(0.0001)) self.dis_opt = optimizers.Adam(alpha=lr) self.dis_opt.setup(self.discriminator) self.dis_opt.add_hook(optimizer.WeightDecay(0.0001)) # def save(self, i): # with open(f"param/com/com{i}.pickle", mode='wb') as f: # pickle.dump(self.compressor, f) # with open(f"param/gen/gen{i}.pickle", mode='wb') as f: # pickle.dump(self.generator, f) # with open(f"param/dis/dis{i}.pickle", mode='wb') as f: # pickle.dump(self.discriminator, f) # def load(self, i): # with open(f"param/com/com{i}.pickle", mode='rb') as f: # self.compressor = pickle.load(f) # with open(f"param/gen/gen{i}.pickle", mode='rb') as f: # self.generator = pickle.load(f) # with open(f"param/dis/dis{i}.pickle", mode='rb') as f: # self.discriminator = pickle.load(f) def encode(self, x): # print(x.shape) # print(x.shape) a, b, c = x.shape x = x.reshape(a, 1, c).astype(xp.float32) # x = xp.hstack([x[:,:,i:b-440+i:221] for i in range(441)]) * hamming x = xp.concatenate([ x[:, :, :-221].reshape(a, -1, 1, 442), x[:, :, 221:].reshape( a, -1, 1, 442) ], axis=2).reshape(a, -1, 442) * self.mado # print(x) x = xp.fft.fft(x, axis=-1) # xp.fft.fft(xp.arange(100).reshape(2,5,10),axis=-1) x = xp.concatenate( [x.real.reshape(a, 1, -1, 442), x.imag.reshape(a, 1, -1, 442)], axis=1) #.reshape(a, 2, -1, 442) # xp.concatenate([s.real.reshape(2,5,1,10),s.imag.reshape(2,5,1,10)],axis=2) # print(x.shape) x = xp.transpose(x, axes=(0, 1, 3, 2)) # print(x.dtype) return x def decode(self, x): # print(x.shape) a, b, c, d = x.shape x = x[:, 0] + x[:, 1] * 1j # print(x.shape) # x = xp.transpose(x.reshape(a, -1, 442), axes=(0,1,3,2)) # print(x.shape) # x = x.reshape(x.shape[0], -1, 442) x = xp.transpose(xp.fft.ifft(x, axis=1).real, axes=(0, 2, 1)) # print(x.shape) x /= self.mado x = x[:, :-1:2].reshape(a, -1)[:, 221:] + x[:, 1::2].reshape( a, -1)[:, :-221] # print(x.shape) return x def training(self, batchsize=1): for x in range(100): N = self.data.reset() # a,b,c=self.data.test() # d=F.argmax(self.generator(a.astype(xp.float32),b.astype(xp.int16),c.astype(xp.int16)),-2).data.get().reshape(-1) # print(d[25000:26000]) # self.data.save(d, "_") # self.batch(batchsize = 1) for i in range(N // batchsize - 1): # if not i%1: # self.save(i) # g=copy.deepcopy(self.generator).to_cpu # g.to_cpu # print(d[25000:25100]) res = self.batch(batchsize=batchsize) if not i % 10: print( F"{i} time:{int(time.time()-self.time)} G_Loss:{res[0][0]} {res[0][1]} D_Loss:{res[1][0]+res[1][1]} D_Acc:{res[2]}" ) if not i % 100: # save_npz(f"param/com/com_{i}.npz",self.compressor) save_npz(f"param/gen/gen_{i}.npz", self.generator) save_npz(f"param/dis/dis_{i}.npz", self.discriminator) a = xp.asarray(self.data.testData[0][:88200].reshape( 1, 1, 1, -1)) # a=self.encode(a.reshape(1,1,-1)[:,:,:a.shape[-1]//442*442-221]) # a=self.encode(a.reshape(1,1,-1)[:,:,:112047]) # b=self.encode(b) # c=self.encode(c) d = self.generator(a, xp.array([110])).data.get() # d=self.decode(d).get() # print(d.shape) self.data.save(d.flatten(), f"Garagara_{i}") # del d # print(res[-1][0]) # print(res[-1][1]) def batch(self, batchsize=2): x, c = self.data.next(batchSize=batchsize, dataSize=[8190], dataSelect=[0]) x = x[0].reshape(batchsize, 1, 1, -1) c = xp.asarray(c[0]) c_ = xp.random.randint(0, 111, batchsize) c_ = c_ + (c_ >= c) # t = next(self.test) # t = self.data.test(size=6143) # _ = lambda x:self.encode(x) # _ = lambda x:x/xp.float32(32768) # B0_ = _(B0) A_gen = self.generator(x, c_) # print(A_gen.shape) B_gen = self.generator(x, c) F_tf, F_c = self.discriminator(A_gen[:, :, :, 5119:]) T_tf, T_c = self.discriminator(x[:, :, :, 2047:-5119]) dis_acc = (F.argmax(F_tf, axis=1).data.sum(), xp.int32(batchsize) - F.argmax(T_tf, axis=1).data.sum(), (T_c.data.argmax(axis=-1) == c).sum()) # acc = (dis_acc[0]+dis_acc[1])/8 # self.dataRate = self.dataRate if dis_acc[0] == dis_acc[1] else self.dataRate / xp.float32(0.99) if dis_acc[0] > dis_acc[1] else self.dataRate * xp.float32(0.99) # receptionSize = B0.shape[-1] - B_gen.shape[-1] # L_gen0 = F.softmax_cross_entropy(B_gen, B0[:,:,receptionSize:].reshape(batchsize,-1)) # print(B_gen.shape) # print(B0_.shape) # L_gen0 = 0 L_gen0 = F.mean_squared_error(B_gen, x[:, :, :, 1023:-1024]) L_gen1 = F.softmax_cross_entropy(F_tf, xp.zeros(batchsize, dtype=np.int32)) L_gen2 = F.softmax_cross_entropy(F_c, c_) gen_loss = (L_gen0.data, L_gen1.data) L_gen = L_gen1 + L_gen0 + L_gen2 # L_gen = L_gen1 + (L_gen0 if L_gen0.data > 0.0001 else 0) L_dis0 = F.softmax_cross_entropy(F_tf, xp.ones(batchsize, dtype=np.int32)) L_dis1 = F.softmax_cross_entropy(T_tf, xp.zeros(batchsize, dtype=np.int32)) L_dis2 = F.softmax_cross_entropy(T_c, c) dis_loss = (L_dis0.data.get(), L_dis1.data.get(), L_dis2.data.get()) # L_dis = L_dis0 * min(xp.float32(1), 1 / self.dataRate) + L_dis1 * min(xp.float32(1), self.dataRate) L_dis = L_dis0 + L_dis1 + L_dis2 self.generator.cleargrads() L_gen.backward() self.gen_opt.update() self.discriminator.cleargrads() L_dis.backward() self.dis_opt.update() self.dis_opt.alpha *= 0.99999 self.gen_opt.alpha *= 0.99999 return (gen_loss, dis_loss, dis_acc, self.dataRate, (F_tf.data, T_tf.data)) def garagara(self): pass
def train(epochs, iterations, batchsize, validsize, data_path, sketch_path, digi_path, extension, img_size, outdir, modeldir, pretrained_epoch, adv_weight, enf_weight, sn, bn, activ): # Dataset Definition dataloader = DataLoader(data_path, sketch_path, digi_path, extension=extension, img_size=img_size) print(dataloader) color_valid, line_valid, mask_valid, ds_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition generator = SAGeneratorWithGuide(attn_type="sa", bn=bn, activ=activ) #generator = SAGenerator(attn_type="sa", base=64) generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator(sn=sn) discriminator.to_gpu() dis_opt = set_optimizer(discriminator) vgg = VGG() vgg.to_gpu() vgg_opt = set_optimizer(vgg) vgg.base.disable_update() # Loss Function Definition lossfunc = LossCalculator() # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): color, line, mask, mask_ds = dataloader(batchsize) line_input = F.concat([line, mask]) extractor = vgg(mask, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() if epoch > pretrained_epoch: adv_weight = 0.1 enf_weight = 0.0 # Discriminator update fake, _ = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) t_dis = discriminator(color, extractor) loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis) fake.unchain_backward() discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() # Generator update fake, guide = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) loss = adv_weight * lossfunc.gen_hinge_loss(y_dis) loss += enf_weight * lossfunc.positive_enforcing_loss(fake) loss += lossfunc.content_loss(fake, color) loss += 0.9 * lossfunc.content_loss(guide, color) loss += lossfunc.perceptual_loss(vgg, fake, color) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) extractor = vgg(line_valid, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() line_valid_input = F.concat([line_valid, mask_valid]) with chainer.using_config('train', False): y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor) y_valid = y_valid.data.get() c_valid = color_valid.data.get() input_valid = line_valid_input.data.get() guide_valid = guide_valid.data.get() evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")