コード例 #1
0
 def testListWrapperBasic(self):
     # ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = autotrackable.AutoTrackable()
     b = autotrackable.AutoTrackable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures.ListWrapper([a, a]),
                      data_structures.ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures.ListWrapper([a, a]))
     self.assertEqual(data_structures.ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures.ListWrapper([a, a]),
                         data_structures.ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures.ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures.ListWrapper([a]),
                     data_structures.ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures.ListWrapper([a]),
                          data_structures.ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures.ListWrapper([a, b]),
                        data_structures.ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
                             data_structures.ListWrapper([a]))
     self.assertEqual([a], data_structures.ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
     self.assertIsInstance(data_structures.ListWrapper([a]), list)
     self.assertEqual(
         tensor_shape.TensorShape([None, 2]).as_list(),
         (data_structures.ListWrapper([None]) +
          tensor_shape.TensorShape([2])).as_list())
コード例 #2
0
 def testShallowCopyTrackable(self):
     original = autotrackable.AutoTrackable()
     original_sub = autotrackable.AutoTrackable()
     original.a = [[1.]]
     original.b = {"a": original_sub}
     shallow_copied = copy.copy(original)
     self.assertIs(original_sub, shallow_copied.b["a"])
     self.assertIsNot(original, shallow_copied)
     self.assertEqual([[1.]], shallow_copied.a)
     shallow_deps = util.list_objects(shallow_copied)
     self.assertIn(shallow_copied.a, shallow_deps)
     self.assertIn(shallow_copied.b, shallow_deps)
     self.assertIn(shallow_copied.b["a"], shallow_deps)
コード例 #3
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
    def test_method_save_annotated_function(self):
        # This test is only meaningful with Python 3 because Python 2's
        # inspect.getargspec doesn't save annotations.

        root = autotrackable.AutoTrackable()

        class UnknownType(object):  # pylint: disable=unused-variable
            pass

        def annotated_function(z):
            return {"out": 2. * z}

        # Same effect as annotating function like the following.
        # def annotated_function("z": UnknownType) -> UnknownType:
        # This is a workaround since Python 2 does not support annotations and
        # our presubmit linter catches it.
        annotated_function.__annotations__ = {
            "z": UnknownType,
            "return": UnknownType
        }

        root.f = def_function.function(annotated_function)
        root.f(constant_op.constant(1.))
        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save.save(
            root, save_dir, {
                "non_default_key":
                root.f.get_concrete_function(
                    tensor_spec.TensorSpec(None, dtypes.float32))
            })
        self.assertEqual({"out": 2.},
                         _import_and_infer(save_dir, {"z": 1.},
                                           signature_key="non_default_key"))
コード例 #4
0
    def testL2LossOp(self, tf_quantization_mode):
        root = autotrackable.AutoTrackable()
        root.l2_loss_func = def_function.function(lambda x: nn_ops.l2_loss(x))  # pylint: disable=unnecessary-lambda
        input_data = tf.range(4, dtype=tf.float32)
        concrete_func = root.l2_loss_func.get_concrete_function(input_data)

        converter = lite.TFLiteConverterV2.from_concrete_functions(
            [concrete_func], root)
        converter._experimental_tf_quantization_mode = tf_quantization_mode
        tflite_model = converter.convert()
        self.assertTrue(tflite_model)
        self.assertIn('FlexL2Loss',
                      tflite_test_util.get_ops_list(tflite_model))

        # Check the model works.
        interpreter = Interpreter(model_content=tflite_model)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        test_input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
        interpreter.set_tensor(input_details[0]['index'], test_input)
        interpreter.invoke()

        output_details = interpreter.get_output_details()
        expected_output = np.array([15.0], dtype=np.float32)
        output_data = interpreter.get_tensor(output_details[0]['index'])
        self.assertTrue((expected_output == output_data).all())
コード例 #5
0
    def testAddOp(self, tf_quantization_mode):
        root = autotrackable.AutoTrackable()
        root.add_func = def_function.function(lambda x: x + x)
        input_data = tf.reshape(tf.range(4, dtype=tf.float32), [1, 4])
        concrete_func = root.add_func.get_concrete_function(input_data)

        # Convert model and check if the op is not flex.
        converter = lite.TFLiteConverterV2.from_concrete_functions(
            [concrete_func], root)
        converter._experimental_tf_quantization_mode = tf_quantization_mode
        tflite_model = converter.convert()
        self.assertTrue(tflite_model)
        if tf_quantization_mode == 'LEGACY_INTEGER':
            self.assertIn('ADD', tflite_test_util.get_ops_list(tflite_model))
        else:
            self.assertIn('FlexAddV2',
                          tflite_test_util.get_ops_list(tflite_model))

        # Check the model works.
        interpreter = Interpreter(model_content=tflite_model)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
        interpreter.set_tensor(input_details[0]['index'], test_input)
        interpreter.invoke()

        output_details = interpreter.get_output_details()
        expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
        output_data = interpreter.get_tensor(output_details[0]['index'])
        self.assertTrue((expected_output == output_data).all())
コード例 #6
0
    def testFloat(self, enable_mlir):
        input_data = constant_op.constant(1., shape=[1])
        root = autotrackable.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        concrete_func = root.f.get_concrete_function(input_data)

        # Convert model.
        converter = lite.TFLiteConverterV2.from_concrete_functions(
            [concrete_func], root)
        converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
        converter.experimental_new_converter = enable_mlir
        tflite_model = converter.convert()

        # Check the model works with TensorFlow ops.
        interpreter = Interpreter(model_content=tflite_model)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        test_input = np.array([4.0], dtype=np.float32)
        interpreter.set_tensor(input_details[0]['index'], test_input)
        interpreter.invoke()

        output_details = interpreter.get_output_details()
        expected_output = np.array([24.0], dtype=np.float32)
        output_data = interpreter.get_tensor(output_details[0]['index'])
        self.assertTrue((expected_output == output_data).all())
コード例 #7
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
 def test_nested_outputs(self):
     root = autotrackable.AutoTrackable()
     root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
     root.f(constant_op.constant(1.))
     to_save = root.f.get_concrete_function(constant_op.constant(1.))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegex(ValueError, "non-Tensor value"):
         save.save(root, save_dir, to_save)
コード例 #8
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
 def test_non_concrete_error(self):
     root = autotrackable.AutoTrackable()
     root.f = def_function.function(lambda x: 2. * x)
     root.f(constant_op.constant(1.))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegex(ValueError,
                                 "Expected a TensorFlow function"):
         save.save(root, save_dir, root.f)
コード例 #9
0
 def testDeepCopyTrackable(self):
     original = autotrackable.AutoTrackable()
     original_sub = autotrackable.AutoTrackable()
     original.a = [[1.]]
     original.b = {"a": original_sub}
     self.assertIsInstance(original.b, dict)
     deep_copied = copy.deepcopy(original)
     self.assertIsInstance(deep_copied.b, dict)
     self.assertIsNot(original, deep_copied)
     self.assertIsNot(original_sub, deep_copied.b["a"])
     self.assertEqual([[1.]], deep_copied.a)
     self.assertIsInstance(deep_copied.b["a"], autotrackable.AutoTrackable)
     deps = util.list_objects(deep_copied)
     self.assertIn(deep_copied.a, deps)
     self.assertIn(deep_copied.b, deps)
     self.assertIn(deep_copied.b["a"], deps)
     self.assertNotIn(original_sub, deps)
コード例 #10
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
 def test_nested_inputs(self):
     root = autotrackable.AutoTrackable()
     root.f = def_function.function(
         lambda x: 2. * x[0],
         input_signature=([
             tensor_spec.TensorSpec(None, dtypes.float32),
             tensor_spec.TensorSpec(None, dtypes.float32)
         ], ))
     root.f([constant_op.constant(1.), constant_op.constant(1.)])
コード例 #11
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
 def test_variable(self):
     root = autotrackable.AutoTrackable()
     root.v1 = variables.Variable(3.)
     root.v2 = variables.Variable(2.)
     root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
     root.f(constant_op.constant(1.))
     to_save = root.f.get_concrete_function(constant_op.constant(1.))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(root, save_dir, to_save)
     self.assertAllEqual({"output_0": 12.},
                         _import_and_infer(save_dir, {"x": 2.}))
コード例 #12
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
 def test_version_information_included(self):
     root = autotrackable.AutoTrackable()
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(root, save_dir)
     saved_model_proto = loader_impl.parse_saved_model(save_dir)
     self.assertEqual(
         versions.__version__,
         saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_version)
     self.assertEqual(
         versions.__git_version__, saved_model_proto.meta_graphs[0].
         meta_info_def.tensorflow_git_version)
コード例 #13
0
ファイル: save_test.py プロジェクト: wwjiang007/tensorflow
    def test_unused_asset(self):
        root = autotrackable.AutoTrackable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset = asset.Asset(self._vocab_path)

        export_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, export_dir)
        self.assertAllClose({"output_0": [0.2]},
                            _import_and_infer(export_dir, {"x": [0.1]}))
コード例 #14
0
  def testSavedModel(self):
    input_data = constant_op.constant(1., shape=[1])
    root = autotrackable.AutoTrackable()
    root.f = def_function.function(lambda x: 2. * x)
    to_save = root.f.get_concrete_function(input_data)

    saved_model_dir = self._getFilepath('model')
    save(root, saved_model_dir, to_save)

    flags_str = '--saved_model_dir={}'.format(saved_model_dir)
    self._run(flags_str, should_succeed=True)