Esempio n. 1
0
def test_rms_normalization(partial, bias):  # pylint:disable=unused-argument
    """ Basic test for RMS Layer normalization. """
    layer_test(normalization.RMSNormalization,
               kwargs={
                   "partial": partial,
                   "bias": bias
               },
               input_shape=(4, 512))
Esempio n. 2
0
def test_layer_normalization(center, scale):
    """ Basic test for layer normalization. """
    layer_test(normalization.LayerNormalization,
               kwargs={
                   "center": center,
                   "scale": scale
               },
               input_shape=(4, 512))
Esempio n. 3
0
def test_group_normalization(dummy):  # pylint:disable=unused-argument
    """ Basic test for instance normalization. """
    layer_test(normalization.GroupNormalization,
               kwargs={
                   'epsilon': 0.1,
                   'gamma_regularizer': regularizers.l2(0.01),
                   'beta_regularizer': regularizers.l2(0.01)
               },
               input_shape=(4, 3, 4, 128))
    layer_test(normalization.GroupNormalization,
               kwargs={
                   'epsilon': 0.1,
                   'axis': 1
               },
               input_shape=(4, 1, 4, 256))
    layer_test(normalization.GroupNormalization,
               kwargs={
                   'gamma_init': 'ones',
                   'beta_init': 'ones'
               },
               input_shape=(4, 64))
    layer_test(normalization.GroupNormalization,
               kwargs={
                   'epsilon': 0.1,
                   'axis': 1,
                   'group': 16
               },
               input_shape=(3, 64))
Esempio n. 4
0
def test_instance_normalization(dummy):  # pylint:disable=unused-argument
    """ Basic test for instance normalization. """
    layer_test(normalization.InstanceNormalization,
               kwargs={
                   'epsilon': 0.1,
                   'gamma_regularizer': regularizers.l2(0.01),
                   'beta_regularizer': regularizers.l2(0.01)
               },
               input_shape=(3, 4, 2))
    layer_test(normalization.InstanceNormalization,
               kwargs={
                   'epsilon': 0.1,
                   'axis': 1
               },
               input_shape=(1, 4, 1))
    layer_test(normalization.InstanceNormalization,
               kwargs={
                   'gamma_initializer': 'ones',
                   'beta_initializer': 'ones'
               },
               input_shape=(3, 4, 2, 4))
    layer_test(normalization.InstanceNormalization,
               kwargs={
                   'epsilon': 0.1,
                   'axis': 1,
                   'scale': False,
                   'center': False
               },
               input_shape=(3, 4, 2, 4))