コード例 #1
0
def get_fid_tf(real_img, sample_img):
    print('debug/real_img type')
    print(real_img)
    real_img = tf.convert_to_tensor(real_img)
    real_img = preprocess_image(real_img)
    print(real_img)
    print('debug/sample_img type')
    print(sample_img)
    sample_img = tf.convert_to_tensor(sample_img)
    sample_img = preprocess_image(sample_img)
    print(sample_img)
    real_single = (real_img.shape.ndims == 3)
    sample_single = (sample_img.shape.ndims == 3)
    if real_single and sample_single:
        sample_img = tf.concat([sample_img, sample_img, sample_img], 3)
        real_img = tf.concat([real_img, real_img, real_img], 3)

    ll = real_img.shape[0]
    print('debug/ll')
    print(ll)

    pbs = 4
    while ll % pbs:
        pbs += 1

    FID = frechet_inception_distance(real_img,
                                     sample_img,
                                     num_batches=ll // pbs)
    return FID
コード例 #2
0
 def test_preprocess_image_graph(self):
   """Test `preprocess_image` graph construction."""
   incorrectly_sized_image = array_ops.zeros([520, 240, 3])
   correct_image = classifier_metrics.preprocess_image(
       images=incorrectly_sized_image)
   _run_with_mock(classifier_metrics.run_inception,
                  array_ops.expand_dims(correct_image, 0))
コード例 #3
0
 def test_preprocess_image_graph(self):
     """Test `preprocess_image` graph construction."""
     incorrectly_sized_image = array_ops.zeros([520, 240, 3])
     correct_image = classifier_metrics.preprocess_image(
         images=incorrectly_sized_image)
     _run_with_mock(classifier_metrics.run_inception,
                    array_ops.expand_dims(correct_image, 0))
コード例 #4
0
def get_inception_score_tf(sample_img):
    print('debug/sample_img type')
    print(sample_img.shape)
    print(sample_img)
    sample_img = tf.convert_to_tensor(sample_img)
    sample_img = preprocess_image(sample_img)
    print(sample_img.shape)
    print(sample_img)

    ll = sample_img.shape[0]
    print('debug/ll')
    print(ll)
    print('debug/sample_img 2')
    print(sample_img.shape)

    pbs = 10
    while True:
        print('debug/pbs', pbs)
        if ll % pbs == 0:
            break
        else:
            pbs -= 1

    print('debug/batch_size')
    print(pbs)

    IS = classifier_score(sample_img,
                          functools.partial(
                              run_image_classifier,
                              graph_def=graph_def,
                              input_tensor=INCEPTION_INPUT,
                              output_tensor=INCEPTION_OUTPUT,
                          ),
                          num_batches=ll // 100)

    # generated_images_list = array_ops.split(
    #     sample_img, num_or_size_splits=ll//pbs)
    #
    # # Compute the classifier splits using the memory-efficient `map_fn`.
    # logits = functional_ops.map_fn(
    #     fn=functools.partial(
    #         run_image_classifier,
    #         graph_def=graph_def,
    #         input_tensor=INCEPTION_INPUT,
    #         output_tensor=INCEPTION_OUTPUT,),
    #     elems=array_ops.stack(generated_images_list),
    #     parallel_iterations=1,
    #     back_prop=False,
    #     swap_memory=True,
    #     name='RunClassifier')
    # logits = array_ops.concat(array_ops.unstack(logits), 0)
    # print('logits',logits)
    # logits.shape.assert_has_rank(2)
    #
    # # Use maximum precision for best results.
    # logits_dtype = logits.dtype
    # if logits_dtype != dtypes.float64:
    #     logits = math_ops.to_double(logits)
    #
    # p = nn_ops.softmax(logits)
    # q = math_ops.reduce_mean(p, axis=0)
    # kl = _kl_divergence(p, logits, q)
    # print('pqkl',p,q,kl)
    # kl.shape.assert_has_rank(1)
    #
    # sess = tf.Session()
    # sess.run(tf.Print(kl, [kl]))
    #
    # log_score = math_ops.reduce_mean(kl)
    # print('log_score',log_score)
    # final_score = math_ops.exp(log_score)
    # print('final',final_score)
    # if logits_dtype != dtypes.float64:
    #     final_score = math_ops.cast(final_score, logits_dtype)
    # IS = final_score

    print('debug/end IS')
    IS = IS.shape
    print('debug/end shape')

    return IS