def test_invalid_input(self):
        """Test that functions properly fail on invalid input."""
        with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
            classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))

        p = array_ops.zeros([8, 10])
        p_logits = array_ops.zeros([8, 10])
        q = array_ops.zeros([10])
        with self.assertRaisesRegexp(ValueError, 'must be floating type'):
            classifier_metrics._kl_divergence(
                array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)

        with self.assertRaisesRegexp(ValueError, 'must be floating type'):
            classifier_metrics._kl_divergence(
                p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)

        with self.assertRaisesRegexp(ValueError, 'must be floating type'):
            classifier_metrics._kl_divergence(
                p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))

        with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
            classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits,
                                              q)

        with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
            classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)

        with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
            classifier_metrics._kl_divergence(p, p_logits,
                                              array_ops.zeros([10, 8]))
  def test_invalid_input(self):
    """Test that functions properly fail on invalid input."""
    with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
      classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))

    p = array_ops.zeros([8, 10])
    p_logits = array_ops.zeros([8, 10])
    q = array_ops.zeros([10])
    with self.assertRaisesRegexp(ValueError, 'must be floating type'):
      classifier_metrics._kl_divergence(
          array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)

    with self.assertRaisesRegexp(ValueError, 'must be floating type'):
      classifier_metrics._kl_divergence(p,
                                        array_ops.zeros(
                                            [8, 10], dtype=dtypes.int32), q)

    with self.assertRaisesRegexp(ValueError, 'must be floating type'):
      classifier_metrics._kl_divergence(p, p_logits,
                                        array_ops.zeros(
                                            [10], dtype=dtypes.int32))

    with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
      classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)

    with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
      classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)

    with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
      classifier_metrics._kl_divergence(p, p_logits, array_ops.zeros([10, 8]))
  def test_run_inception_graph(self, use_default_graph_def):
    """Test `run_inception` graph construction."""
    batch_size = 7
    img = array_ops.ones([batch_size, 299, 299, 3])

    if use_default_graph_def:
      logits = _run_with_mock(classifier_metrics.run_inception, img)
    else:
      logits = classifier_metrics.run_inception(img, _get_dummy_graphdef())

    self.assertIsInstance(logits, ops.Tensor)
    logits.shape.assert_is_compatible_with([batch_size, 1001])

    # Check that none of the model variables are trainable.
    self.assertListEqual([], variables.trainable_variables())
Esempio n. 4
0
  def test_run_inception_graph(self, use_default_graph_def):
    """Test `run_inception` graph construction."""
    batch_size = 7
    img = array_ops.ones([batch_size, 299, 299, 3])

    if use_default_graph_def:
      logits = _run_with_mock(classifier_metrics.run_inception, img)
    else:
      logits = classifier_metrics.run_inception(img, _get_dummy_graphdef())

    self.assertTrue(isinstance(logits, ops.Tensor))
    logits.shape.assert_is_compatible_with([batch_size, 1001])

    # Check that none of the model variables are trainable.
    self.assertListEqual([], variables.trainable_variables())
  def test_run_inception_graph_pool_output(self, use_default_graph_def):
    """Test `run_inception` graph construction with pool output."""
    batch_size = 3
    img = array_ops.ones([batch_size, 299, 299, 3])

    if use_default_graph_def:
      pool = _run_with_mock(
          classifier_metrics.run_inception,
          img,
          output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
    else:
      pool = classifier_metrics.run_inception(
          img, _get_dummy_graphdef(),
          output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)

    self.assertIsInstance(pool, ops.Tensor)
    pool.shape.assert_is_compatible_with([batch_size, 2048])

    # Check that none of the model variables are trainable.
    self.assertListEqual([], variables.trainable_variables())
Esempio n. 6
0
  def test_run_inception_graph_pool_output(self, use_default_graph_def):
    """Test `run_inception` graph construction with pool output."""
    batch_size = 3
    img = array_ops.ones([batch_size, 299, 299, 3])

    if use_default_graph_def:
      pool = _run_with_mock(
          classifier_metrics.run_inception,
          img,
          output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
    else:
      pool = classifier_metrics.run_inception(
          img, _get_dummy_graphdef(),
          output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)

    self.assertTrue(isinstance(pool, ops.Tensor))
    pool.shape.assert_is_compatible_with([batch_size, 2048])

    # Check that none of the model variables are trainable.
    self.assertListEqual([], variables.trainable_variables())
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--real_dir', type=str, default='results/258')
    parser.add_argument('--fake_dir', type=str, default='results/566')
    parser.add_argument('--get_fid', action=store_true)
    parser.add_argument('--get_sanity_check', action=store_true)
    parser.add_argument('--get_embeddings', action=store_true)

    args = parser.parse_args()

    # CHANGE THESE PATHS TO DIRS
    real_dir = Path(args.real_dir)
    fake_dir = Path(args.fake_dir)

    def load_image(image_path):
        img = Image.open(image_path)
        newImg = img.resize((299, 299), PIL.Image.BILINEAR).convert("RGB")
        data = np.array(newImg.getdata())
        return 2 * (data.reshape(
            (newImg.size[0], newImg.size[1], 3)).astype(np.float32) / 255) - 1

    real_imgs = []
    for x in real_dir.iterdir():
        if 'png' in str(x):
            real_img = load_image(x)
            real_imgs.append(real_img)
    # print(real_imgs)

    fake_imgs = []
    for x in fake_dir.iterdir():
        if 'png' in str(x):
            fake_img = load_image(x)
            fake_imgs.append(fake_img)
    # print(fake_imgs)

    real_imgs = np.array(real_imgs)
    print(real_imgs.shape)
    fake_imgs = np.array(fake_imgs)
    print(fake_imgs.shape)

    # KID
    kid = classifier_metrics.kernel_inception_distance(
        real_images=real_imgs, generated_images=fake_imgs)
    with tf.Session() as sess:
        print('KID:', sess.run(kid))

    # Kernel Classifier Distance with Inception (technically KID)
    if args.get_sanity_check:
        kid_general = classifier_metrics.kernel_classifier_distance(
            real_images=real_imgs,
            generated_images=fake_imgs,
            classifier_fn=classifier_metrics.run_inception)
        with tf.Session() as sess:
            print('KID sanity check:', sess.run(kid_general))

    # FID
    if args.get_fid:
        fid = classifier_metrics.frechet_inception_distance(
            real_images=real_imgs, generated_images=fake_imgs)
        with tf.Session() as sess:
            print('FID:', sess.run(fid))

    # Get imagenet embeddings
    if args.get_embeddings:
        inception_embeddings_real = classifier_metrics.run_inception(real_imgs)
        inception_embeddings_fake = classifier_metrics.run_inception(fake_imgs)
        with tf.Session() as sess:
            # TODO: save embeddings
            print(sess.run(inception_embeddings_real))
            print(sess.run(inception_embeddings_fake))