def testCompositeTypeSpecArgWithoutDtype(self):
        for assign_variant_dtype in [False, True]:
            # Create a Keras Input
            spec = TwoTensorsSpecNoOneDtype(
                (1, 2, 3),
                dtypes.float32, (1, 2, 3),
                dtypes.int64,
                assign_variant_dtype=assign_variant_dtype)
            x = input_layer_lib.Input(type_spec=spec)

            def lambda_fn(tensors):
                return (math_ops.cast(tensors.x, dtypes.float64) +
                        math_ops.cast(tensors.y, dtypes.float64))

            # Verify you can construct and use a model w/ this input
            model = functional.Functional(x, core.Lambda(lambda_fn)(x))

            # And that the model works
            two_tensors = TwoTensors(
                array_ops.ones((1, 2, 3)) * 2.0, array_ops.ones(1, 2, 3))
            self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))

            # Test serialization / deserialization
            model = functional.Functional.from_config(model.get_config())
            self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
            model = model_config.model_from_json(model.to_json())
            self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
        def run_model(inp):
            if not model_container:
                # Create a Keras Input
                x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16)))
                self.assertAllEqual(x.shape.as_list(), [10, 16])

                # Verify you can construct and use a model w/ this input
                model_container['model'] = functional.Functional(x, x * 3.0)
            return model_container['model'](inp)
    def testInputTensorArg(self):
        # Create a Keras Input
        x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32)))
        self.assertAllEqual(x.shape.as_list(), [7, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(array_ops.ones(x.shape)),
                            array_ops.ones(x.shape) * 2.0)
    def testBasicOutputShapeWithBatchSize(self):
        # Create a Keras Input
        x = input_layer_lib.Input(batch_size=6, shape=(32, ), name='input_b')
        self.assertAllEqual(x.shape.as_list(), [6, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(array_ops.ones(x.shape)),
                            array_ops.ones(x.shape) * 2.0)
        def run_model(inp):
            if not model_container:
                # Create a Keras Input
                x = input_layer_lib.Input(
                    type_spec=tensor_spec.TensorSpec((10, 16), dtypes.float32))
                self.assertAllEqual(x.shape.as_list(), [10, 16])

                # Verify you can construct and use a model w/ this input
                model_container['model'] = functional.Functional(x, x * 3.0)
            return model_container['model'](inp)
        def run_model(inp):
            nonlocal model
            if not model:
                # Create a Keras Input
                x = input_layer_lib.Input(shape=(8, ), name='input_a')
                self.assertAllEqual(x.shape.as_list(), [None, 8])

                # Verify you can construct and use a model w/ this input
                model = functional.Functional(x, x * 2.0)
            return model(inp)
        def run_model(inp):
            if not model_container:
                # Create a Keras Input
                rt = ragged_tensor.RaggedTensor.from_row_splits(
                    values=[3, 1, 4, 1, 5, 9, 2, 6],
                    row_splits=[0, 4, 4, 7, 8, 8])
                x = input_layer_lib.Input(type_spec=rt._type_spec)

                # Verify you can construct and use a model w/ this input
                model_container['model'] = functional.Functional(x, x * 3)
            return model_container['model'](inp)
    def testCompositeInputTensorArg(self):
        # Create a Keras Input
        rt = ragged_tensor.RaggedTensor.from_row_splits(
            values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
        x = input_layer_lib.Input(tensor=rt)

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2)

        # And that the model works
        rt = ragged_tensor.RaggedTensor.from_row_splits(
            values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
        self.assertAllEqual(model(rt), rt * 2)
    def testCompositeTypeSpecArg(self):
        # Create a Keras Input
        rt = ragged_tensor.RaggedTensor.from_row_splits(
            values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
        x = input_layer_lib.Input(type_spec=rt._type_spec)

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2)

        # And that the model works
        rt = ragged_tensor.RaggedTensor.from_row_splits(
            values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
        self.assertAllEqual(model(rt), rt * 2)

        # Test serialization / deserialization
        model = functional.Functional.from_config(model.get_config())
        self.assertAllEqual(model(rt), rt * 2)
        model = model_config.model_from_json(model.to_json())
        self.assertAllEqual(model(rt), rt * 2)
    def testTypeSpecArg(self):
        # Create a Keras Input
        x = input_layer_lib.Input(
            type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
        self.assertAllEqual(x.shape.as_list(), [7, 32])

        # Verify you can construct and use a model w/ this input
        model = functional.Functional(x, x * 2.0)
        self.assertAllEqual(model(array_ops.ones(x.shape)),
                            array_ops.ones(x.shape) * 2.0)

        # Test serialization / deserialization
        model = functional.Functional.from_config(model.get_config())
        self.assertAllEqual(model(array_ops.ones(x.shape)),
                            array_ops.ones(x.shape) * 2.0)

        model = model_config.model_from_json(model.to_json())
        self.assertAllEqual(model(array_ops.ones(x.shape)),
                            array_ops.ones(x.shape) * 2.0)