def test_fail_due_to_call_in_invalid_place_subclass():
    """this test is failed. `set_seed` must be called at just before generating random variables"""
    # not reproducible. if call `build` twice, not same variables are generated in each built models.
    set_seed(200)

    def build():
        class SubClass(tf.keras.Model):
            def __init__(self, *args, **kwargs):
                super(SubClass, self).__init__(*args, **kwargs)
                self.a = tf.keras.layers.Conv2D(4, 3)
                self.b = tf.keras.layers.BatchNormalization()
                self.c = tf.keras.layers.Activation("relu")
                self.d = tf.keras.layers.Flatten()
                self.e = tf.keras.layers.Dense(2)

            def call(self, x):
                x = self.a(x)
                x = self.b(x)
                x = self.c(x)
                x = self.d(x)
                return self.e(x)

        layers = [SubClass()]
        model = build_helper(layers, (16, 16, 3))
        return model

    assert not is_reproducible(build)
 def build():
     set_seed(200)
     layers = [
         tf.keras.layers.Conv2D(4, 3, input_shape=(16, 16, 3)),
         tf.keras.layers.BatchNormalization(),
         tf.keras.layers.Activation("relu"),
         tf.keras.layers.Flatten(),
         tf.keras.layers.Dense(2),
     ]
     model = tf.keras.Sequential(layers)
     return model
 def build():
     set_seed(200)
     layers = [
         tf.keras.layers.Conv2D(4, 3),
         tf.keras.layers.BatchNormalization(),
         tf.keras.layers.Activation("relu"),
         tf.keras.layers.Flatten(),
         tf.keras.layers.Dense(2),
     ]
     model = build_helper(layers, (16, 16, 3))
     return model
def test_fail_due_to_call_in_invalid_place_functional_api():
    """this test is failed. `set_seed` must be called at just before generating random variables"""
    # not reproducible. if call `build` twice, not same variables are generated in each built models.
    set_seed(200)

    def build():
        layers = [
            tf.keras.layers.Conv2D(4, 3),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation("relu"),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(2),
        ]
        model = build_helper(layers, (16, 16, 3))
        return model

    assert not is_reproducible(build)
    def build():
        set_seed(200)

        class SubClass(tf.keras.Model):
            def __init__(self, *args, **kwargs):
                super(SubClass, self).__init__(*args, **kwargs)
                self.a = tf.keras.layers.Conv2D(4, 3)
                self.b = tf.keras.layers.BatchNormalization()
                self.c = tf.keras.layers.Activation("relu")
                self.d = tf.keras.layers.Flatten()
                self.e = tf.keras.layers.Dense(2)

            def call(self, x):
                x = self.a(x)
                x = self.b(x)
                x = self.c(x)
                x = self.d(x)
                return self.e(x)

        layers = [SubClass()]
        model = build_helper(layers, (16, 16, 3))
        return model