def train(epochs, iterations, outdir, path, batchsize, validsize, model_type): # Dataset Definition dataloader = DatasetLoader(path) print(dataloader) t_valid, x_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition if model_type == 'ram': model = Model() elif model_type == 'gan': model = Generator() model.to_gpu() optimizer = set_optimizer(model) vgg = VGG() vgg.to_gpu() vgg_opt = set_optimizer(vgg) vgg.base.disable_update() # Loss Function Definition lossfunc = RAMLossFunction() print(lossfunc) # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): t_train, x_train = dataloader(batchsize, mode="train") y_train = model(x_train) y_feat = vgg(y_train) t_feat = vgg(t_train) loss = lossfunc.content_loss(y_train, t_train) loss += lossfunc.perceptual_loss(y_feat, t_feat) model.cleargrads() vgg.cleargrads() loss.backward() optimizer.update() vgg_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{outdir}/model_{epoch}.model", model) with chainer.using_config('train', False): y_valid = model(x_valid) x = x_valid.data.get() y = y_valid.data.get() t = t_valid.data.get() evaluator(x, y, t, epoch, outdir) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, outdir, data_path): # Dataset Definition dataloader = DatasetLoader(data_path) # Model & Optimizer Definition #generator = Generator() generator = GeneratorWithCIN() generator.to_gpu() gen_opt = set_optimizer(generator, alpha=0.0002) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, alpha=0.0001) # Loss Function Definition lossfunc = StarGANVC2LossFunction() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): x_sp, x_label, y_sp, y_label = dataloader.train(batchsize) y_fake = generator(x_sp, F.concat([y_label, x_label])) y_fake.unchain_backward() loss = lossfunc.dis_loss(discriminator, y_fake, x_sp, y_label, x_label) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() y_fake = generator(x_sp, F.concat([y_label, x_label])) x_fake = generator(y_fake, F.concat([x_label, y_label])) x_identity = generator(x_sp, F.concat([x_label, x_label])) loss = lossfunc.gen_loss(discriminator, y_fake, x_fake, x_sp, F.concat([y_label, x_label])) if epoch < 50: loss += lossfunc.identity_loss(x_identity, x_sp) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"modeldirCIN/generator_{epoch}.model", generator) serializers.save_npz('discriminator.model', discriminator) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, batchsize, iterations, nc_size, data_path, modeldir): # Dataset definition dataset = DatasetLoader(data_path, nc_size) # Model Definition & Optimizer Definition generator = Generator(nc_size) generator.to_gpu() gen_opt = set_optimizer(generator, 0.0001, 0.5) discriminator = Discriminator(nc_size) discriminator.to_gpu() dis_opt = set_optimizer(discriminator, 0.0001, 0.5) for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, x_label, y, y_label = dataset.train(batchsize) y_fake = generator(x, y_label) y_fake.unchain_backward() loss = adversarial_loss_dis(discriminator, y_fake, x, y_label, x_label) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() sum_dis_loss += loss.data y_fake = generator(x, y_label) x_fake = generator(y_fake, x_label) x_id = generator(x, x_label) loss = adversarial_loss_gen(discriminator, y_fake, x_fake, x, y_label) if epoch < 20: loss += 10 * F.mean_absolute_error(x_id, x) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_gen_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz("discriminator.model", discriminator) print(f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}")
def train(epochs, iterations, outdir, path, batchsize, validsize): # Dataset Definition dataloader = DatasetLoader(path) print(dataloader) t_valid, x_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition model = Generator() model.to_gpu() optimizer = set_optimizer(model) # Loss Function Definition lossfunc = ESRGANPretrainLossFunction() print(lossfunc) # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): t_train, x_train = dataloader(batchsize, mode="train") y_train = model(x_train) loss = lossfunc.content_loss(y_train, t_train) model.cleargrads() loss.backward() optimizer.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{outdir}/model_{epoch}.model", model) with chainer.using_config('train', False): y_valid = model(x_valid) x = x_valid.data.get() y = y_valid.data.get() t = t_valid.data.get() evaluator(x, y, t, epoch, outdir) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, validsize, src_path, tgt_path, extension, img_size, outdir, modeldir, lr_dis, lr_gen, beta1, beta2): # Dataset definition dataset = DatasetLoader(src_path, tgt_path, extension, img_size) print(dataset) x_val, x_mask_val, y_val, y_mask_val = dataset.valid(validsize) # Model & Optimizer definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy, lr_gen, beta1, beta2) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx, lr_gen, beta1, beta2) discriminator_y = Discriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y, lr_dis, beta1, beta2) discriminator_x = Discriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x, lr_dis, beta1, beta2) # Loss Function definition lossfunc = InstaGANLossFunction() # Visualizer definition visualize = Visualizer() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, x_mask, y, y_mask = dataset.train(batchsize) # discriminator update xy, xy_mask = generator_xy(x, x_mask) yx, yx_mask = generator_yx(y, y_mask) xy.unchain_backward() xy_mask.unchain_backward() yx.unchain_backward() yx_mask.unchain_backward() dis_loss = lossfunc.adversarial_dis_loss(discriminator_y, xy, xy_mask, y, y_mask) dis_loss += lossfunc.adversarial_dis_loss(discriminator_x, yx, yx_mask, x, x_mask) discriminator_y.cleargrads() discriminator_x.cleargrads() dis_loss.backward() dis_y_opt.update() dis_x_opt.update() sum_dis_loss += dis_loss.data # generator update xy, xy_mask = generator_xy(x, x_mask) yx, yx_mask = generator_yx(y, y_mask) xyx, xyx_mask = generator_yx(xy, xy_mask) yxy, yxy_mask = generator_xy(yx, yx_mask) x_id, x_mask_id = generator_yx(x, x_mask) y_id, y_mask_id = generator_xy(y, y_mask) gen_loss = lossfunc.adversarial_gen_loss(discriminator_y, xy, xy_mask) gen_loss += lossfunc.adversarial_gen_loss(discriminator_x, yx, yx_mask) gen_loss += lossfunc.cycle_consistency_loss( xyx, xyx_mask, x, x_mask) gen_loss += lossfunc.cycle_consistency_loss( yxy, yxy_mask, y, y_mask) gen_loss += lossfunc.identity_mapping_loss(x_id, x_mask_id, x, x_mask) gen_loss += lossfunc.identity_mapping_loss(y_id, y_mask_id, y, y_mask) gen_loss += lossfunc.context_preserving_loss( xy, xy_mask, x, x_mask) gen_loss += lossfunc.context_preserving_loss( yx, yx_mask, y, y_mask) generator_xy.cleargrads() generator_yx.cleargrads() gen_loss.backward() gen_xy_opt.update() gen_yx_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx) xy, xy_mask = generator_xy(x_val, x_mask_val) yx, yx_mask = generator_yx(y_val, y_mask_val) x = x_val.data.get() x_mask = x_mask_val.data.get() xy = xy.data.get() xy_mask = xy_mask.data.get() visualize(x, x_mask, xy, xy_mask, outdir, epoch, validsize, switch="mtot") y = y_val.data.get() y_mask = y_mask_val.data.get() yx = yx.data.get() yx_mask = yx_mask.data.get() visualize(y, y_mask, yx, yx_mask, outdir, epoch, validsize, switch="ttom") print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, data_path, extension, img_size, latent_dim, learning_rate, beta1, beta2, enable): # Dataset Definition dataloader = DataLoader(data_path, extension, img_size, latent_dim) print(dataloader) color_valid, line_valid = dataloader(validsize, mode="valid") noise_valid = dataloader.noise_generator(validsize) # Model Definition if enable: encoder = Encoder() encoder.to_gpu() enc_opt = set_optimizer(encoder) generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator, learning_rate, beta1, beta2) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, beta2) # Loss Funtion Definition lossfunc = GauGANLossFunction() # Evaluation Definition evaluator = Evaluaton() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): color, line = dataloader(batchsize) z = dataloader.noise_generator(batchsize) # Discriminator update if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, F.concat([y, line]), F.concat([color, line])) discriminator.cleargrads() dis_loss.backward() dis_opt.update() dis_loss.unchain_backward() sum_dis_loss += dis_loss.data # Generator update z = dataloader.noise_generator(batchsize) if enable: mu, sigma = encoder(color) z = F.gaussian(mu, sigma) y = generator(z, line) gen_loss = lossfunc.gen_loss(discriminator, F.concat([y, line]), F.concat([color, line])) gen_loss += lossfunc.content_loss(y, color) if enable: gen_loss += 0.05 * F.gaussian_kl_divergence(mu, sigma) / batchsize generator.cleargrads() if enable: encoder.cleargrads() gen_loss.backward() gen_opt.update() if enable: enc_opt.update() gen_loss.unchain_backward() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) with chainer.using_config("train", False): y = generator(noise_valid, line_valid) y = y.data.get() sr = line_valid.data.get() cr = color_valid.data.get() evaluator(y, cr, sr, outdir, epoch, validsize=validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, src_path, tgt_path, extension, img_size, learning_rate, beta1 ): # Dataset definition dataloader = DatasetLoader(src_path, tgt_path, extension, img_size) print(dataloader) src_val = dataloader.valid(validsize) # Model & Optimizer definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy, learning_rate, beta1) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx, learning_rate, beta1) discriminator_y = Discriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y, learning_rate, beta1) discriminator_x = Discriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x, learning_rate, beta1) # LossFunction definition lossfunc = CycleGANLossCalculator() # Visualization visualizer = Visualization() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, y = dataloader.train(batchsize) # Discriminator update xy = generator_xy(x) yx = generator_yx(y) xy.unchain_backward() yx.unchain_backward() dis_loss_xy = lossfunc.dis_loss(discriminator_y, xy, y) dis_loss_yx = lossfunc.dis_loss(discriminator_x, yx, x) dis_loss = dis_loss_xy + dis_loss_yx discriminator_x.cleargrads() discriminator_y.cleargrads() dis_loss.backward() dis_x_opt.update() dis_y_opt.update() sum_dis_loss += dis_loss.data # Generator update xy = generator_xy(x) yx = generator_yx(y) xyx = generator_yx(xy) yxy = generator_xy(yx) y_id = generator_xy(y) x_id = generator_yx(x) # adversarial loss gen_loss_xy = lossfunc.gen_loss(discriminator_y, xy) gen_loss_yx = lossfunc.gen_loss(discriminator_x, yx) # cycle-consitency loss cycle_y = lossfunc.cycle_consitency_loss(yxy, y) cycle_x = lossfunc.cycle_consitency_loss(xyx, x) # identity mapping loss identity_y = lossfunc.identity_mapping_loss(y_id, y) identity_x = lossfunc.identity_mapping_loss(x_id, x) gen_loss = gen_loss_xy + gen_loss_yx + cycle_x + cycle_y + identity_x + identity_y generator_xy.cleargrads() generator_yx.cleargrads() gen_loss.backward() gen_xy_opt.update() gen_yx_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy_{epoch}.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx_{epoch}.model", generator_yx) with chainer.using_config('train', False): tgt = generator_xy(src_val) src = src_val.data.get() tgt = tgt.data.get() visualizer(src, tgt, outdir, epoch, validsize) print(f"epoch: {epoch}") print(F"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}")
def train(epochs, iterations, batchsize, validsize, data_path, sketch_path, digi_path, extension, img_size, outdir, modeldir, pretrained_epoch, adv_weight, enf_weight, sn, bn, activ): # Dataset Definition dataloader = DataLoader(data_path, sketch_path, digi_path, extension=extension, img_size=img_size) print(dataloader) color_valid, line_valid, mask_valid, ds_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition generator = SAGeneratorWithGuide(attn_type="sa", bn=bn, activ=activ) #generator = SAGenerator(attn_type="sa", base=64) generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator(sn=sn) discriminator.to_gpu() dis_opt = set_optimizer(discriminator) vgg = VGG() vgg.to_gpu() vgg_opt = set_optimizer(vgg) vgg.base.disable_update() # Loss Function Definition lossfunc = LossCalculator() # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): color, line, mask, mask_ds = dataloader(batchsize) line_input = F.concat([line, mask]) extractor = vgg(mask, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() if epoch > pretrained_epoch: adv_weight = 0.1 enf_weight = 0.0 # Discriminator update fake, _ = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) t_dis = discriminator(color, extractor) loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis) fake.unchain_backward() discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() # Generator update fake, guide = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) loss = adv_weight * lossfunc.gen_hinge_loss(y_dis) loss += enf_weight * lossfunc.positive_enforcing_loss(fake) loss += lossfunc.content_loss(fake, color) loss += 0.9 * lossfunc.content_loss(guide, color) loss += lossfunc.perceptual_loss(vgg, fake, color) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) extractor = vgg(line_valid, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() line_valid_input = F.concat([line_valid, mask_valid]) with chainer.using_config('train', False): y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor) y_valid = y_valid.data.get() c_valid = color_valid.data.get() input_valid = line_valid_input.data.get() guide_valid = guide_valid.data.get() evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, data_path, modeldir, extension, img_size, learning_rate, beta1, weight_decay): # Dataset definition dataset = DatasetLoader(data_path, extension, img_size) # Model & Optimizer definition generator = Generator(dataset.number) generator.to_gpu() gen_opt = set_optimizer(generator, learning_rate, beta1, weight_decay) discriminator = Discriminator(dataset.number) discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay) # Loss Function definition lossfunc = RelGANLossFunction() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x, x_label, y, y_label, z, z_label = dataset.train(batchsize) # Discriminator update # Adversairal loss a = y_label - x_label fake = generator(x, a) fake.unchain_backward() loss = lossfunc.adversarial_loss_dis(discriminator, fake, y) # Interpolation loss rnd = np.random.randint(2) if rnd == 0: alpha = xp.random.uniform(0, 0.5, size=batchsize) else: alpha = xp.random.uniform(0.5, 1.0, size=batchsize) alpha = chainer.as_variable(alpha.astype(xp.float32)) alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number)) fake_0 = generator(x, y_label - y_label) fake_1 = generator(x, alpha * a) fake_0.unchain_backward() fake_1.unchain_backward() loss += 10 * lossfunc.interpolation_loss_dis( discriminator, fake_0, fake, fake_1, alpha, rnd) # Matching loss v2 = y_label - z_label v3 = z_label - x_label loss += lossfunc.matching_loss_dis(discriminator, x, fake, y, z, a, v2, v3) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() sum_dis_loss += loss.data # Generator update # Adversarial loss fake = generator(x, a) loss = lossfunc.adversarial_loss_gen(discriminator, fake) # Interpolation loss rnd = np.random.randint(2) if rnd == 0: alpha = xp.random.uniform(0, 0.5, size=batchsize) else: alpha = xp.random.uniform(0.5, 1.0, size=batchsize) alpha = chainer.as_variable(alpha.astype(xp.float32)) alpha = F.tile(F.expand_dims(alpha, axis=1), (1, dataset.number)) fake_alpha = generator(x, alpha * a) loss += 10 * lossfunc.interpolation_loss_gen( discriminator, fake_alpha) # Matching loss loss += lossfunc.matching_loss_gen(discriminator, x, fake, a) # Cycle-consistency loss cyc = generator(fake, -a) loss += 10 * F.mean_absolute_error(cyc, x) # Self-reconstruction loss fake_0 = generator(x, y_label - y_label) loss += 10 * F.mean_absolute_error(fake_0, x) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_gen_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) print( f"epoch: {epoch} disloss: {sum_dis_loss/iterations} genloss: {sum_gen_loss/iterations}" )
def train(epochs, iterations, batchsize, src_path, tgt_path, modeldir): # Dataset definition dataset = DatasetLoader(src_path, tgt_path) print(dataset) # Model & Optimizer Definition generator_xy = Generator() generator_xy.to_gpu() gen_xy_opt = set_optimizer(generator_xy) generator_yx = Generator() generator_yx.to_gpu() gen_yx_opt = set_optimizer(generator_yx) discriminator_y = MSDiscriminator() discriminator_y.to_gpu() dis_y_opt = set_optimizer(discriminator_y) discriminator_x = MSDiscriminator() discriminator_x.to_gpu() dis_x_opt = set_optimizer(discriminator_x) # Loss Function Definition lossfunc = CycleGANVC2LossFunction() for epoch in range(epochs): sum_gen_loss = 0 sum_dis_loss = 0 for batch in range(0, iterations, batchsize): x, y = dataset.train(batchsize) xy = generator_xy(x) yx = generator_yx(y) xy.unchain_backward() yx.unchain_backward() loss = lossfunc.adv_dis_loss(discriminator_y, xy, y) loss += lossfunc.adv_dis_loss(discriminator_x, yx, x) sum_dis_loss += loss.data discriminator_x.cleargrads() discriminator_y.cleargrads() loss.backward() dis_x_opt.update() dis_y_opt.update() loss.unchain_backward() xy = generator_xy(x) xyx = generator_yx(xy) id_y = generator_xy(y) yx = generator_yx(y) yxy = generator_xy(yx) id_x = generator_yx(x) loss = lossfunc.adv_gen_loss(discriminator_y, xy) loss += lossfunc.adv_gen_loss(discriminator_x, yx) cycle_loss_x = lossfunc.recon_loss(xyx, x) cycle_loss_y = lossfunc.recon_loss(yxy, y) cycle_loss = cycle_loss_x + cycle_loss_y identity_loss_x = lossfunc.recon_loss(id_y, y) identity_loss_y = lossfunc.recon_loss(id_x, x) identity_loss = identity_loss_x + identity_loss_y if epoch > 20: identity_weight = 0.0 else: identity_weight = 5.0 loss += 10 * cycle_loss + identity_weight * identity_loss generator_xy.cleargrads() generator_yx.cleargrads() loss.backward() gen_xy_opt.update() gen_yx_opt.update() loss.unchain_backward() sum_gen_loss += loss.data.get() if batch == 0: serializers.save_npz(f"{modeldir}/generator_xy.model", generator_xy) serializers.save_npz(f"{modeldir}/generator_yx.model", generator_yx) print('epoch : {}'.format(epoch)) print('Generator loss : {}'.format(sum_gen_loss / iterations)) print('Discriminator loss : {}'.format(sum_dis_loss / iterations))
def train_refine(epochs, iterations, batchsize, validsize, data_path, sketch_path, digi_path, st_path, extension, img_size, crop_size, outdir, modeldir, adv_weight, enf_weight): # Dataset Definition dataloader = RefineDataset(data_path, sketch_path, digi_path, st_path, extension=extension, img_size=img_size, crop_size=crop_size) print(dataloader) color_valid, line_valid, mask_valid, ds_valid, cm_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition generator = SAGeneratorWithGuide(attn_type="sa", base=64, bn=True, activ=F.relu) generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) vgg = VGG() vgg.to_gpu() vgg_opt = set_optimizer(vgg) vgg.base.disable_update() # Loss Function Definition lossfunc = LossCalculator() # Evaluation Definition evaluator = Evaluation() iteration = 0 for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): iteration += 1 color, line, mask, mask_ds, color_mask = dataloader(batchsize) line_input = F.concat([line, mask]) extractor = vgg(color_mask, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() # Discriminator update fake, _ = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) t_dis = discriminator(color, extractor) loss = adv_weight * lossfunc.dis_hinge_loss(y_dis, t_dis) fake.unchain_backward() discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() sum_dis_loss += loss.data # Generator update fake, guide = generator(line_input, mask_ds, extractor) y_dis = discriminator(fake, extractor) loss = adv_weight * lossfunc.gen_hinge_loss(y_dis) loss += lossfunc.content_loss(fake, color) loss += 0.9 * lossfunc.content_loss(guide, color) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_gen_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) extractor = vgg(cm_valid, extract=True) extractor = F.average_pooling_2d(extractor, 3, 2, 1) extractor.unchain_backward() line_valid_input = F.concat([line_valid, mask_valid]) with chainer.using_config('train', False): y_valid, guide_valid = generator(line_valid_input, ds_valid, extractor) y_valid = y_valid.data.get() c_valid = color_valid.data.get() input_valid = line_valid_input.data.get() cm_val = cm_valid.data.get() guide_valid = guide_valid.data.get() input_valid = np.concatenate([input_valid[:, 3:6], cm_val], axis=1) evaluator(y_valid, c_valid, input_valid, guide_valid, outdir, epoch, validsize) print(f"iter: {iteration} dis loss: {sum_dis_loss} gen loss: {gen_loss}")
def train(epochs, iterations, dataset_path, test_path, outdir, batchsize, testsize, recon_weight, fm_weight, gp_weight, spectral_norm=False): # Dataset Definition dataloader = DatasetLoader(dataset_path, test_path) c_valid, s_valid = dataloader.test(testsize) # Model & Optimizer Definition if spectral_norm: generator = SNGenerator() else: generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Function Definition lossfunc = FUNITLossFunction() # Evaluator Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): c, ci, s, si = dataloader.train(batchsize) y = generator(c, s) y.unchain_backward() loss = lossfunc.dis_loss(discriminator, y, s, si) loss += lossfunc.gradient_penalty(discriminator, s, y, si) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() y_conert = generator(c, s) y_recon = generator(c, c) adv_loss, recon_loss, fm_loss = lossfunc.gen_loss( discriminator, y_conert, y_recon, s, c, si, ci) loss = adv_loss + recon_weight * recon_loss + fm_weight * fm_loss generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz('generator.model', generator) serializers.save_npz('discriminator.model', discriminator) with chainer.using_config('train', False): y = generator(c_valid, s_valid) y.unchain_backward() y = y.data.get() c = c_valid.data.get() s = s_valid.data.get() evaluator(y, c, s, outdir, epoch, testsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, testsize, img_path, seg_path, outdir, modeldir, n_dis, mode): # Dataset Definition dataloader = DatasetLoader(img_path, seg_path) print(dataloader) valid_noise = dataloader.test(testsize) # Model & Optimizer Definition generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Function Definition lossfunc = SGANLossFunction() # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): for _ in range(n_dis): t, s, noise = dataloader.train(batchsize) y_img, y_seg = generator(noise) loss = lossfunc.dis_loss(discriminator, y_img, y_seg, t, s) loss += lossfunc.gradient_penalty(discriminator, y_img, y_seg, t, s, mode=mode) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() _, _, noise = dataloader.train(batchsize) y_img, y_seg = generator(noise) loss = lossfunc.gen_loss(discriminator, y_img, y_seg) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss = loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator) with chainer.using_config('train', False): y_img, y_seg = generator(valid_noise) y_img = y_img.data.get() y_seg = y_seg.data.get() evaluator(y_img, y_seg, epoch, outdir, testsize=testsize) print(f"epoh: {epoch}") print(f"loss: {sum_loss / iterations}")
input_dim=args.input_dim, n_heads=args.n_heads, n_blocks=args.n_blocks, dropout=args.dropout, ff_hidden_dim=4 * args.input_dim if not args.ff_hidden_dim else args.ff_hidden_dim, transformer_activation=args.transformer_activation, mlp_hidden_dims=mlp_hidden_dims, mlp_activation=args.mlp_activation, mlp_batchnorm=args.mlp_batchnorm, mlp_batchnorm_last=args.mlp_batchnorm_last, mlp_linear_first=args.mlp_linear_first, ) model = WideDeep(wide=wide, deeptabular=deeptabular) optimizers = set_optimizer(model, args) steps_per_epoch = (X_tab_train.shape[0] // args.batch_size) + 1 lr_schedulers = set_lr_scheduler(optimizers, steps_per_epoch, args) early_stopping = EarlyStopping( monitor=args.monitor, min_delta=args.early_stop_delta, patience=args.early_stop_patience, ) trainer = Trainer( model, objective="binary", optimizers=optimizers, lr_schedulers=lr_schedulers,
def train(epochs, iterations, batchsize, modeldir, extension, time_width, mel_bins, sampling_rate, g_learning_rate, d_learning_rate, beta1, beta2, identity_epoch, adv_type, residual_flag, data_path): # Dataset Definition dataloader = DatasetLoader(data_path) # Model & Optimizer Definition generator = GeneratorWithCIN(adv_type=adv_type) generator.to_gpu() gen_opt = set_optimizer(generator, g_learning_rate, beta1, beta2) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, d_learning_rate, beta1, beta2) # Loss Function Definition lossfunc = StarGANVC2LossFunction() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x_sp, x_label, y_sp, y_label = dataloader.train(batchsize) if adv_type == 'sat': y_fake = generator(x_sp, F.concat([y_label, x_label])) elif adv_type == 'orig': y_fake = generator(x_sp, y_label) else: raise AttributeError y_fake.unchain_backward() if adv_type == 'sat': advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss( discriminator, y_fake, x_sp, F.concat([y_label, x_label]), F.concat([x_label, y_label]), residual_flag) elif adv_type == 'orig': advloss_dis_real, advloss_dis_fake = lossfunc.dis_loss( discriminator, y_fake, x_sp, y_label, x_label, residual_flag) else: raise AttributeError dis_loss = advloss_dis_real + advloss_dis_fake discriminator.cleargrads() dis_loss.backward() dis_opt.update() dis_loss.unchain_backward() if adv_type == 'sat': y_fake = generator(x_sp, F.concat([y_label, x_label])) x_fake = generator(y_fake, F.concat([x_label, y_label])) x_identity = generator(x_sp, F.concat([x_label, x_label])) advloss_gen_fake, cycle_loss = lossfunc.gen_loss( discriminator, y_fake, x_fake, x_sp, F.concat([y_label, x_label]), residual_flag) elif adv_type == 'orig': y_fake = generator(x_sp, y_label) x_fake = generator(y_fake, x_label) x_identity = generator(x_sp, x_label) advloss_gen_fake, cycle_loss = lossfunc.gen_loss( discriminator, y_fake, x_fake, x_sp, y_label, residual_flag) else: raise AttributeError if epoch < identity_epoch: identity_loss = lossfunc.identity_loss(x_identity, x_sp) else: identity_loss = call_zeros(advloss_dis_fake) gen_loss = advloss_gen_fake + cycle_loss + identity_loss generator.cleargrads() gen_loss.backward() gen_opt.update() gen_loss.unchain_backward() sum_dis_loss += dis_loss.data sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss / iterations} gen loss: {sum_gen_loss / iterations}" )
def train(epochs, iterations, batchsize, validsize, outdir, modeldir, extension, train_size, valid_size, data_path, sketch_path, digi_path, learning_rate, beta1, weight_decay): # Dataset definition dataset = DatasetLoader(data_path, sketch_path, digi_path, extension, train_size, valid_size) print(dataset) x_val, t_val = dataset.valid(validsize) # Model & Optimizer definition unet = UNet() unet.to_gpu() unet_opt = set_optimizer(unet, learning_rate, beta1, weight_decay) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator, learning_rate, beta1, weight_decay) # Loss function definition lossfunc = Pix2pixLossCalculator() # Visualization definition visualizer = Visualizer() for epoch in range(epochs): sum_dis_loss = 0 sum_gen_loss = 0 for batch in range(0, iterations, batchsize): x, t = dataset.train(batchsize) # Discriminator update y = unet(x) y.unchain_backward() dis_loss = lossfunc.dis_loss(discriminator, y, t) discriminator.cleargrads() dis_loss.backward() dis_opt.update() sum_dis_loss += dis_loss.data # Generator update y = unet(x) gen_loss = lossfunc.gen_loss(discriminator, y) gen_loss += lossfunc.content_loss(y, t) unet.cleargrads() gen_loss.backward() unet_opt.update() sum_gen_loss += gen_loss.data if batch == 0: serializers.save_npz(f"{modeldir}/unet_{epoch}.model", unet) with chainer.using_config("train", False): y = unet(x_val) x = x_val.data.get() t = t_val.data.get() y = y.data.get() visualizer(x, t, y, outdir, epoch, validsize) print(f"epoch: {epoch}") print( f"dis loss: {sum_dis_loss/iterations} gen loss: {sum_gen_loss/iterations}" )
def train(epochs, iterations, outdir, path, batchsize, validsize, adv_weight, content_weight): # Dataset Definition dataloader = DatasetLoader(path) print(dataloader) t_valid, x_valid = dataloader(validsize, mode="valid") # Model & Optimizer Definition model = Generator() model.to_gpu() optimizer = set_optimizer(model) serializers.load_npz('./outdir_pretrain/model_80.model', model) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) vgg = VGG() vgg.to_gpu() vgg_opt = set_optimizer(vgg) vgg.base.disable_update() # Loss Function Definition lossfunc = ESRGANLossFunction() print(lossfunc) # Evaluation Definition evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): t_train, x_train = dataloader(batchsize, mode="train") y_train = model(x_train) y_train.unchain_backward() loss = adv_weight * lossfunc.dis_hinge_loss(discriminator, y_train, t_train) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() y_train = model(x_train) loss = adv_weight * lossfunc.gen_hinge_loss(discriminator, y_train) loss += content_weight * lossfunc.content_loss(y_train, t_train) loss += lossfunc.perceptual_loss(vgg, y_train, t_train) model.cleargrads() vgg.cleargrads() loss.backward() optimizer.update() vgg_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{outdir}/model_{epoch}.model", model) with chainer.using_config('train', False): y_valid = model(x_valid) x = x_valid.data.get() y = y_valid.data.get() t = t_valid.data.get() evaluator(x, y, t, epoch, outdir) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")
def train(epochs, iterations, batchsize, testsize, outdir, modeldir, n_dis, img_path, tag_path): # Dataset Definition dataloader = DatasetLoader(img_path, tag_path) zvis_valid, ztag_valid = dataloader.valid(batchsize) noise_valid = F.concat([zvis_valid, ztag_valid]) # Model & Optimizer Definition generator = Generator() generator.to_gpu() gen_opt = set_optimizer(generator) discriminator = Discriminator() discriminator.to_gpu() dis_opt = set_optimizer(discriminator) # Loss Functio Definition lossfunc = RGANLossFunction() # Evaluation evaluator = Evaluation() for epoch in range(epochs): sum_loss = 0 for batch in range(0, iterations, batchsize): for _ in range(n_dis): zvis, ztag, img, tag = dataloader.train(batchsize) y = generator(F.concat([zvis, ztag])) y.unchain_backward() loss = lossfunc.dis_loss(discriminator, y, img, tag, ztag) loss += lossfunc.gradient_penalty(discriminator, img, tag) discriminator.cleargrads() loss.backward() dis_opt.update() loss.unchain_backward() zvis, ztag, _, _ = dataloader.train(batchsize) y = generator(F.concat([zvis, ztag])) loss = lossfunc.gen_loss(discriminator, y, ztag) generator.cleargrads() loss.backward() gen_opt.update() loss.unchain_backward() sum_loss += loss.data if batch == 0: serializers.save_npz(f"{modeldir}/generator_{epoch}.model", generator) serializers.save_npz(f"{modeldir}/discriminator_{epoch}.model", discriminator) with chainer.using_config('train', False): y = generator(noise_valid) y = y.data.get() evaluator(y, outdir, epoch, testsize) print(f"epoch: {epoch}") print(f"loss: {sum_loss / iterations}")