def testCompositeTypeSpecArgWithoutDtype(self):
        for assign_variant_dtype in [False, True]:
            # Create a Keras Input
            spec = TwoTensorsSpecNoOneDtype(
                (1, 2, 3),
                tf.float32,
                (1, 2, 3),
                tf.int64,
                assign_variant_dtype=assign_variant_dtype,
            )
            x = input_layer_lib.Input(type_spec=spec)

            def lambda_fn(tensors):
                return tf.cast(tensors.x, tf.float64) + tf.cast(
                    tensors.y, tf.float64)

            # Verify you can construct and use a model w/ this input
            model = functional.Functional(x, core.Lambda(lambda_fn)(x))

            # And that the model works
            two_tensors = TwoTensors(
                tf.ones((1, 2, 3)) * 2.0, tf.ones(1, 2, 3))
            self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))

            # Test serialization / deserialization
            model = functional.Functional.from_config(model.get_config())
            self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
            model = model_config.model_from_json(model.to_json())
            self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
    def testInputTensorArg(self):
        # Create a Keras Input
        x = input_layer_lib.Input(tensor=tf.zeros((7, 32)))
        self.assertAllEqual(x.shape.as_list(), [7, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(tf.ones(x.shape)), tf.ones(x.shape) * 2.0)
    def testBasicOutputShapeWithBatchSize(self):
        # Create a Keras Input
        x = input_layer_lib.Input(batch_size=6, shape=(32, ), name="input_b")
        self.assertAllEqual(x.shape.as_list(), [6, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(tf.ones(x.shape)), tf.ones(x.shape) * 2.0)
Exemple #4
0
    def testBasicOutputShapeNoBatchSize(self):
        # Create a Keras Input
        x = input_layer_lib.Input(shape=(32, ), name='input_a')
        self.assertAllEqual(x.shape.as_list(), [None, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(tf.ones((3, 32))), tf.ones((3, 32)) * 2.0)
        def run_model(inp):
            if not model_container:
                # Create a Keras Input
                x = input_layer_lib.Input(tensor=tf.zeros((10, 16)))
                self.assertAllEqual(x.shape.as_list(), [10, 16])

                # Verify you can construct and use a model w/ this input
                model_container["model"] = functional.Functional(x, x * 3.0)
            return model_container["model"](inp)
        def run_model(inp):
            nonlocal model
            if not model:
                # Create a Keras Input
                x = input_layer_lib.Input(shape=(8, ), name="input_a")
                self.assertAllEqual(x.shape.as_list(), [None, 8])

                # Verify you can construct and use a model w/ this input
                model = functional.Functional(x, x * 2.0)
            return model(inp)
Exemple #7
0
        def run_model(inp):
            if not model_container:
                # Create a Keras Input
                x = input_layer_lib.Input(
                    type_spec=tf.TensorSpec((10, 16), tf.float32))
                self.assertAllEqual(x.shape.as_list(), [10, 16])

                # Verify you can construct and use a model w/ this input
                model_container['model'] = functional.Functional(x, x * 3.0)
            return model_container['model'](inp)
Exemple #8
0
    def run_model(inp):
      if not model_container:
        # Create a Keras Input
        rt = tf.RaggedTensor.from_row_splits(
            values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
        x = input_layer_lib.Input(type_spec=rt._type_spec)

        # Verify you can construct and use a model w/ this input
        model_container['model'] = functional.Functional(x, x * 3)
      return model_container['model'](inp)
    def testCompositeInputTensorArg(self):
        # Create a Keras Input
        rt = tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
                                             row_splits=[0, 4, 4, 7, 8, 8])
        x = input_layer_lib.Input(tensor=rt)

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2)

        # And that the model works
        rt = tf.RaggedTensor.from_row_splits(values=[3, 21, 4, 1, 53, 9, 2, 6],
                                             row_splits=[0, 4, 4, 7, 8, 8])
        self.assertAllEqual(model(rt), rt * 2)
Exemple #10
0
    def testTypeSpecArg(self):
        # Create a Keras Input
        x = input_layer_lib.Input(type_spec=tf.TensorSpec((7, 32), tf.float32))
        self.assertAllEqual(x.shape.as_list(), [7, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(tf.ones(x.shape)), tf.ones(x.shape) * 2.0)

        # Test serialization / deserialization
        model = functional.Functional.from_config(model.get_config())
        self.assertAllEqual(model(tf.ones(x.shape)), tf.ones(x.shape) * 2.0)

        model = model_config.model_from_json(model.to_json())
        self.assertAllEqual(model(tf.ones(x.shape)), tf.ones(x.shape) * 2.0)
Exemple #11
0
    def testCompositeTypeSpecArg(self):
        # Create a Keras Input
        rt = tf.RaggedTensor.from_row_splits(values=[3, 1, 4, 1, 5, 9, 2, 6],
                                             row_splits=[0, 4, 4, 7, 8, 8])
        x = input_layer_lib.Input(type_spec=rt._type_spec)

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2)

        # And that the model works
        rt = tf.RaggedTensor.from_row_splits(values=[3, 21, 4, 1, 53, 9, 2, 6],
                                             row_splits=[0, 4, 4, 7, 8, 8])
        self.assertAllEqual(model(rt), rt * 2)

        # Test serialization / deserialization
        model = functional.Functional.from_config(model.get_config())
        self.assertAllEqual(model(rt), rt * 2)
        model = model_config.model_from_json(model.to_json())
        self.assertAllEqual(model(rt), rt * 2)