def training_step( draft_gen: Generator, gen: Generator, gen_optim, batch: Tuple[Tensor, Tensor, Tensor], ) -> Dict[str, Tensor]: line, line_draft, hint, color = batch batch_size = line.shape[0] mask = opt.mask_gen(list(hint.shape), X, batch_size // 2) hint = hint * mask image_size = line.shape[1] draft = build_draft(draft_gen, line_draft, hint, image_size) with tf.GradientTape() as tape: _color = gen({"line": line, "hint": draft}, training=True) loss = losses.l1_loss(_color, color) grad = tape.gradient(loss, gen.trainable_variables) gen_optim.apply_gradients(zip(grad, gen.trainable_variables)) return { # image "line": line, "hint": hint, "color": color, "draft": draft, "_color": _color, # scaler "l1_loss": loss, }
def testL1Loss(self): with self.test_session(): shape = [5, 5, 5] num_elem = 5 * 5 * 5 weights = tf.constant(1.0, shape=shape) wd = 0.01 loss = losses.l1_loss(weights, wd) self.assertEquals(loss.op.name, 'L1Loss/value') self.assertAlmostEqual(loss.eval(), num_elem * wd, 5)
def _frc_losses(self, ops={}, suffix=''): # classification loss cls_score = self.get_output('cls_score'+suffix) # ops['cls_score' + "_0"] = cls_score # ops['y_outs'] ops['loss_cls'+suffix] = losses.sparse_softmax(cls_score, self.data['labels'], name='cls_loss'+suffix) # bounding box regression L1 loss if cfg.TRAIN.BBOX_REG: bbox_pred = self.get_output('bbox_pred'+suffix) ops['loss_box'+suffix] = losses.l1_loss(bbox_pred, self.data['bbox_targets'], 'reg_loss'+suffix, self.data['bbox_inside_weights']) else: print('NO BBOX REGRESSION!!!!!') return ops
def training_step( gen: Generator, disc: Discriminator, gen_optim, disc_optim, batch: Tuple[Tensor, Tensor, Tensor], ) -> Dict[str, Tensor]: line, hint, color = batch batch_size = line.shape[0] mask = opt.mask_gen(list(hint.shape), X, batch_size // 2) hint = hint * mask with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: _color = gen({"line": line, "hint": hint, "training": True}) logits = disc(color, training=True) _logits = disc(_color, training=True) adv_loss = losses.binary_crossentropy(_logits, tf.ones_like(_logits)) l1_loss = losses.l1_loss(_color, color) gen_loss = adv_loss + (l1_loss * 100) real_loss = losses.binary_crossentropy(logits, tf.ones_like(logits)) fake_loss = losses.binary_crossentropy(_logits, tf.zeros_like(_logits)) disc_loss = (real_loss + fake_loss) / 2 gen_grad = gen_tape.gradient(gen_loss, gen.trainable_variables) disc_grad = disc_tape.gradient(disc_loss, disc.trainable_variables) gen_optim.apply_gradients(zip(gen_grad, gen.trainable_variables)) disc_optim.apply_gradients(zip(disc_grad, disc.trainable_variables)) return { # image "line": line, "hint": hint, "color": color, "_color": _color, # scaler "logits": logits, "_logits": _logits, "adv_loss": adv_loss, "l1_loss": l1_loss, "gen_loss": gen_loss, "real_loss": real_loss, "fake_loss": fake_loss, "disc_loss": disc_loss, }
def find_lr(data, batch_size=1, start_lr=1e-9, end_lr=1): generator = Generator(input_shape=[None,None,2]) discriminator = Discriminator(input_shape=[None,None,1]) generator_optimizer = tf.keras.optimizers.Adam(lr=start_lr) discriminator_optimizer = tf.keras.optimizers.Adam(lr=start_lr) model_name = data['training'].origin+'_2_any' checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name) if(not os.path.isdir(checkpoint_prefix)): os.makedirs(checkpoint_prefix) epoch_size = data['training'].__len__() lr_mult = (end_lr / start_lr) ** (1 / epoch_size) lrs = [] losses = { 'gen_mae': [], 'gen_loss': [], 'disc_loss': [] } best_losses = { 'gen_mae': 1e9, 'gen_loss': 1e9, 'disc_loss': 1e9 } print() print("Finding the optimal LR with the following parameters: ") print("\tCheckpoints: \t", checkpoint_prefix) print("\tEpochs: \t", 1) print("\tBatchSize: \t", batch_size) print("\tnBatches: \t", epoch_size) print() print('Epoch {}/{}'.format(1, 1)) progbar = tf.keras.utils.Progbar(epoch_size) for i in range(epoch_size): # Get the data from the DataGenerator input_image, target = data['training'].__getitem__(i) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Generate a fake image gen_output = generator(input_image, training=True) # Train the discriminator disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True) disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True) # Compute the losses gen_mae = l1_loss(target, gen_output) gen_loss = generator_loss(disc_generated_output, gen_mae) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) # Compute the gradients generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables) discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables) # Apply the gradients generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables)) # Convert losses to numpy gen_mae = gen_mae.numpy() gen_loss = gen_loss.numpy() disc_loss = disc_loss.numpy() # Update the progress bar progbar.add(1, values=[ ("gen_mae", gen_mae), ("gen_loss", gen_loss), ("disc_loss", disc_loss) ]) # On batch end lr = tf.keras.backend.get_value(generator_optimizer.lr) lrs.append(lr) # Update the lr lr *= lr_mult tf.keras.backend.set_value(generator_optimizer.lr, lr) tf.keras.backend.set_value(discriminator_optimizer.lr, lr) # Update the losses losses['gen_mae'].append(gen_mae) losses['gen_loss'].append(gen_loss) losses['disc_loss'].append(disc_loss) # Update the best losses if(best_losses['gen_mae'] > gen_mae): best_losses['gen_mae'] = gen_mae if(best_losses['gen_loss'] > gen_loss): best_losses['gen_loss'] = gen_loss if(best_losses['disc_loss'] > disc_loss): best_losses['disc_loss'] = disc_loss if(gen_mae >= 100*best_losses['gen_mae'] or gen_loss >= 100*best_losses['gen_loss'] or disc_loss >= 100*best_losses['disc_loss']): break plot_loss_findlr(losses['gen_mae'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_mae.tiff')) plot_loss_findlr(losses['gen_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_loss.tiff')) plot_loss_findlr(losses['disc_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_disc_loss.tiff')) print('Best losses:') print('gen_mae =', best_losses['gen_mae']) print('gen_loss =', best_losses['gen_loss']) print('disc_loss =', best_losses['disc_loss'])
def train(data, epochs, batch_size=1, gen_lr=5e-6, disc_lr=5e-7, epoch_offset=0): generator = Generator(input_shape=[None,None,2]) discriminator = Discriminator(input_shape=[None,None,1]) generator_optimizer = tf.keras.optimizers.Adam(gen_lr) discriminator_optimizer = tf.keras.optimizers.Adam(disc_lr) model_name = data['training'].origin+'_2_any' checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name) if(not os.path.isdir(checkpoint_prefix)): os.makedirs(checkpoint_prefix) else: if(os.path.isfile(os.path.join(checkpoint_prefix, 'generator.h5'))): generator.load_weights(os.path.join(checkpoint_prefix, 'generator.h5'), by_name=True) print('Generator weights restorred from ' + checkpoint_prefix) if(os.path.isfile(os.path.join(checkpoint_prefix, 'discriminator.h5'))): discriminator.load_weights(os.path.join(checkpoint_prefix, 'discriminator.h5'), by_name=True) print('Discriminator weights restorred from ' + checkpoint_prefix) # Get the number of batches in the training set epoch_size = data['training'].__len__() print() print("Started training with the following parameters: ") print("\tCheckpoints: \t", checkpoint_prefix) print("\tEpochs: \t", epochs) print("\tgen_lr: \t", gen_lr) print("\tdisc_lr: \t", disc_lr) print("\tBatchSize: \t", batch_size) print("\tnBatches: \t", epoch_size) print() # Precompute the test input and target for validation audio_input = load_audio(os.path.join(TEST_AUDIOS_PATH, data['training'].origin+'.wav')) mag_input, phase = forward_transform(audio_input) mag_input = amplitude_to_db(mag_input) test_input = slice_magnitude(mag_input, mag_input.shape[0]) test_input = (test_input * 2) - 1 test_inputs = [] test_targets = [] for t in data['training'].target: audio_target = load_audio(os.path.join(TEST_AUDIOS_PATH, t+'.wav')) mag_target, _ = forward_transform(audio_target) mag_target = amplitude_to_db(mag_target) test_target = slice_magnitude(mag_target, mag_target.shape[0]) test_target = (test_target * 2) - 1 test_target_perm = test_target[np.random.permutation(test_target.shape[0]),:,:,:] test_inputs.append(np.concatenate([test_input, test_target_perm], axis=3)) test_targets.append(test_target) gen_mae_list, gen_mae_val_list = [], [] gen_loss_list, gen_loss_val_list = [], [] disc_loss_list, disc_loss_val_list = [], [] for epoch in range(epochs): gen_mae_total, gen_mae_val_total = 0, 0 gen_loss_total, gen_loss_val_total = 0, 0 disc_loss_total, disc_loss_val_total = 0, 0 print('Epoch {}/{}'.format((epoch+1)+epoch_offset, epochs+epoch_offset)) progbar = tf.keras.utils.Progbar(epoch_size) for i in range(epoch_size): # Get the data from the DataGenerator input_image, target = data['training'].__getitem__(i) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Generate a fake image gen_output = generator(input_image, training=True) # Train the discriminator disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True) disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True) # Compute the losses gen_mae = l1_loss(target, gen_output) gen_loss = generator_loss(disc_generated_output, gen_mae) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) # Compute the gradients generator_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables) discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables) # Apply the gradients generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables)) # Update the progress bar gen_mae = gen_mae.numpy() gen_loss = gen_loss.numpy() disc_loss = disc_loss.numpy() gen_mae_total += gen_mae gen_loss_total += gen_loss disc_loss_total += disc_loss progbar.add(1, values=[ ("gen_mae", gen_mae), ("gen_loss", gen_loss), ("disc_loss", disc_loss) ]) gen_mae_list.append(gen_mae_total/epoch_size) gen_mae_val_list.append(gen_mae_val_total/epoch_size) gen_loss_list.append(gen_loss_total/epoch_size) gen_loss_val_list.append(gen_loss_val_total/epoch_size) disc_loss_list.append(disc_loss_total/epoch_size) disc_loss_val_list.append(disc_loss_val_total/epoch_size) history = pd.DataFrame({ 'gen_mae': gen_mae_list, 'gen_mae_val': gen_mae_val_list, 'gen_loss': gen_loss_list, 'gen_loss_val': gen_loss_val_list, 'disc_loss': disc_loss_list, 'disc_loss_val': disc_loss_val_list }) write_csv(history, os.path.join(checkpoint_prefix, 'history.csv')) epoch_output = os.path.join(OUTPUT_PATH, model_name, str((epoch+1)+epoch_offset).zfill(3)) init_directory(epoch_output) # Generate audios and save spectrograms for the entire audios for j in range(len(data['training'].target)): prediction = generator(test_inputs[j], training=False) prediction = (prediction + 1) / 2 generate_images(prediction, (test_inputs[j] + 1) / 2, (test_targets[j] + 1) / 2, os.path.join(epoch_output, 'spectrogram_'+data['training'].target[j])) generate_audio(prediction, phase, os.path.join(epoch_output, 'audio_'+data['training'].target[j]+'.wav')) print('Epoch outputs saved in ' + epoch_output) # Save the weights generator.save_weights(os.path.join(checkpoint_prefix, 'generator.h5')) discriminator.save_weights(os.path.join(checkpoint_prefix, 'discriminator.h5')) print('Weights saved in ' + checkpoint_prefix) # Callback at the end of the epoch for the DataGenerator data['training'].on_epoch_end()
def train(data, epochs, batch_size=1, lr=1e-3, epoch_offset=0): generator = Generator() generator_optimizer = tf.keras.optimizers.Adam(lr) model_name = data['training'].origin + '_2_' + data[ 'training'].target + '_generator' checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name) if (not os.path.isdir(checkpoint_prefix)): os.makedirs(checkpoint_prefix) else: if (os.path.isfile(os.path.join(checkpoint_prefix, 'generator.h5'))): generator.load_weights(os.path.join(checkpoint_prefix, 'generator.h5'), by_name=True) print('Generator weights restorred from ' + checkpoint_prefix) # Get the number of batches in the training set epoch_size = data['training'].__len__() print() print("Started training with the following parameters: ") print("\tCheckpoints: \t", checkpoint_prefix) print("\tEpochs: \t", epochs) print("\tgen_lr: \t", lr) print("\tBatchSize: \t", batch_size) print("\tnBatches: \t", epoch_size) print() # Precompute the test input and target for validation audio_input = load_audio( os.path.join(TEST_AUDIOS_PATH, data['training'].origin + '.wav')) mag_input, phase = forward_transform(audio_input) mag_input = amplitude_to_db(mag_input) test_input = slice_magnitude(mag_input, mag_input.shape[0]) test_input = (test_input * 2) - 1 audio_target = load_audio( os.path.join(TEST_AUDIOS_PATH, data['training'].target + '.wav')) mag_target, _ = forward_transform(audio_target) mag_target = amplitude_to_db(mag_target) test_target = slice_magnitude(mag_target, mag_target.shape[0]) test_target = (test_target * 2) - 1 gen_mae_list, gen_mae_val_list = [], [] for epoch in range(epochs): gen_mae_total, gen_mae_val_total = 0, 0 print('Epoch {}/{}'.format((epoch + 1) + epoch_offset, epochs + epoch_offset)) progbar = tf.keras.utils.Progbar(epoch_size) for i in range(epoch_size): input_image, target = data['training'].__getitem__(i) with tf.GradientTape() as gen_tape: # Generate a fake image gen_output = generator(input_image, training=True) # Compute the losses gen_mae = l1_loss(target, gen_output) # Timbre transfer # gen_mae = l1_loss(input_image, gen_output) # Autoencoder # Compute the gradients generator_gradients = gen_tape.gradient( gen_mae, generator.trainable_variables) # Apply the gradients generator_optimizer.apply_gradients( zip(generator_gradients, generator.trainable_variables)) # Update the progress bar gen_mae = gen_mae.numpy() gen_mae_total += gen_mae progbar.add(1, values=[("gen_mae", gen_mae)]) gen_mae_total /= epoch_size gen_mae_list.append(gen_mae_total) gen_mae_val_list.append(gen_mae_val_total) history = pd.DataFrame({ 'gen_mae': gen_mae_list, 'gen_mae_val': gen_mae_val_list }) write_csv(history, os.path.join(checkpoint_prefix, 'history.csv')) epoch_output = os.path.join(OUTPUT_PATH, model_name, str((epoch + 1) + epoch_offset).zfill(3)) init_directory(epoch_output) # Generate audios and save spectrograms for the entire audios prediction = generator(test_input, training=False) prediction = (prediction + 1) / 2 generate_images(prediction, (test_input + 1) / 2, (test_target + 1) / 2, os.path.join(epoch_output, 'spectrogram')) generate_audio(prediction, phase, os.path.join(epoch_output, 'audio.wav')) print('Epoch outputs saved in ' + epoch_output) # Save the weights generator.save_weights(os.path.join(checkpoint_prefix, 'generator.h5')) print('Weights saved in ' + checkpoint_prefix) # Callback at the end of the epoch for the DataGenerator data['training'].on_epoch_end()
def main(): # Get data loader loader = get_loader(config, train=False) num_test = len(loader.dataset) # Generator network for target domain if img_size == 64: G = Generator64().cuda() else: G = Generator32().cuda() g_ckpt = torch.load('models/' + gen_ckpt) G.load_state_dict(g_ckpt) G.eval() # Transfer network from target to source domain T = Net(config).cuda() last_model_name = get_model_list(checkpoint_directory) t_ckpt = torch.load(last_model_name) T.load_state_dict(t_ckpt) T.eval() vgg19 = VGG19(vgg_ckpt).cuda() for idx in range(num_test): # Source image source = loader.dataset[idx] if dataset in ['svhn', 'mnist']: source = source[0] source = source.cuda() source = source.unsqueeze(0) source = source.repeat(num_sample, 1, 1, 1) # Latent Z for training set z = torch.randn(num_sample, z_dim, 1, 1).cuda() z.requires_grad = True optimizer_Z = optim.Adam([z], lr=lr_z) for it in range(num_eval_iter): # Update Z vector target = G(z) target_downsampled = T.get_downsampled_images(target) target2source = T(target_downsampled) source_features = vgg19(source) target2source_features = vgg19(target2source) l1_loss_samples_z = l1_loss(target2source, source) perceptual_loss_samples_z = perceptual_loss(target2source_features, source_features) loss_z = l1_w * l1_loss_samples_z + vgg_w * perceptual_loss_samples_z loss_z_samples = loss_z loss_z = loss_z.mean() optimizer_Z.zero_grad() loss_z.backward() optimizer_Z.step() print("Image: {}, Step: {}, Loss: {}, L1: {}, VGG: {}".format(idx, it, loss_z, l1_loss_samples_z.mean(), perceptual_loss_samples_z.mean())) save_result(source, target, idx, num_save, evaluation_directory, loss_z_samples)
def compute_losses(self): cycle_consistency_loss_a = \ self._lambda_a * losses.cycle_consistency_loss( real_images=tf.expand_dims(self.input_a[:,:,:,1], axis=3), generated_images=self.cycle_images_a, ) cycle_consistency_loss_b = \ self._lambda_b * losses.cycle_consistency_loss( real_images=tf.expand_dims(self.input_b[:,:,:,1], axis=3), generated_images=self.cycle_images_b, ) lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real) lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real) lsgan_loss_f = losses.lsgan_loss_generator(self.prob_fea_b_is_real) lsgan_loss_a_aux = losses.lsgan_loss_generator( self.prob_fake_a_aux_is_real) losssea = 0.01 * losses.l1_loss(self.fea_A_separate_B) lossseb = 0.01 * losses.l1_loss(self.fea_B_separate_A) lossseaf = 0.01 * losses.l1_loss(self.fea_FA_separate_B) losssebf = 0.01 * losses.l1_loss(self.fea_FB_separate_A) dif_loss = losssea + lossseb + lossseaf + losssebf ce_loss_b, dice_loss_b = losses.task_loss(self.pred_mask_fake_b, self.gt_a) ce_loss_a, dice_loss_a = losses.task_loss(self.pre_mask_real_a, self.gt_a) l2_loss_b = tf.add_n([ 0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables() if '/s_B/' in v.name or '/e_B/' in v.name ]) g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a self.loss_f_weight = tf.placeholder(tf.float32, shape=[], name="loss_f_weight") self.loss_f_weight_summ = tf.summary.scalar("loss_f_weight", self.loss_f_weight) #seg_loss_B = ce_loss_b + dice_loss_b + ce_loss_a + dice_loss_a + l2_loss_b + 0.1*g_loss_B + self.loss_f_weight*lsgan_loss_f + 0.1*lsgan_loss_a_aux seg_loss_B = ce_loss_b + dice_loss_b + l2_loss_b + 0.1 * g_loss_B + self.loss_f_weight * lsgan_loss_f + 0.1 * lsgan_loss_a_aux seg_loss_A = ce_loss_a + dice_loss_a + l2_loss_b + +0.1 * g_loss_A d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_a_is_real, prob_fake_is_real=self.prob_fake_pool_a_is_real, ) d_loss_A_aux = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_cycle_a_aux_is_real, prob_fake_is_real=self.prob_fake_pool_a_aux_is_real, ) d_loss_A = d_loss_A + d_loss_A_aux d_loss_B = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_b_is_real, prob_fake_is_real=self.prob_fake_pool_b_is_real, ) d_loss_F = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_fea_fake_b_is_real, prob_fake_is_real=self.prob_fea_b_is_real, ) optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg) self.model_vars = tf.trainable_variables() d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name] d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name] e_c_vars = [var for var in self.model_vars if '/e_c/' in var.name] e_cs_vars = [var for var in self.model_vars if '/e_cs/' in var.name] e_ct_vars = [var for var in self.model_vars if '/e_ct/' in var.name] de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name] de_A_vars = [var for var in self.model_vars if '/de_A/' in var.name] de_c_vars = [var for var in self.model_vars if '/de_c/' in var.name] s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name] d_F_vars = [var for var in self.model_vars if '/d_F/' in var.name] e_dB_vars = [var for var in self.model_vars if '/e_dB/' in var.name] e_dA_vars = [var for var in self.model_vars if '/e_dA/' in var.name] self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars) self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars) self.dif_trainer = optimizer.minimize(dif_loss, var_list=e_dB_vars + e_dA_vars) self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=de_A_vars + de_c_vars + e_c_vars + e_cs_vars + e_dB_vars) self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=de_B_vars + de_c_vars + e_c_vars + e_ct_vars + e_dA_vars) self.s_B_trainer = optimizer_seg.minimize(seg_loss_B, var_list=e_c_vars + e_ct_vars + s_B_vars) self.s_A_trainer = optimizer_seg.minimize(seg_loss_A, var_list=e_c_vars + e_cs_vars + s_B_vars) self.d_F_trainer = optimizer.minimize(d_loss_F, var_list=d_F_vars) for var in self.model_vars: print(var.name) # Summary variables for tensorboard self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A) self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B) self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A) self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B) self.dif_loss_summ = tf.summary.scalar("dif_loss", dif_loss) self.ce_B_loss_summ = tf.summary.scalar("ce_B_loss", ce_loss_b) self.dice_B_loss_summ = tf.summary.scalar("dice_B_loss", dice_loss_b) self.l2_B_loss_summ = tf.summary.scalar("l2_B_loss", l2_loss_b) self.s_B_loss_summ = tf.summary.scalar("s_B_loss", seg_loss_B) self.s_A_loss_summ = tf.summary.scalar("s_A_loss", seg_loss_A) self.s_B_loss_merge_summ = tf.summary.merge([ self.ce_B_loss_summ, self.dice_B_loss_summ, self.l2_B_loss_summ, self.s_B_loss_summ, self.s_A_loss_summ ]) self.d_F_loss_summ = tf.summary.scalar("d_F_loss", d_loss_F)
def build_model(self): self.input_shape = (self.batchsize, self.args['input']['size'], self.args['input']['size'], self.args['input']['channel']) self.output_shape = (self.batchsize, self.args['output']['size'], self.args['output']['size'], self.args['output']['channel']) # set placeholder ## s for source # if self.flag_L1 or self.flag_d_intra or self.flag_task: self.s_gm = tf.placeholder(dtype=tf.float32, shape=self.input_shape) # if self.flag_L1 or self.flag_d_intra or self.flag_d_inter: self.s_color = tf.placeholder(dtype=tf.float32, shape=self.output_shape) if self.flag_task: temp = (self.args['batchsize'], self.args['model']['tasknet']['num_classes']) self.s_label = tf.placeholder(dtype=tf.float32, shape=temp) ## t for target if self.flag_d_inter: self.t_color = tf.placeholder(dtype=tf.float32, shape=self.output_shape) self.t_gm = tf.placeholder(dtype=tf.float32, shape=self.input_shape) # generate images if self.flag_L1 or self.flag_d_intra or self.flag_task: self.fake_s_color = generator(self.s_gm, gf_dim=self.gf_dim, o_c=self.output_shape[-1]) if self.flag_d_inter: self.fake_t_color = generator(self.t_gm, gf_dim=self.gf_dim, o_c=self.output_shape[-1]) # compute loss ## intra-domain self.loss_dict = {} if self.flag_d_intra: d_intra_logits_real = discriminator(self.s_color, df_dim=self.df_dim, name='intra_discriminator') d_intra_logits_fake = discriminator(self.fake_s_color, df_dim=self.df_dim, name='intra_discriminator') d_intra_loss_real = losses.nsgan_loss(d_intra_logits_real, is_real=True) d_intra_loss_fake = losses.nsgan_loss(d_intra_logits_fake, is_real=False) self.d_intra_loss = d_intra_loss_real + d_intra_loss_fake tf.summary.scalar("d_intra_loss", self.d_intra_loss) self.loss_dict.update({'d_intra_loss': self.d_intra_loss}) ## inter-domain if self.flag_d_inter: d_inter_logits_real = discriminator(self.s_color, df_dim=self.df_dim, name='inter_discriminator') d_inter_logits_fake = discriminator(self.fake_t_color, df_dim=self.df_dim, name='inter_discriminator') d_inter_loss_real = losses.nsgan_loss(d_inter_logits_real, is_real=True) d_inter_loss_fake = losses.nsgan_loss(d_inter_logits_fake, is_real=False) self.d_inter_loss = d_inter_loss_real + d_inter_loss_fake tf.summary.scalar("d_inter_loss", self.d_inter_loss) self.loss_dict.update({'d_inter_loss': self.d_inter_loss}) ## Generator loss flag = False self.g_loss_dict = {} ### l1 loss if self.flag_L1: self.l1_loss = losses.l1_loss(self.fake_s_color, self.s_color) self.g_loss_dict.update({'g_l1': self.l1_loss}) _lambda = self.args['model']['lambda_L1'] self.g_loss = _lambda * self.l1_loss flag = True tf.summary.scalar("l1_loss", self.l1_loss) ### task loss if self.flag_task: num_classes = self.args['model']['tasknet']['num_classes'] task_net = ResNet50(self.fake_s_color, num_classes, phase=False) task_pred_logits = task_net.outputs self.task_loss = losses.task_loss(task_pred_logits, self.s_label) self.g_loss_dict.update({'g_loss_task': self.task_loss}) _lambda = self.args['model']['tasknet']['lambda_L_task'] if flag: self.g_loss += _lambda * self.task_loss else: self.g_loss = _lambda * self.task_loss flag = True tf.summary.scalar("g_task_loss", self.task_loss) ### d-intra loss if self.flag_d_intra: self.g_loss_intra = losses.nsgan_loss(d_intra_logits_fake, True) self.g_loss_dict.update({'g_d_intra': self.g_loss_intra}) _lambda = self.args['model']['discriminator_intra'][ 'lambda_L_d_intra'] if flag: self.g_loss += _lambda * self.g_loss_intra else: self.g_loss = _lambda * self.g_loss_intra flag = True tf.summary.scalar("g_loss_intra", self.g_loss_intra) ### d-inter loss if self.flag_d_inter: self.g_loss_inter = losses.nsgan_loss(d_inter_logits_fake, True) self.g_loss_dict.update({'g_d_inter': self.g_loss_inter}) _lambda = self.args['model']['discriminator_inter'][ 'lambda_L_d_inter'] if flag: self.g_loss += _lambda * self.g_loss_inter else: self.g_loss = _lambda * self.g_loss_inter flag = True tf.summary.scalar("g_loss_inter", self.g_loss_inter) tf.summary.scalar("g_loss", self.g_loss) self.loss_dict.update(self.g_loss_dict) #log self.sample = tf.concat( [self.fake_s_color, self.s_gm[:, :, :, 1:], self.s_color], 2) if self.flag_d_inter: sample_t = tf.concat( [self.fake_t_color, self.t_gm[:, :, :, 1:], self.t_color], 2) self.sample = tf.concat([self.sample, sample_t], 1) self.sample = (self.sample + 1) * 127.5 #divide variable group t_vars = tf.trainable_variables() global_vars = tf.global_variables() self.normnet_vars_global = [] if self.flag_d_intra: self.d_intra_vars = [ var for var in t_vars if 'intra_discriminator' in var.name ] self.d_intra_vars_global = [ var for var in global_vars if 'intra_discriminator' in var.name ] self.normnet_vars_global += self.d_intra_vars_global if self.flag_d_inter: self.d_inter_vars = [ var for var in t_vars if 'inter_discriminator' in var.name ] self.d_inter_vars_global = [ var for var in global_vars if 'inter_discriminator' in var.name ] self.normnet_vars_global += self.d_inter_vars_global self.g_vars = [var for var in t_vars if 'generator' in var.name] self.g_vars_global = [ var for var in global_vars if 'generator' in var.name ] self.normnet_vars_global += self.g_vars_global if self.flag_task: self.tasknet_vars = [ var for var in t_vars if var not in self.normnet_vars_global ] # self.tasknet_vars_tainable = self.tasknet_vars[44:] self.tasknet_vars_global = [ var for var in global_vars if var not in self.normnet_vars_global ] #saver vars_save = self.normnet_vars_global self.saver = tf.train.Saver(var_list=vars_save, max_to_keep=20)
def test_(image_size: int): image = build_tensor(1, image_size, 3) loss = l1_loss(image, image) assert loss == 0
def main(): # Get data loader loader = get_loader(config) # Latent Z for training set z_y = torch.randn(num_train, z_dim).cuda() z_y = z_y.view(num_train, z_dim, 1, 1) z_y.requires_grad = True optimizer_Z = optim.Adam([z_y], lr=lr_z) scheduler_Z = optim.lr_scheduler.StepLR(optimizer_Z, step_size=step_size, gamma=gamma) # Generator network for target domain if img_size == 64: G = Generator64().cuda() else: G = Generator32().cuda() g_ckpt = torch.load('models/' + gen_ckpt) G.load_state_dict(g_ckpt) G.eval() # Transfer network from target to source domain T = Net(config).cuda() T.weight_init(mean=0.0, std=0.02) optimizer_T = optim.Adam(T.parameters(), lr=lr_t) scheduler_T = optim.lr_scheduler.StepLR(optimizer_T, step_size=step_size, gamma=gamma) vgg19 = VGG19(vgg_ckpt).cuda() for epoch in range(num_epochs): for step, data in enumerate(loader): if dataset in ['svhn', 'mnist']: data = data[0] source = data.cuda() z = z_y[step * batch_size:(step + 1) * batch_size] T.eval() # Update Z vector target = G(z) target_downsampled = T.get_downsampled_images(target) target2source = T(target_downsampled) source_features = vgg19(source) target2source_features = vgg19(target2source) l1_loss_samples_z = l1_loss(target2source, source) perceptual_loss_samples_z = perceptual_loss( target2source_features, source_features) loss_z = l1_w * l1_loss_samples_z + vgg_w * perceptual_loss_samples_z loss_z = loss_z.mean() optimizer_Z.zero_grad() loss_z.backward() optimizer_Z.step() T.train() # Update the T network target = G(z) target_downsampled = T.get_downsampled_images(target) target2source = T(target_downsampled) source_features = vgg19(source) target2source_features = vgg19(target2source) l1_loss_samples_t = l1_loss(target2source, source) perceptual_loss_samples_t = perceptual_loss( target2source_features, source_features) loss_t = l1_w * l1_loss_samples_t + vgg_w * perceptual_loss_samples_t loss_t = loss_t.mean() optimizer_T.zero_grad() loss_t.backward() optimizer_T.step() print("Epoch: {}, Step: {}, Loss: {}, L1: {}, VGG: {}".format( epoch, step, loss_t, l1_loss_samples_t.mean(), perceptual_loss_samples_t.mean())) scheduler_Z.step() scheduler_T.step() if (epoch + 1) % img_save_it == 0: save_result(source, target, epoch, num_save, image_directory) if (epoch + 1) % model_save_it == 0: torch.save( T.state_dict(), os.path.join(checkpoint_directory, "transfer_network_param_{}.pkl".format(epoch))) torch.save( T.state_dict(), os.path.join(checkpoint_directory, "transfer_network_param_final.pkl"))