def __init__(self): super(HasMapping, self).__init__() self.layer_dict = tf.__internal__.tracking.wrap(dict(output=core.Dense(7))) self.layer_dict["norm"] = tf.__internal__.tracking.wrap([]) self.layer_dict["dense"] = tf.__internal__.tracking.wrap([]) self.layer_dict["dense"].extend( [core.Dense(5), core.Dense(6, kernel_regularizer=tf.reduce_sum)]) self.layer_dict["norm"].append( batch_normalization_v1.BatchNormalization()) self.layer_dict["norm"].append( batch_normalization_v1.BatchNormalization())
def my_func(): layer = batch_normalization_v1.BatchNormalization() x = tf.ones((10, 1)) y = layer(x, training=True) # Updates should be tracked in a `wrap_function`. self.assertLen(layer.updates, 2) return y
def test_v1_fused_attribute(self): norm = batch_normalization_v1.BatchNormalization() inp = keras.layers.Input((4, 4, 4)) norm(inp) self.assertEqual(norm.fused, True) norm = batch_normalization_v1.BatchNormalization(fused=False) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) norm = batch_normalization_v1.BatchNormalization(virtual_batch_size=2) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(2, 2, 2)) norm(inp) self.assertEqual(norm.fused, False)
def __init__(self): super().__init__() self.layer_list = ( core.Dense(3), core.Dense(4), core.Dense(5, kernel_regularizer=tf.reduce_sum), ) self.layers_with_updates = ( batch_normalization_v1.BatchNormalization(), )
def __init__(self): super(HasList, self).__init__() self.layer_list = tf.__internal__.tracking.wrap([core.Dense(3)]) self.layer_list.append(core.Dense(4)) self.layer_list.extend( [core.Dense(5), core.Dense(6, kernel_regularizer=tf.reduce_sum)]) self.layer_list += [ core.Dense(7, bias_regularizer=tf.reduce_sum), core.Dense(8) ] self.layer_list += (tf.__internal__.tracking.wrap([core.Dense(9)]) + tf.__internal__.tracking.wrap([core.Dense(10)])) self.layer_list.extend( tf.__internal__.tracking.wrap( list([core.Dense(11)]) + [core.Dense(12)])) self.layers_with_updates = tf.__internal__.tracking.wrap( [batch_normalization_v1.BatchNormalization()])