示例#1
0
def build_cyclegan_models(n_channels, norm_type):
    assert norm_type in ['instancenorm', 'batchnorm']
    generator_g = pix2pix.unet_generator(n_channels, norm_type=norm_type)
    generator_f = pix2pix.unet_generator(n_channels, norm_type=norm_type)

    discriminator_x = pix2pix.discriminator(norm_type=norm_type, target=False)
    discriminator_y = pix2pix.discriminator(norm_type=norm_type, target=False)

    return generator_g, generator_f, discriminator_x, discriminator_y
示例#2
0
    def __init__(self, checkpoint_path: str = None, restore_checkpoint: bool = True):

        output_channels = 3
        logger.info("Creating Generators and Discriminators")
        self.generator_g = pix2pix.unet_generator(
            output_channels, norm_type="instancenorm"
        )
        self.generator_f = pix2pix.unet_generator(
            output_channels, norm_type="instancenorm"
        )

        self.discriminator_x = pix2pix.discriminator(
            norm_type="instancenorm", target=False
        )
        self.discriminator_y = pix2pix.discriminator(
            norm_type="instancenorm", target=False
        )

        logger.info("Setting up the optimizers")

        self.generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.LAMBDA = 10

        if checkpoint_path is None:
            self.checkpoint_path = ".."
        else:
            self.checkpoint_path = checkpoint_path

        self.ckpt = tf.train.Checkpoint(
            generator_g=self.generator_g,
            generator_f=self.generator_f,
            discriminator_x=self.discriminator_x,
            discriminator_y=self.discriminator_y,
            generator_g_optimizer=self.generator_g_optimizer,
            generator_f_optimizer=self.generator_f_optimizer,
            discriminator_x_optimizer=self.discriminator_x_optimizer,
            discriminator_y_optimizer=self.discriminator_y_optimizer,
        )

        self.ckpt_manager = tf.train.CheckpointManager(
            self.ckpt, self.checkpoint_path, max_to_keep=5
        )

        self.restore_checkpoint = restore_checkpoint

        if self.restore_checkpoint:
            # if a checkpoint exists, restore the latest checkpoint.
            if self.ckpt_manager.latest_checkpoint:
                self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
                print("Latest checkpoint restored!!")
示例#3
0
    def __init__(self):
        build_discriminator = lambda: pix2pix.discriminator(norm_type=NORM_TYPE, target=False)
        build_optimizer = lambda: tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        build_generator = lambda: pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type=NORM_TYPE)

        # Discriminators (Tries to identify if the input is a real example of its class)
        self.Dx = build_discriminator()
        self.Dy = build_discriminator()
        # Generators (Converts from the opposite class into the new class)
        self.Gx = build_generator()
        self.Gy = build_generator()
        # Optimizers to perform stochastic graident descent
        self.Dx_optimizer = build_optimizer()
        self.Dy_optimizer = build_optimizer()
        self.Gx_optimizer = build_optimizer()
        self.Gy_optimizer = build_optimizer()
示例#4
0
def get_models_from_input_shape(input_shape,
                                norm_type,
                                output_init=0.02,
                                residual_output=False):
    if input_shape == (28, 28, 1):
        # MNIST-like data
        return mnist_unet_generator(norm_type=norm_type), \
               mnist_discriminator(norm_type=norm_type, target=False)
    elif input_shape == (256, 256, 3):
        # TODO: just use our unet_generator fn
        if residual_output is True or output_init != 0.02:
            raise NotImplementedError
        return pix2pix.unet_generator(output_channels=3, norm_type=norm_type), \
               pix2pix.discriminator(norm_type=norm_type, target=False)
    else:
        return unet_generator(output_channels=3, input_shape=input_shape, norm_type=norm_type,
                              output_init=output_init, residual_output=residual_output), \
               pix2pix.discriminator(norm_type=norm_type, target=False)
示例#5
0
def predict(img_path):
    monet_gen = pix2pix.unet_generator(3, norm_type="instancenorm")
    monet_gen.load_weights("monet_gen.h5")
    img = cv2.imread(img_path)

    def decode_img(img):
        #img = tf.image.decode_jpeg(image, channels=3)
        img = tf.cast(img, tf.float32)
        img = img / 127.5 - 1
        img = tf.image.resize(img, [256, 256])
        return img

    img = decode_img(img)
    pred = monet_gen.predict(tf.expand_dims(img, axis=0))
    if not os.path.exists("test_results"):
        os.mkdir("test_results")

    img_name = img_path.split("/")[-1].split(".")
    pred_name = img_name[0] + "_pred" + "." + img_name[1]
    cv2.imwrite(f"test_results/{pred_name}",
                pred[0] * 127.5 + 127.5)  #Saving the image.
    return f"test_results/{pred_name}"
示例#6
0
plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

# Reuse the pix2pix model for the conversion.
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

# Plotting the generated images

to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']
            seed=27).take(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
    train_zebras = train_zebras.map(
        preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
            buffer_size=TRAIN_ZEBRAS_BUFFER,
            reshuffle_each_iteration=True,
            seed=27).take(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

    train_horses = strategy.experimental_distribute_dataset(train_horses)
    train_zebras = strategy.experimental_distribute_dataset(train_zebras)

    for sample_horse in test_horses:
        break
    for sample_zebra in test_zebras:
        break

    generator_g = pix2pix.unet_generator(output_channels=OUTPUT_CHANNELS,
                                         norm_type='instancenorm')
    generator_f = pix2pix.unet_generator(output_channels=OUTPUT_CHANNELS,
                                         norm_type='instancenorm')

    discriminator_x = pix2pix.discriminator(norm_type='instancenorm',
                                            target=False)
    discriminator_y = pix2pix.discriminator(norm_type='instancenorm',
                                            target=False)

    generator_g_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    generator_f_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_x_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_y_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

loss_obj = tf.keras.losses.BinaryCrossentropy(
    from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
示例#8
0
if uploaded_file:

    input_file = Image.open(uploaded_file).convert("RGB")
    input_file = np.asarray(input_file)

    st.write("### Uploaded Image:")
    st.image(input_file)
    st.write("### Choose a Model:")
    model = st.selectbox("", ("Vanilla", "CycleGAN"))

    if model == "CycleGAN":
        from tensorflow_examples.models.pix2pix import pix2pix

        cg_url = "https://docclean.s3.us-east-2.amazonaws.com/cg_weights.tar.gz"
        model = pix2pix.unet_generator(3, norm_type="instancenorm")
        tf.keras.utils.get_file("cg_weights", cg_url, untar=True)
        model.load_weights(os.path.expanduser("~/.keras/datasets/weights/cg"))
    else:
        ae_url = "https://docclean.s3.us-east-2.amazonaws.com/ae_weights.tar.gz"

        model = docclean.autoencoder.Autoencoder().autoencoder_model
        tf.keras.utils.get_file("ae_weights", ae_url, untar=True)
        model.load_weights(os.path.expanduser("~/.keras/datasets/weights/ae"))

    im = ImageMosaic(input_file)
    batches = im.make_patches()

    if model == "CycleGAN":
        batches = 2 * batches - 1
示例#9
0
 def __init__(self):
     self.generator = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type=NORM_TYPE)
     self.discriminator = pix2pix.discriminator(norm_type=NORM_TYPE, target=False)
     # Optimizers to perform stochastic graident descent
     self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
     self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
示例#10
0
  def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

  def preprocess_image_test(image, label):
  image = normalize(image)
  return image

  train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)
示例#11
0
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.show()

OUTPUT_CHANNELS = 3

generator_h_z = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_z_h = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_isHorse = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_isZebra = pix2pix.discriminator(norm_type='instancenorm', target=False)

to_zebra = generator_h_z(sample_horse)
to_horse = generator_z_h(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)