Example #1
0
def process_path(file_path, patch_size):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    img = cyclegan_dp.full_image_to_patch(img, patch_size)
    return img, label
Example #2
0
def provide_cyclegan_test_set(tfds_name, patch_size, num_images=6):
    """Provide one example of every class.

  Args:
    tfds_name: string, tfds name
    patch_size: Python int. The patch size to extract.

  Returns:
    An `np.array` of shape (num_domains, H, W, C) representing the images.
      Values are in [-1, 1].
  """
    ds = tfds.load(tfds_name)

    num_images_B = num_images // 2
    num_images_A = num_images - num_images_B

    examples_A = list(tfds.as_numpy(ds['testA'].take(num_images_A)))
    examples_B = list(tfds.as_numpy(ds['testB'].take(num_images_B)))

    images = [
        tfds.as_numpy(cyclegan_dp.full_image_to_patch(x['image'], patch_size))
        for x in examples_A + examples_B
    ]
    images = np.array(images, dtype=np.float32)

    assert images.dtype == np.float32
    assert np.max(np.abs(images)) <= 1.0
    assert images.shape == (num_images, patch_size, patch_size, 3)

    return images
Example #3
0
 def _preprocess(x):
     return {
         'image':
         cyclegan_dp.full_image_to_patch(x['image'], patch_size,
                                         num_channels),
         'label':
         x['label'],
     }
Example #4
0
 def _preprocess(*elements):
     """Map elements to the example dicts expected by the model."""
     output_dict = {}
     num_domains = len(elements)
     for idx, (domain, elem) in enumerate(zip(domains, elements)):
         uint8_img = elem['image']
         patch = data_provider.full_image_to_patch(uint8_img, patch_size)
         label = tf.one_hot(idx, num_domains)
         output_dict[domain] = {'images': patch, 'labels': label}
     return output_dict
Example #5
0
 def _preprocess(*elements):
     """Map elements to the example dicts expected by the model."""
     output_dict = {}
     for idx, elem in enumerate(elements):
         uint8_img = elem['image']
         patch = data_provider.full_image_to_patch(uint8_img, patch_size,
                                                   num_channels)
         label = tf.one_hot(idx, num_classes)
         output_dict[idx] = {'images': patch, 'labels': label}
     return output_dict
Example #6
0
def make_inference_graph(model_name, patch_dim):
  """Build the inference graph for either the X2Y or Y2X GAN.

  Args:
    model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
    patch_dim: An integer size of patches to feed to the generator.

  Returns:
    Tuple of (input_placeholder, generated_tensor).
  """
  input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])

  # Expand HWC to NHWC
  images_x = tf.expand_dims(
      data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)

  with tf.variable_scope(model_name):
    with tf.variable_scope('Generator'):
      generated = networks.generator(images_x)
  return input_hwc_pl, generated
Example #7
0
 def _preprocess(x):
     return {
         'image': cyclegan_dp.full_image_to_patch(x['image'], patch_size),
         'attributes': x['attributes'],
     }
Example #8
0
def gen_transformed_images(checkpoint_struct, output, num_examples, data_dir):
    """ REF: make_summary_images at evaluate_multiple.py"""

    if data_dir:
        apple_paths = glob(os.path.join(data_dir, "testA", "*.jpg"))
        orange_paths = glob(os.path.join(data_dir, "testB", "*.jpg"))
        if num_examples > 0:
            apple_paths = apple_paths[:num_examples]
            orange_paths = orange_paths[:num_examples]

        examples_apples = load_data_from(apple_paths)
        examples_oranges = load_data_from(orange_paths)

    if os.path.exists(output):
        shutil.rmtree(output)

    # else:
    #     ds = tfds.load(dataset_name)
    #     examples_apples = list(tfds.as_numpy(ds['testA'].take(num_examples)))
    #     examples_oranges = list(tfds.as_numpy(ds['testB'].take(num_examples)))
    input_apples = [
        tfds.as_numpy(cyclegan_dp.full_image_to_patch(x['image'], 128))
        for x in examples_apples
    ]
    input_oranges = [
        tfds.as_numpy(cyclegan_dp.full_image_to_patch(x['image'], 128))
        for x in examples_oranges
    ]

    stargan_estimator = tfgan.estimator.StarGANEstimator(
        model_dir=None,
        generator_fn=network.generator,
        discriminator_fn=network.discriminator,
        loss_fn=tfgan.stargan_loss,
        add_summaries=tfgan.estimator.SummaryType.IMAGES)

    summary_apples, summary_oranges = {}, {}
    for checkpoint_path in checkpoint_struct['all']:
        print("checkpoint: {}".format(checkpoint_path))
        # ckpt
        ckpt_name = os.path.basename(checkpoint_path)
        output_path = os.path.join(output, ckpt_name)
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        apple2orange = translate_images(stargan_estimator, input_apples, 1,
                                        checkpoint_path, 2)
        orange2apple = translate_images(stargan_estimator, input_oranges, 0,
                                        checkpoint_path, 2)
        apple2orange = np.array(apple2orange)
        orange2apple = np.array(orange2apple)
        for i in range(len(input_apples)):
            # image_np = apple2orange[np.newaxis, i,:]
            # image_np = cyclegan_dp.undo_normalize_image(image_np)
            image_np = apple2orange[i]
            fn = os.path.basename(apple_paths[i])
            fp = os.path.join(output_path, "generated_from_" + fn)
            PIL.Image.fromarray((255 * image_np).astype(np.uint8)).save(fp)

        for i in range(len(input_oranges)):
            # image_np = orange2apple[np.newaxis,i,:]
            # image_np = cyclegan_dp.undo_normalize_image(image_np)
            image_np = orange2apple[i]
            fn = os.path.basename(orange_paths[i])
            fp = os.path.join(output_path, "generated_from_" + fn)
            PIL.Image.fromarray((255 * image_np).astype(np.uint8)).save(fp)
Example #9
0
def make_summary_images(checkpoint_dir,
                        checkpoint_struct,
                        dataset_name,
                        num_examples=10):

    ds = tfds.load(dataset_name)
    examples_apples = list(tfds.as_numpy(ds['testA'].take(num_examples)))
    examples_oranges = list(tfds.as_numpy(ds['testB'].take(num_examples)))
    input_apples = [
        tfds.as_numpy(cyclegan_dp.full_image_to_patch(x['image'],
                                                      256)).astype('float32')
        for x in examples_apples
    ]
    input_oranges = [
        tfds.as_numpy(cyclegan_dp.full_image_to_patch(x['image'],
                                                      256)).astype('float32')
        for x in examples_oranges
    ]

    # discriminator_fn = network.discriminator
    # discriminator_fn = network.custom_tf_discriminator()
    # discriminator_fn = network.CustomKerasDiscriminator('/home/ec2-user/gan/test_model/a2o_rmsp/base_model.h5')
    discriminator_fn = network.CustomKerasDiscriminator(
        '/home/ec2-user/gan/test_model/a2o_v2/base_model.h5')

    stargan_estimator = tfgan.estimator.StarGANEstimator(
        model_dir=None,
        generator_fn=network.generator,
        discriminator_fn=discriminator_fn,
        loss_fn=tfgan.stargan_loss,
        add_summaries=tfgan.estimator.SummaryType.IMAGES)

    summary_apples, summary_oranges = {}, {}
    for checkpoint_path in checkpoint_struct['all']:
        summary_apples[checkpoint_path] = translate_images(
            stargan_estimator, input_apples, 1, checkpoint_path, 2)
        summary_oranges[checkpoint_path] = translate_images(
            stargan_estimator, input_oranges, 0, checkpoint_path, 2)

    for key, struct in checkpoint_struct.items():
        all_rows_apple, all_rows_orange = [], []

        for ind in range(num_examples):
            image_apple = (input_apples[ind] + 1.0) / 2.0
            image_orange = (input_oranges[ind] + 1.0) / 2.0
            row_apple = np.concatenate(
                [image_apple] + [summary_apples[x][ind] for x in struct], 1)
            row_orange = np.concatenate(
                [image_orange] + [summary_oranges[x][ind] for x in struct], 1)

            all_rows_apple.append(row_apple)
            all_rows_orange.append(row_orange)

        summary_apple = np.concatenate(all_rows_apple, 0)
        summary_orange = np.concatenate(all_rows_orange, 0)

        with tf.io.gfile.GFile(
                checkpoint_dir + 'summary_apple_' + key + '.png', 'w') as f:
            PIL.Image.fromarray(
                (255 * summary_apple).astype(np.uint8)).save(f, 'PNG')
        with tf.io.gfile.GFile(
                checkpoint_dir + 'summary_orange_' + key + '.png', 'w') as f:
            PIL.Image.fromarray(
                (255 * summary_orange).astype(np.uint8)).save(f, 'PNG')