예제 #1
0
def cifar_standardization(x, mode='FEATURE_NORMALIZE', data_samples=None):
    mode = mode.upper()
    assert mode in ['FEATURE_NORMALIZE', 'PIXEL_MEAN_SUBTRACT']

    if mode == 'PIXEL_MEAN_SUBTRACT' and not data_samples:
        raise ValueError('`data_samples` argument should not be `None`, '
                         'when `mode="PIXEL_MEAN_SUBTRACT"`.')

    if mode == 'FEATURE_NORMALIZE':
        cifar_mean = tf.cast(CIFAR_MEAN, tf.float32).numpy()
        cifar_std = tf.cast(CIFAR_STD, tf.float32).numpy()

        x = Rescaling(scale=1. / cifar_std,
                      offset=-(cifar_mean / cifar_std),
                      name='mean_normalization')(x)
    elif mode == 'PIXEL_MEAN_SUBTRACT':
        mean_subtraction_layer = Normalization(axis=[1, 2, 3],
                                               name='pixel_mean_subtraction')
        mean_subtraction_layer.adapt(data_samples)

        # set values of variance = 1. and keep mean values as is
        mean_pixels = mean_subtraction_layer.get_weights()[0]
        mean_subtraction_layer.set_weights(
            [mean_pixels, tf.ones_like(mean_pixels)])

        x = mean_subtraction_layer(x)
        x = Rescaling(scale=1 / 255., name='rescaling')(x)
    return x
already-preprocessed data. The reason being that, if your
model expects preprocessed data, any time you export
your model to use it elsewhere (in a web browser, in a
mobile app), you'll need to reimplement the same exact
preprocessing pipeline. This can be a bit tricky to do.
"""
# normalize in range [0, 1]
scaling_layer = Rescaling(1.0 / 255)
# normalize in range [-1, 1]
input_ = tf.keras.Input(shape=(32, 32, 3))
norm_neg_one_to_one = Normalization()
x = norm_neg_one_to_one(input_)
import numpy as np
mean = [127.5]*3
var = mean ** 2
norm_neg_one_to_one.set_weights([mean, var])
norm_neg_one_to_one.get_weights()

# normalize with mean 0 and std 1
norm_mean_std = Normalization()
norm_mean_std.adapt(x_train[0])

model_ = Sequential([
    tf.keras.Input(shape=(32, 32, 3)),
    norm_mean_std,
    model
])

model_.compile(
optimizer="Adam",
loss="sparse_categorical_crossentropy",