Exemplo n.º 1
0
def convert():
    g_gan = generator()
    f_gan = generator()

    dis_g = pix2pix.discriminator(norm_type='instancenorm', target=False)
    dis_f = pix2pix.discriminator(norm_type='instancenorm', target=False)

    ckpt = tf.train.Checkpoint(ggan=g_gan, fgan=f_gan, gdis=dis_g, fdis=dis_f)

    if tf.train.latest_checkpoint('checkpoints'):
        manager = tf.train.CheckpointManager(
            ckpt,
            'checkpoints',
            5,
        )

        for index, chpt in enumerate(manager.checkpoints):
            print(f'{index} : {chpt}')

        index = int(input('Index'))

        print('loaded checkpoint:', manager.checkpoints[index])
        ckpt.restore(manager.checkpoints[index])

    g_gan.save('./g_gan')
    f_gan.save('./f_gan')
Exemplo n.º 2
0
def build_cyclegan_models(n_channels, norm_type):
    assert norm_type in ['instancenorm', 'batchnorm']
    generator_g = pix2pix.unet_generator(n_channels, norm_type=norm_type)
    generator_f = pix2pix.unet_generator(n_channels, norm_type=norm_type)

    discriminator_x = pix2pix.discriminator(norm_type=norm_type, target=False)
    discriminator_y = pix2pix.discriminator(norm_type=norm_type, target=False)

    return generator_g, generator_f, discriminator_x, discriminator_y
Exemplo n.º 3
0
    def __init__(self, checkpoint_path: str = None, restore_checkpoint: bool = True):

        output_channels = 3
        logger.info("Creating Generators and Discriminators")
        self.generator_g = pix2pix.unet_generator(
            output_channels, norm_type="instancenorm"
        )
        self.generator_f = pix2pix.unet_generator(
            output_channels, norm_type="instancenorm"
        )

        self.discriminator_x = pix2pix.discriminator(
            norm_type="instancenorm", target=False
        )
        self.discriminator_y = pix2pix.discriminator(
            norm_type="instancenorm", target=False
        )

        logger.info("Setting up the optimizers")

        self.generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.LAMBDA = 10

        if checkpoint_path is None:
            self.checkpoint_path = ".."
        else:
            self.checkpoint_path = checkpoint_path

        self.ckpt = tf.train.Checkpoint(
            generator_g=self.generator_g,
            generator_f=self.generator_f,
            discriminator_x=self.discriminator_x,
            discriminator_y=self.discriminator_y,
            generator_g_optimizer=self.generator_g_optimizer,
            generator_f_optimizer=self.generator_f_optimizer,
            discriminator_x_optimizer=self.discriminator_x_optimizer,
            discriminator_y_optimizer=self.discriminator_y_optimizer,
        )

        self.ckpt_manager = tf.train.CheckpointManager(
            self.ckpt, self.checkpoint_path, max_to_keep=5
        )

        self.restore_checkpoint = restore_checkpoint

        if self.restore_checkpoint:
            # if a checkpoint exists, restore the latest checkpoint.
            if self.ckpt_manager.latest_checkpoint:
                self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
                print("Latest checkpoint restored!!")
Exemplo n.º 4
0
def get_models_from_input_shape(input_shape,
                                norm_type,
                                output_init=0.02,
                                residual_output=False):
    if input_shape == (28, 28, 1):
        # MNIST-like data
        return mnist_unet_generator(norm_type=norm_type), \
               mnist_discriminator(norm_type=norm_type, target=False)
    elif input_shape == (256, 256, 3):
        # TODO: just use our unet_generator fn
        if residual_output is True or output_init != 0.02:
            raise NotImplementedError
        return pix2pix.unet_generator(output_channels=3, norm_type=norm_type), \
               pix2pix.discriminator(norm_type=norm_type, target=False)
    else:
        return unet_generator(output_channels=3, input_shape=input_shape, norm_type=norm_type,
                              output_init=output_init, residual_output=residual_output), \
               pix2pix.discriminator(norm_type=norm_type, target=False)
Exemplo n.º 5
0
    def __init__(self, summary, lmbda, nsamples, niters, learning_rate,
                 beta_1):
        self.lmbda = tf.constant(lmbda, tf.float32)
        self.summary = summary
        self.ggan = generator()
        self.fgan = generator()

        self.gdis = pix2pix.discriminator(norm_type='instancenorm',
                                          target=False)
        self.fdis = pix2pix.discriminator(norm_type='instancenorm',
                                          target=False)

        self.lrscheculer = scheduler.LinearDecay(nsamples * niters // 2,
                                                 learning_rate,
                                                 nsamples * niters // 2, 0.0)

        self.opti_ggan = tf.keras.optimizers.Adam(self.lrscheculer, beta_1)
        self.opti_gdis = tf.keras.optimizers.Adam(self.lrscheculer, beta_1)
        self.opti_fgan = tf.keras.optimizers.Adam(self.lrscheculer, beta_1)
        self.opti_fdis = tf.keras.optimizers.Adam(self.lrscheculer, beta_1)
Exemplo n.º 6
0
    def __init__(self):
        build_discriminator = lambda: pix2pix.discriminator(norm_type=NORM_TYPE, target=False)
        build_optimizer = lambda: tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        build_generator = lambda: pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type=NORM_TYPE)

        # Discriminators (Tries to identify if the input is a real example of its class)
        self.Dx = build_discriminator()
        self.Dy = build_discriminator()
        # Generators (Converts from the opposite class into the new class)
        self.Gx = build_generator()
        self.Gy = build_generator()
        # Optimizers to perform stochastic graident descent
        self.Dx_optimizer = build_optimizer()
        self.Dy_optimizer = build_optimizer()
        self.Gx_optimizer = build_optimizer()
        self.Gy_optimizer = build_optimizer()
Exemplo n.º 7
0
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

# Reuse the pix2pix model for the conversion.
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

# Plotting the generated images

to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
    plt.subplot(2, 2, i + 1)
    plt.title(title[i])
Exemplo n.º 8
0
OUTPUT_CHANNELS = 3
LAMBDA = 10

train_horses = train_horses.map(preprocess_image_train, num_parallel_calls = AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_zebras = train_zebras.map(preprocess_image_train, num_parallel_calls = AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(preprocess_image_test, num_parallel_calls = AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_zebras = test_zebras.map(preprocess_image_test, num_parallel_calls = AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

y_generator = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type = 'instancenorm')
x_generator = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type = 'instancenorm')

y_discriminator = pix2pix.discriminator(norm_type = 'instancenorm', target = False)
x_discriminator = pix2pix.discriminator(norm_type = 'instancenorm', target = False)

generated_y = y_generator(sample_horse)
generated_x = x_generator(sample_zebra)

prediction_y = y_discriminator(generated_y)
prediction_x = x_discriminator(generated_x)

y_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
x_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

y_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
x_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

EPOCHS = 40
Exemplo n.º 9
0
 def __init__(self):
     self.generator = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type=NORM_TYPE)
     self.discriminator = pix2pix.discriminator(norm_type=NORM_TYPE, target=False)
     # Optimizers to perform stochastic graident descent
     self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
     self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
Exemplo n.º 10
0
  def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

  def preprocess_image_test(image, label):
  image = normalize(image)
  return image

  train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)
Exemplo n.º 11
0
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.show()

OUTPUT_CHANNELS = 3

generator_h_z = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_z_h = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_isHorse = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_isZebra = pix2pix.discriminator(norm_type='instancenorm', target=False)

to_zebra = generator_h_z(sample_horse)
to_horse = generator_z_h(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
      plt.imshow(imgs[i][0] * 0.5 + 0.5)