Exemple #1
0
    def test_functions_have_same_trace(self):
        class Layer(keras.engine.base_layer.Layer):
            def call(self, inputs):
                return inputs

            def call2(self, inputs):
                return inputs * 2

        layer = Layer()
        call_collection = keras_save.LayerCallCollection(layer)
        fn = call_collection.add_function(layer.call, 'call')
        fn2 = call_collection.add_function(layer.call2, 'call2')

        fn(np.ones((2, 3)))
        fn(np.ones((4, 5)))

        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)
        self.assertLen(fn2._list_all_concrete_functions_for_serialization(), 2)

        # Check that the shapes are correct
        self.assertEqual(
            {(2, 3), (4, 5)},
            set(
                tuple(c.structured_input_signature[0][0].shape.as_list())
                for c in fn2._list_all_concrete_functions_for_serialization()))
Exemple #2
0
    def test_maintains_losses(self):
        layer = LayerWithLoss()
        layer(np.ones((2, 3)))
        previous_losses = layer.losses[:]

        call_collection = keras_save.LayerCallCollection(layer)
        fn = call_collection.add_function(layer.call, 'call')
        fn(np.ones((2, 3)))

        self.assertAllEqual(previous_losses, layer.losses)
Exemple #3
0
        def assert_num_traces(layer_cls, training_keyword):
            layer = layer_cls()
            call_collection = keras_save.LayerCallCollection(layer)
            fn = call_collection.add_function(layer.call, 'call')

            fn(np.ones((2, 3)), training=True)
            self.assertLen(fn._list_all_concrete_functions_for_serialization(),
                           2)

            fn(np.ones((2, 4)), training=False)
            self.assertLen(fn._list_all_concrete_functions_for_serialization(),
                           4)

            if training_keyword:
                fn(np.ones((2, 5)), True)
                self.assertLen(
                    fn._list_all_concrete_functions_for_serialization(), 6)
                fn(np.ones((2, 6)))
                self.assertLen(
                    fn._list_all_concrete_functions_for_serialization(), 8)