'x': tf.TensorSpec([4], tf.float32),
        'y': tf.TensorSpec([5], tf.float32),
        'z': tf.TensorSpec([6], tf.float32),
        'a': tf.TensorSpec([1], tf.float32),
        'b': tf.TensorSpec([2], tf.float32),
        'c': tf.TensorSpec([3], tf.float32),
    }])
    def f0004_dict_many_keys(self, d):
        return

    # Check a slightly more complex recursive structure.
    # Note that list elements can have heterogenous types.
    #
    # CHECK:      func {{@[a-zA-Z_0-9]+}}(
    # CHECK-SAME:   %arg0: tensor<1xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x", 0]},
    # CHECK-SAME:   %arg1: tensor<2xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x", 1]},
    # CHECK-SAME:   %arg2: tensor<3xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "y"]})
    # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_more_complex_recursive_structure"]
    @tf.function(input_signature=[{
        'x': [tf.TensorSpec([1], tf.float32),
              tf.TensorSpec([2], tf.float32)],
        'y':
        tf.TensorSpec([3], tf.float32),
    }])
    def f0005_more_complex_recursive_structure(self, d):
        return


if __name__ == '__main__':
    common.do_test(TestModule)
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# RUN: %p/partially_shaped_variables | FileCheck %s

# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common


class TestModule(tf.Module):
    def __init__(self):
        super(TestModule, self).__init__()
        # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v0"], type = tensor<*xf32>, value = dense<0.000000e+00> : tensor<1xf32>} : () -> ()
        # CHECK: "tf_saved_model.global_tensor"() {is_mutable, {{.*}} tf_saved_model.exported_names = ["v1"], type = tensor<?xf32>, value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> ()
        self.v0 = tf.Variable([0.], shape=tf.TensorShape(None))
        self.v1 = tf.Variable([0., 1.], shape=[None])


if __name__ == '__main__':
    common.do_test(TestModule, exported_names=[])
示例#3
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# RUN: %p/debug_info | FileCheck %s

# pylint: disable=missing-docstring,line-too-long
import tensorflow.compat.v2 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common


class TestModule(tf.Module):
    @tf.function(input_signature=[
        tf.TensorSpec([], tf.float32),
        tf.TensorSpec([], tf.float32)
    ])
    def some_function(self, x, y):
        return x + y
        # Basic check that the debug info file is being correctly saved and loaded.
        #
        # CHECK: "tf.AddV2"{{.*}}loc(#[[LOC:.*]])
        # CHECK: #[[LOC]] = loc({{.*}}callsite("{{[^"]*}}/debug_info.py{{.*}}":{{[0-9]+}}:{{[0-9]+}}


if __name__ == '__main__':
    common.do_test(TestModule, show_debug_info=True)
示例#4
0
import tensorflow.compat.v2 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common


def mnist_model():
    """Creates a MNIST model."""
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128, activation='relu'))
    model.add(tf.keras.layers.Dense(10, activation='softmax'))
    return model


class TestModule(tf.Module):
    def __init__(self):
        super(TestModule, self).__init__()
        self.model = mnist_model()

    # CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<1x28x28x1xf32> {tf_saved_model.index_path = [0]}
    # CHECK: attributes {{.*}} tf_saved_model.exported_names = ["my_predict"]
    @tf.function(input_signature=[
        tf.TensorSpec([1, 28, 28, 1], tf.float32),
    ])
    def my_predict(self, x):
        return self.model(x)


if __name__ == '__main__':
    common.do_test(TestModule, exported_names=['my_predict'])