Exemplo n.º 1
0
def training_step(
    draft_gen: Generator,
    gen: Generator,
    gen_optim,
    batch: Tuple[Tensor, Tensor, Tensor],
) -> Dict[str, Tensor]:
    line, line_draft, hint, color = batch
    batch_size = line.shape[0]
    mask = opt.mask_gen(list(hint.shape), X, batch_size // 2)
    hint = hint * mask
    image_size = line.shape[1]

    draft = build_draft(draft_gen, line_draft, hint, image_size)

    with tf.GradientTape() as tape:
        _color = gen({"line": line, "hint": draft}, training=True)
        loss = losses.l1_loss(_color, color)

    grad = tape.gradient(loss, gen.trainable_variables)
    gen_optim.apply_gradients(zip(grad, gen.trainable_variables))

    return {
        # image
        "line": line,
        "hint": hint,
        "color": color,
        "draft": draft,
        "_color": _color,
        # scaler
        "l1_loss": loss,
    }
Exemplo n.º 2
0
 def testL1Loss(self):
     with self.test_session():
         shape = [5, 5, 5]
         num_elem = 5 * 5 * 5
         weights = tf.constant(1.0, shape=shape)
         wd = 0.01
         loss = losses.l1_loss(weights, wd)
         self.assertEquals(loss.op.name, 'L1Loss/value')
         self.assertAlmostEqual(loss.eval(), num_elem * wd, 5)
Exemplo n.º 3
0
    def _frc_losses(self, ops={}, suffix=''):
        # classification loss
        cls_score = self.get_output('cls_score'+suffix)
        # ops['cls_score' + "_0"] = cls_score
        # ops['y_outs']
        ops['loss_cls'+suffix] = losses.sparse_softmax(cls_score, self.data['labels'], name='cls_loss'+suffix)

        # bounding box regression L1 loss
        if cfg.TRAIN.BBOX_REG:
            bbox_pred = self.get_output('bbox_pred'+suffix)
            ops['loss_box'+suffix]  = losses.l1_loss(bbox_pred, self.data['bbox_targets'], 'reg_loss'+suffix,
                                                     self.data['bbox_inside_weights'])
        else:
            print('NO BBOX REGRESSION!!!!!')
        return ops
Exemplo n.º 4
0
def training_step(
    gen: Generator,
    disc: Discriminator,
    gen_optim,
    disc_optim,
    batch: Tuple[Tensor, Tensor, Tensor],
) -> Dict[str, Tensor]:
    line, hint, color = batch
    batch_size = line.shape[0]
    mask = opt.mask_gen(list(hint.shape), X, batch_size // 2)
    hint = hint * mask

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        _color = gen({"line": line, "hint": hint, "training": True})

        logits = disc(color, training=True)
        _logits = disc(_color, training=True)

        adv_loss = losses.binary_crossentropy(_logits, tf.ones_like(_logits))
        l1_loss = losses.l1_loss(_color, color)
        gen_loss = adv_loss + (l1_loss * 100)

        real_loss = losses.binary_crossentropy(logits, tf.ones_like(logits))
        fake_loss = losses.binary_crossentropy(_logits, tf.zeros_like(_logits))
        disc_loss = (real_loss + fake_loss) / 2

    gen_grad = gen_tape.gradient(gen_loss, gen.trainable_variables)
    disc_grad = disc_tape.gradient(disc_loss, disc.trainable_variables)

    gen_optim.apply_gradients(zip(gen_grad, gen.trainable_variables))
    disc_optim.apply_gradients(zip(disc_grad, disc.trainable_variables))

    return {
        # image
        "line": line,
        "hint": hint,
        "color": color,
        "_color": _color,
        # scaler
        "logits": logits,
        "_logits": _logits,
        "adv_loss": adv_loss,
        "l1_loss": l1_loss,
        "gen_loss": gen_loss,
        "real_loss": real_loss,
        "fake_loss": fake_loss,
        "disc_loss": disc_loss,
    }
def find_lr(data, batch_size=1, start_lr=1e-9, end_lr=1):
    generator = Generator(input_shape=[None,None,2])
    discriminator = Discriminator(input_shape=[None,None,1])

    generator_optimizer = tf.keras.optimizers.Adam(lr=start_lr)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr=start_lr)

    model_name = data['training'].origin+'_2_any'
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name)
    if(not os.path.isdir(checkpoint_prefix)):
        os.makedirs(checkpoint_prefix)

    epoch_size = data['training'].__len__()
    lr_mult = (end_lr / start_lr) ** (1 / epoch_size)

    lrs = []
    losses = {
        'gen_mae': [],
        'gen_loss': [],
        'disc_loss': []
    }
    best_losses = {
        'gen_mae': 1e9,
        'gen_loss': 1e9,
        'disc_loss': 1e9
    }

    print()
    print("Finding the optimal LR with the following parameters: ")
    print("\tCheckpoints: \t", checkpoint_prefix)
    print("\tEpochs: \t", 1)
    print("\tBatchSize: \t", batch_size)
    print("\tnBatches: \t", epoch_size)
    print()    

    print('Epoch {}/{}'.format(1, 1))
    progbar = tf.keras.utils.Progbar(epoch_size)
    for i in range(epoch_size):
        # Get the data from the DataGenerator
        input_image, target = data['training'].__getitem__(i) 
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # Generate a fake image
            gen_output = generator(input_image, training=True)
            
            # Train the discriminator
            disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True)
            disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True)
            
            # Compute the losses
            gen_mae = l1_loss(target, gen_output)
            gen_loss = generator_loss(disc_generated_output, gen_mae)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

            # Compute the gradients
            generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
            discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
            
            # Apply the gradients
            generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
            
            # Convert losses to numpy 
            gen_mae = gen_mae.numpy()
            gen_loss = gen_loss.numpy()
            disc_loss = disc_loss.numpy()
            
            # Update the progress bar
            progbar.add(1, values=[
                                    ("gen_mae", gen_mae), 
                                    ("gen_loss", gen_loss), 
                                    ("disc_loss", disc_loss)
                                ])
            
            # On batch end
            lr = tf.keras.backend.get_value(generator_optimizer.lr)
            lrs.append(lr)
            
            # Update the lr
            lr *= lr_mult
            tf.keras.backend.set_value(generator_optimizer.lr, lr)
            tf.keras.backend.set_value(discriminator_optimizer.lr, lr)
            
            # Update the losses
            losses['gen_mae'].append(gen_mae)
            losses['gen_loss'].append(gen_loss)
            losses['disc_loss'].append(disc_loss)
            
            # Update the best losses
            if(best_losses['gen_mae'] > gen_mae):
                best_losses['gen_mae'] = gen_mae
            if(best_losses['gen_loss'] > gen_loss):
                best_losses['gen_loss'] = gen_loss
            if(best_losses['disc_loss'] > disc_loss):
                best_losses['disc_loss'] = disc_loss
            if(gen_mae >= 100*best_losses['gen_mae'] or gen_loss >= 100*best_losses['gen_loss'] or disc_loss >= 100*best_losses['disc_loss']):
                break

    plot_loss_findlr(losses['gen_mae'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_mae.tiff'))
    plot_loss_findlr(losses['gen_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_loss.tiff'))
    plot_loss_findlr(losses['disc_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_disc_loss.tiff'))

    print('Best losses:')
    print('gen_mae =', best_losses['gen_mae'])
    print('gen_loss =', best_losses['gen_loss'])
    print('disc_loss =', best_losses['disc_loss'])
def train(data, epochs, batch_size=1, gen_lr=5e-6, disc_lr=5e-7, epoch_offset=0):
    generator = Generator(input_shape=[None,None,2])
    discriminator = Discriminator(input_shape=[None,None,1])

    generator_optimizer = tf.keras.optimizers.Adam(gen_lr)
    discriminator_optimizer = tf.keras.optimizers.Adam(disc_lr)

    model_name = data['training'].origin+'_2_any'
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name)
    if(not os.path.isdir(checkpoint_prefix)):
        os.makedirs(checkpoint_prefix)
    else:
        if(os.path.isfile(os.path.join(checkpoint_prefix, 'generator.h5'))):
            generator.load_weights(os.path.join(checkpoint_prefix, 'generator.h5'), by_name=True)
            print('Generator weights restorred from ' + checkpoint_prefix)

        if(os.path.isfile(os.path.join(checkpoint_prefix, 'discriminator.h5'))):
            discriminator.load_weights(os.path.join(checkpoint_prefix, 'discriminator.h5'), by_name=True)
            print('Discriminator weights restorred from ' + checkpoint_prefix)

    # Get the number of batches in the training set
    epoch_size = data['training'].__len__()

    print()
    print("Started training with the following parameters: ")
    print("\tCheckpoints: \t", checkpoint_prefix)
    print("\tEpochs: \t", epochs)
    print("\tgen_lr: \t", gen_lr)
    print("\tdisc_lr: \t", disc_lr)
    print("\tBatchSize: \t", batch_size)
    print("\tnBatches: \t", epoch_size)
    print()

    # Precompute the test input and target for validation
    audio_input = load_audio(os.path.join(TEST_AUDIOS_PATH, data['training'].origin+'.wav'))
    mag_input, phase = forward_transform(audio_input)
    mag_input = amplitude_to_db(mag_input)
    test_input = slice_magnitude(mag_input, mag_input.shape[0])
    test_input = (test_input * 2) - 1

    test_inputs = []
    test_targets = []

    for t in data['training'].target:
        audio_target = load_audio(os.path.join(TEST_AUDIOS_PATH, t+'.wav'))
        mag_target, _ = forward_transform(audio_target)
        mag_target = amplitude_to_db(mag_target)
        test_target = slice_magnitude(mag_target, mag_target.shape[0])
        test_target = (test_target * 2) - 1

        test_target_perm = test_target[np.random.permutation(test_target.shape[0]),:,:,:]
        test_inputs.append(np.concatenate([test_input, test_target_perm], axis=3))
        test_targets.append(test_target)

    gen_mae_list, gen_mae_val_list  = [], []
    gen_loss_list, gen_loss_val_list  = [], []
    disc_loss_list, disc_loss_val_list  = [], []
    for epoch in range(epochs):
        gen_mae_total, gen_mae_val_total = 0, 0
        gen_loss_total, gen_loss_val_total = 0, 0
        disc_loss_total, disc_loss_val_total = 0, 0

        print('Epoch {}/{}'.format((epoch+1)+epoch_offset, epochs+epoch_offset))
        progbar = tf.keras.utils.Progbar(epoch_size)
        for i in range(epoch_size):
            # Get the data from the DataGenerator
            input_image, target = data['training'].__getitem__(i) 
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                # Generate a fake image
                gen_output = generator(input_image, training=True)
                
                # Train the discriminator
                disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True)
                disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True)
                
                # Compute the losses
                gen_mae = l1_loss(target, gen_output)
                gen_loss = generator_loss(disc_generated_output, gen_mae)
                disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
                
                # Compute the gradients
                generator_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables)
                discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
                
                # Apply the gradients
                generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
                discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

                # Update the progress bar
                gen_mae = gen_mae.numpy()
                gen_loss = gen_loss.numpy()
                disc_loss = disc_loss.numpy()
                
                gen_mae_total += gen_mae
                gen_loss_total += gen_loss
                disc_loss_total += disc_loss

                progbar.add(1, values=[
                                        ("gen_mae", gen_mae), 
                                        ("gen_loss", gen_loss), 
                                        ("disc_loss", disc_loss)
                                    ])

        gen_mae_list.append(gen_mae_total/epoch_size)
        gen_mae_val_list.append(gen_mae_val_total/epoch_size)
        gen_loss_list.append(gen_loss_total/epoch_size)
        gen_loss_val_list.append(gen_loss_val_total/epoch_size)
        disc_loss_list.append(disc_loss_total/epoch_size)
        disc_loss_val_list.append(disc_loss_val_total/epoch_size)

        history = pd.DataFrame({
                                    'gen_mae': gen_mae_list, 
                                    'gen_mae_val': gen_mae_val_list, 
                                    'gen_loss': gen_loss_list,
                                    'gen_loss_val': gen_loss_val_list,
                                    'disc_loss': disc_loss_list,
                                    'disc_loss_val': disc_loss_val_list
                                })
        write_csv(history, os.path.join(checkpoint_prefix, 'history.csv'))

        epoch_output = os.path.join(OUTPUT_PATH, model_name, str((epoch+1)+epoch_offset).zfill(3))
        init_directory(epoch_output)

        # Generate audios and save spectrograms for the entire audios
        for j in range(len(data['training'].target)):
            prediction = generator(test_inputs[j], training=False)
            prediction = (prediction + 1) / 2
            generate_images(prediction, (test_inputs[j] + 1) / 2, (test_targets[j] + 1) / 2, os.path.join(epoch_output, 'spectrogram_'+data['training'].target[j]))
            generate_audio(prediction, phase, os.path.join(epoch_output, 'audio_'+data['training'].target[j]+'.wav'))
        print('Epoch outputs saved in ' + epoch_output)

        # Save the weights
        generator.save_weights(os.path.join(checkpoint_prefix, 'generator.h5'))
        discriminator.save_weights(os.path.join(checkpoint_prefix, 'discriminator.h5'))
        print('Weights saved in ' + checkpoint_prefix)

        # Callback at the end of the epoch for the DataGenerator
        data['training'].on_epoch_end()
Exemplo n.º 7
0
def train(data, epochs, batch_size=1, lr=1e-3, epoch_offset=0):
    generator = Generator()
    generator_optimizer = tf.keras.optimizers.Adam(lr)

    model_name = data['training'].origin + '_2_' + data[
        'training'].target + '_generator'
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name)
    if (not os.path.isdir(checkpoint_prefix)):
        os.makedirs(checkpoint_prefix)
    else:
        if (os.path.isfile(os.path.join(checkpoint_prefix, 'generator.h5'))):
            generator.load_weights(os.path.join(checkpoint_prefix,
                                                'generator.h5'),
                                   by_name=True)
            print('Generator weights restorred from ' + checkpoint_prefix)

    # Get the number of batches in the training set
    epoch_size = data['training'].__len__()

    print()
    print("Started training with the following parameters: ")
    print("\tCheckpoints: \t", checkpoint_prefix)
    print("\tEpochs: \t", epochs)
    print("\tgen_lr: \t", lr)
    print("\tBatchSize: \t", batch_size)
    print("\tnBatches: \t", epoch_size)
    print()

    # Precompute the test input and target for validation
    audio_input = load_audio(
        os.path.join(TEST_AUDIOS_PATH, data['training'].origin + '.wav'))
    mag_input, phase = forward_transform(audio_input)
    mag_input = amplitude_to_db(mag_input)
    test_input = slice_magnitude(mag_input, mag_input.shape[0])
    test_input = (test_input * 2) - 1

    audio_target = load_audio(
        os.path.join(TEST_AUDIOS_PATH, data['training'].target + '.wav'))
    mag_target, _ = forward_transform(audio_target)
    mag_target = amplitude_to_db(mag_target)
    test_target = slice_magnitude(mag_target, mag_target.shape[0])
    test_target = (test_target * 2) - 1

    gen_mae_list, gen_mae_val_list = [], []
    for epoch in range(epochs):
        gen_mae_total, gen_mae_val_total = 0, 0
        print('Epoch {}/{}'.format((epoch + 1) + epoch_offset,
                                   epochs + epoch_offset))
        progbar = tf.keras.utils.Progbar(epoch_size)
        for i in range(epoch_size):
            input_image, target = data['training'].__getitem__(i)
            with tf.GradientTape() as gen_tape:
                # Generate a fake image
                gen_output = generator(input_image, training=True)

                # Compute the losses
                gen_mae = l1_loss(target, gen_output)  # Timbre transfer
                # gen_mae = l1_loss(input_image, gen_output) # Autoencoder

                # Compute the gradients
                generator_gradients = gen_tape.gradient(
                    gen_mae, generator.trainable_variables)

                # Apply the gradients
                generator_optimizer.apply_gradients(
                    zip(generator_gradients, generator.trainable_variables))

                # Update the progress bar
                gen_mae = gen_mae.numpy()
                gen_mae_total += gen_mae
                progbar.add(1, values=[("gen_mae", gen_mae)])

        gen_mae_total /= epoch_size
        gen_mae_list.append(gen_mae_total)
        gen_mae_val_list.append(gen_mae_val_total)

        history = pd.DataFrame({
            'gen_mae': gen_mae_list,
            'gen_mae_val': gen_mae_val_list
        })
        write_csv(history, os.path.join(checkpoint_prefix, 'history.csv'))

        epoch_output = os.path.join(OUTPUT_PATH, model_name,
                                    str((epoch + 1) + epoch_offset).zfill(3))
        init_directory(epoch_output)

        # Generate audios and save spectrograms for the entire audios
        prediction = generator(test_input, training=False)
        prediction = (prediction + 1) / 2
        generate_images(prediction, (test_input + 1) / 2,
                        (test_target + 1) / 2,
                        os.path.join(epoch_output, 'spectrogram'))
        generate_audio(prediction, phase,
                       os.path.join(epoch_output, 'audio.wav'))
        print('Epoch outputs saved in ' + epoch_output)

        # Save the weights
        generator.save_weights(os.path.join(checkpoint_prefix, 'generator.h5'))
        print('Weights saved in ' + checkpoint_prefix)

        # Callback at the end of the epoch for the DataGenerator
        data['training'].on_epoch_end()
Exemplo n.º 8
0
def main():
    # Get data loader
    loader = get_loader(config, train=False)

    num_test = len(loader.dataset)

    # Generator network for target domain
    if img_size == 64:
        G = Generator64().cuda()
    else:
        G = Generator32().cuda()
    g_ckpt = torch.load('models/' + gen_ckpt)
    G.load_state_dict(g_ckpt)
    G.eval()

    # Transfer network from target to source domain
    T = Net(config).cuda()
    last_model_name = get_model_list(checkpoint_directory)
    t_ckpt = torch.load(last_model_name)
    T.load_state_dict(t_ckpt)
    T.eval()

    vgg19 = VGG19(vgg_ckpt).cuda()

    for idx in range(num_test):

        # Source image
        source = loader.dataset[idx]
        if dataset in ['svhn', 'mnist']:
            source = source[0]
        source = source.cuda()
        source = source.unsqueeze(0)
        source = source.repeat(num_sample, 1, 1, 1)
        
        # Latent Z for training set
        z = torch.randn(num_sample, z_dim, 1, 1).cuda()
        z.requires_grad = True

        optimizer_Z = optim.Adam([z], lr=lr_z)
        
        for it in range(num_eval_iter):            
            # Update Z vector
            target = G(z)
            target_downsampled = T.get_downsampled_images(target)
            target2source = T(target_downsampled)
                
            source_features = vgg19(source)
            target2source_features = vgg19(target2source)
            
            l1_loss_samples_z = l1_loss(target2source, source)
            perceptual_loss_samples_z = perceptual_loss(target2source_features, source_features)
            
            loss_z = l1_w * l1_loss_samples_z + vgg_w * perceptual_loss_samples_z
            loss_z_samples = loss_z
            loss_z = loss_z.mean()

            optimizer_Z.zero_grad()
            loss_z.backward()
            optimizer_Z.step()

            print("Image: {}, Step: {}, Loss: {}, L1: {}, VGG: {}".format(idx, it, loss_z, l1_loss_samples_z.mean(), perceptual_loss_samples_z.mean()))


        save_result(source, target, idx, num_save, evaluation_directory, loss_z_samples)
Exemplo n.º 9
0
    def compute_losses(self):

        cycle_consistency_loss_a = \
            self._lambda_a * losses.cycle_consistency_loss(
                real_images=tf.expand_dims(self.input_a[:,:,:,1], axis=3), generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * losses.cycle_consistency_loss(
                real_images=tf.expand_dims(self.input_b[:,:,:,1], axis=3), generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
        lsgan_loss_f = losses.lsgan_loss_generator(self.prob_fea_b_is_real)
        lsgan_loss_a_aux = losses.lsgan_loss_generator(
            self.prob_fake_a_aux_is_real)

        losssea = 0.01 * losses.l1_loss(self.fea_A_separate_B)
        lossseb = 0.01 * losses.l1_loss(self.fea_B_separate_A)
        lossseaf = 0.01 * losses.l1_loss(self.fea_FA_separate_B)
        losssebf = 0.01 * losses.l1_loss(self.fea_FB_separate_A)
        dif_loss = losssea + lossseb + lossseaf + losssebf

        ce_loss_b, dice_loss_b = losses.task_loss(self.pred_mask_fake_b,
                                                  self.gt_a)
        ce_loss_a, dice_loss_a = losses.task_loss(self.pre_mask_real_a,
                                                  self.gt_a)

        l2_loss_b = tf.add_n([
            0.0001 * tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if '/s_B/' in v.name or '/e_B/' in v.name
        ])

        g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
        g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a

        self.loss_f_weight = tf.placeholder(tf.float32,
                                            shape=[],
                                            name="loss_f_weight")
        self.loss_f_weight_summ = tf.summary.scalar("loss_f_weight",
                                                    self.loss_f_weight)
        #seg_loss_B = ce_loss_b + dice_loss_b + ce_loss_a + dice_loss_a + l2_loss_b + 0.1*g_loss_B + self.loss_f_weight*lsgan_loss_f + 0.1*lsgan_loss_a_aux
        seg_loss_B = ce_loss_b + dice_loss_b + l2_loss_b + 0.1 * g_loss_B + self.loss_f_weight * lsgan_loss_f + 0.1 * lsgan_loss_a_aux
        seg_loss_A = ce_loss_a + dice_loss_a + l2_loss_b + +0.1 * g_loss_A

        d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_A_aux = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_cycle_a_aux_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_aux_is_real,
        )
        d_loss_A = d_loss_A + d_loss_A_aux
        d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )
        d_loss_F = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_fea_fake_b_is_real,
            prob_fake_is_real=self.prob_fea_b_is_real,
        )

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
        optimizer_seg = tf.train.AdamOptimizer(self.learning_rate_seg)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if '/d_A/' in var.name]
        d_B_vars = [var for var in self.model_vars if '/d_B/' in var.name]
        e_c_vars = [var for var in self.model_vars if '/e_c/' in var.name]
        e_cs_vars = [var for var in self.model_vars if '/e_cs/' in var.name]
        e_ct_vars = [var for var in self.model_vars if '/e_ct/' in var.name]
        de_B_vars = [var for var in self.model_vars if '/de_B/' in var.name]
        de_A_vars = [var for var in self.model_vars if '/de_A/' in var.name]
        de_c_vars = [var for var in self.model_vars if '/de_c/' in var.name]
        s_B_vars = [var for var in self.model_vars if '/s_B/' in var.name]
        d_F_vars = [var for var in self.model_vars if '/d_F/' in var.name]
        e_dB_vars = [var for var in self.model_vars if '/e_dB/' in var.name]
        e_dA_vars = [var for var in self.model_vars if '/e_dA/' in var.name]

        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.dif_trainer = optimizer.minimize(dif_loss,
                                              var_list=e_dB_vars + e_dA_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A,
                                              var_list=de_A_vars + de_c_vars +
                                              e_c_vars + e_cs_vars + e_dB_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B,
                                              var_list=de_B_vars + de_c_vars +
                                              e_c_vars + e_ct_vars + e_dA_vars)
        self.s_B_trainer = optimizer_seg.minimize(seg_loss_B,
                                                  var_list=e_c_vars +
                                                  e_ct_vars + s_B_vars)
        self.s_A_trainer = optimizer_seg.minimize(seg_loss_A,
                                                  var_list=e_c_vars +
                                                  e_cs_vars + s_B_vars)
        self.d_F_trainer = optimizer.minimize(d_loss_F, var_list=d_F_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
        self.dif_loss_summ = tf.summary.scalar("dif_loss", dif_loss)
        self.ce_B_loss_summ = tf.summary.scalar("ce_B_loss", ce_loss_b)
        self.dice_B_loss_summ = tf.summary.scalar("dice_B_loss", dice_loss_b)
        self.l2_B_loss_summ = tf.summary.scalar("l2_B_loss", l2_loss_b)
        self.s_B_loss_summ = tf.summary.scalar("s_B_loss", seg_loss_B)
        self.s_A_loss_summ = tf.summary.scalar("s_A_loss", seg_loss_A)
        self.s_B_loss_merge_summ = tf.summary.merge([
            self.ce_B_loss_summ, self.dice_B_loss_summ, self.l2_B_loss_summ,
            self.s_B_loss_summ, self.s_A_loss_summ
        ])
        self.d_F_loss_summ = tf.summary.scalar("d_F_loss", d_loss_F)
Exemplo n.º 10
0
    def build_model(self):
        self.input_shape = (self.batchsize, self.args['input']['size'],
                            self.args['input']['size'],
                            self.args['input']['channel'])
        self.output_shape = (self.batchsize, self.args['output']['size'],
                             self.args['output']['size'],
                             self.args['output']['channel'])
        # set placeholder
        ## s for source
        #        if self.flag_L1 or self.flag_d_intra or self.flag_task:
        self.s_gm = tf.placeholder(dtype=tf.float32, shape=self.input_shape)
        #        if self.flag_L1 or self.flag_d_intra or self.flag_d_inter:
        self.s_color = tf.placeholder(dtype=tf.float32,
                                      shape=self.output_shape)
        if self.flag_task:
            temp = (self.args['batchsize'],
                    self.args['model']['tasknet']['num_classes'])
            self.s_label = tf.placeholder(dtype=tf.float32, shape=temp)
        ## t for target
        if self.flag_d_inter:
            self.t_color = tf.placeholder(dtype=tf.float32,
                                          shape=self.output_shape)
            self.t_gm = tf.placeholder(dtype=tf.float32,
                                       shape=self.input_shape)

        # generate images
        if self.flag_L1 or self.flag_d_intra or self.flag_task:
            self.fake_s_color = generator(self.s_gm,
                                          gf_dim=self.gf_dim,
                                          o_c=self.output_shape[-1])
        if self.flag_d_inter:
            self.fake_t_color = generator(self.t_gm,
                                          gf_dim=self.gf_dim,
                                          o_c=self.output_shape[-1])

        # compute loss
        ## intra-domain
        self.loss_dict = {}
        if self.flag_d_intra:
            d_intra_logits_real = discriminator(self.s_color,
                                                df_dim=self.df_dim,
                                                name='intra_discriminator')
            d_intra_logits_fake = discriminator(self.fake_s_color,
                                                df_dim=self.df_dim,
                                                name='intra_discriminator')
            d_intra_loss_real = losses.nsgan_loss(d_intra_logits_real,
                                                  is_real=True)
            d_intra_loss_fake = losses.nsgan_loss(d_intra_logits_fake,
                                                  is_real=False)
            self.d_intra_loss = d_intra_loss_real + d_intra_loss_fake
            tf.summary.scalar("d_intra_loss", self.d_intra_loss)
            self.loss_dict.update({'d_intra_loss': self.d_intra_loss})
        ## inter-domain
        if self.flag_d_inter:
            d_inter_logits_real = discriminator(self.s_color,
                                                df_dim=self.df_dim,
                                                name='inter_discriminator')
            d_inter_logits_fake = discriminator(self.fake_t_color,
                                                df_dim=self.df_dim,
                                                name='inter_discriminator')
            d_inter_loss_real = losses.nsgan_loss(d_inter_logits_real,
                                                  is_real=True)
            d_inter_loss_fake = losses.nsgan_loss(d_inter_logits_fake,
                                                  is_real=False)
            self.d_inter_loss = d_inter_loss_real + d_inter_loss_fake
            tf.summary.scalar("d_inter_loss", self.d_inter_loss)
            self.loss_dict.update({'d_inter_loss': self.d_inter_loss})
        ## Generator loss
        flag = False
        self.g_loss_dict = {}
        ### l1 loss
        if self.flag_L1:
            self.l1_loss = losses.l1_loss(self.fake_s_color, self.s_color)
            self.g_loss_dict.update({'g_l1': self.l1_loss})
            _lambda = self.args['model']['lambda_L1']
            self.g_loss = _lambda * self.l1_loss
            flag = True
            tf.summary.scalar("l1_loss", self.l1_loss)
        ### task loss
        if self.flag_task:
            num_classes = self.args['model']['tasknet']['num_classes']
            task_net = ResNet50(self.fake_s_color, num_classes, phase=False)
            task_pred_logits = task_net.outputs
            self.task_loss = losses.task_loss(task_pred_logits, self.s_label)
            self.g_loss_dict.update({'g_loss_task': self.task_loss})
            _lambda = self.args['model']['tasknet']['lambda_L_task']
            if flag:
                self.g_loss += _lambda * self.task_loss
            else:
                self.g_loss = _lambda * self.task_loss
                flag = True
            tf.summary.scalar("g_task_loss", self.task_loss)
        ### d-intra loss
        if self.flag_d_intra:
            self.g_loss_intra = losses.nsgan_loss(d_intra_logits_fake, True)
            self.g_loss_dict.update({'g_d_intra': self.g_loss_intra})
            _lambda = self.args['model']['discriminator_intra'][
                'lambda_L_d_intra']
            if flag:
                self.g_loss += _lambda * self.g_loss_intra
            else:
                self.g_loss = _lambda * self.g_loss_intra
                flag = True
            tf.summary.scalar("g_loss_intra", self.g_loss_intra)
        ### d-inter loss
        if self.flag_d_inter:
            self.g_loss_inter = losses.nsgan_loss(d_inter_logits_fake, True)
            self.g_loss_dict.update({'g_d_inter': self.g_loss_inter})
            _lambda = self.args['model']['discriminator_inter'][
                'lambda_L_d_inter']
            if flag:
                self.g_loss += _lambda * self.g_loss_inter
            else:
                self.g_loss = _lambda * self.g_loss_inter
                flag = True
            tf.summary.scalar("g_loss_inter", self.g_loss_inter)
        tf.summary.scalar("g_loss", self.g_loss)
        self.loss_dict.update(self.g_loss_dict)

        #log
        self.sample = tf.concat(
            [self.fake_s_color, self.s_gm[:, :, :, 1:], self.s_color], 2)
        if self.flag_d_inter:
            sample_t = tf.concat(
                [self.fake_t_color, self.t_gm[:, :, :, 1:], self.t_color], 2)
            self.sample = tf.concat([self.sample, sample_t], 1)
        self.sample = (self.sample + 1) * 127.5

        #divide variable group
        t_vars = tf.trainable_variables()
        global_vars = tf.global_variables()
        self.normnet_vars_global = []
        if self.flag_d_intra:
            self.d_intra_vars = [
                var for var in t_vars if 'intra_discriminator' in var.name
            ]
            self.d_intra_vars_global = [
                var for var in global_vars if 'intra_discriminator' in var.name
            ]
            self.normnet_vars_global += self.d_intra_vars_global
        if self.flag_d_inter:
            self.d_inter_vars = [
                var for var in t_vars if 'inter_discriminator' in var.name
            ]
            self.d_inter_vars_global = [
                var for var in global_vars if 'inter_discriminator' in var.name
            ]
            self.normnet_vars_global += self.d_inter_vars_global
        self.g_vars = [var for var in t_vars if 'generator' in var.name]
        self.g_vars_global = [
            var for var in global_vars if 'generator' in var.name
        ]
        self.normnet_vars_global += self.g_vars_global
        if self.flag_task:
            self.tasknet_vars = [
                var for var in t_vars if var not in self.normnet_vars_global
            ]
            #            self.tasknet_vars_tainable = self.tasknet_vars[44:]
            self.tasknet_vars_global = [
                var for var in global_vars
                if var not in self.normnet_vars_global
            ]

        #saver
        vars_save = self.normnet_vars_global
        self.saver = tf.train.Saver(var_list=vars_save, max_to_keep=20)
Exemplo n.º 11
0
def test_(image_size: int):
    image = build_tensor(1, image_size, 3)
    loss = l1_loss(image, image)
    assert loss == 0
Exemplo n.º 12
0
def main():
    # Get data loader
    loader = get_loader(config)

    # Latent Z for training set
    z_y = torch.randn(num_train, z_dim).cuda()
    z_y = z_y.view(num_train, z_dim, 1, 1)
    z_y.requires_grad = True

    optimizer_Z = optim.Adam([z_y], lr=lr_z)
    scheduler_Z = optim.lr_scheduler.StepLR(optimizer_Z,
                                            step_size=step_size,
                                            gamma=gamma)

    # Generator network for target domain
    if img_size == 64:
        G = Generator64().cuda()
    else:
        G = Generator32().cuda()
    g_ckpt = torch.load('models/' + gen_ckpt)
    G.load_state_dict(g_ckpt)
    G.eval()

    # Transfer network from target to source domain
    T = Net(config).cuda()
    T.weight_init(mean=0.0, std=0.02)

    optimizer_T = optim.Adam(T.parameters(), lr=lr_t)
    scheduler_T = optim.lr_scheduler.StepLR(optimizer_T,
                                            step_size=step_size,
                                            gamma=gamma)

    vgg19 = VGG19(vgg_ckpt).cuda()

    for epoch in range(num_epochs):
        for step, data in enumerate(loader):
            if dataset in ['svhn', 'mnist']:
                data = data[0]
            source = data.cuda()

            z = z_y[step * batch_size:(step + 1) * batch_size]

            T.eval()
            # Update Z vector
            target = G(z)
            target_downsampled = T.get_downsampled_images(target)
            target2source = T(target_downsampled)

            source_features = vgg19(source)
            target2source_features = vgg19(target2source)

            l1_loss_samples_z = l1_loss(target2source, source)
            perceptual_loss_samples_z = perceptual_loss(
                target2source_features, source_features)

            loss_z = l1_w * l1_loss_samples_z + vgg_w * perceptual_loss_samples_z
            loss_z = loss_z.mean()

            optimizer_Z.zero_grad()
            loss_z.backward()
            optimizer_Z.step()

            T.train()
            # Update the T network
            target = G(z)
            target_downsampled = T.get_downsampled_images(target)
            target2source = T(target_downsampled)

            source_features = vgg19(source)
            target2source_features = vgg19(target2source)

            l1_loss_samples_t = l1_loss(target2source, source)
            perceptual_loss_samples_t = perceptual_loss(
                target2source_features, source_features)

            loss_t = l1_w * l1_loss_samples_t + vgg_w * perceptual_loss_samples_t
            loss_t = loss_t.mean()

            optimizer_T.zero_grad()
            loss_t.backward()
            optimizer_T.step()

            print("Epoch: {}, Step: {}, Loss: {}, L1: {}, VGG: {}".format(
                epoch, step, loss_t, l1_loss_samples_t.mean(),
                perceptual_loss_samples_t.mean()))

        scheduler_Z.step()
        scheduler_T.step()

        if (epoch + 1) % img_save_it == 0:
            save_result(source, target, epoch, num_save, image_directory)

        if (epoch + 1) % model_save_it == 0:
            torch.save(
                T.state_dict(),
                os.path.join(checkpoint_directory,
                             "transfer_network_param_{}.pkl".format(epoch)))

    torch.save(
        T.state_dict(),
        os.path.join(checkpoint_directory, "transfer_network_param_final.pkl"))