Example #1
0
    def test_function_with_default_bool_input(self, cycles):
        def func(x, training=False):
            if training:
                return 2 * x
            else:
                return 7

        root = tracking.AutoTrackable()
        root.f = def_function.function(func)

        self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
        self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
        self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())

        imported = self.cycle(root, cycles)

        self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
        self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
Example #2
0
  def test_prefer_specific_trace(self, cycles):
    @def_function.function(autograph=False)
    def func(a):
      if isinstance(a, int):
        return a
      else:
        return a + 1

    self.assertAllEqual(2, func(2).numpy())
    self.assertAllEqual(3, func(constant_op.constant(2)).numpy())

    root = tracking.AutoTrackable()
    root.f = func
    root = self.cycle(root, cycles)
    self.assertAllEqual(2, root.f(2).numpy())
    self.assertAllEqual(4, root.f(3).numpy())
    self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy())
    self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy())
  def _create_unsupported_saved_model(self):
    root = tracking.AutoTrackable()
    root.w = variables.Variable(tf.random.uniform([2, 2]))

    @def_function.function
    def exported_function(x):
      root.x = constant_op.constant([[37.0, -23.0], [1.0, 4.0]])
      root.y = tf.matmul(root.x, root.w)
      # unsupported op: linalg.diag
      root.z = tf.linalg.diag(root.y)
      return root.z * x

    root.f = exported_function
    to_save = root.f.get_concrete_function(
        tensor_spec.TensorSpec([], dtypes.float32))

    save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
    save(root, save_dir, to_save)
    def test_non_strict_predicate(self):
        class NonStrictPredicateClass(tracking.AutoTrackable):
            pass

        registration.register_checkpoint_saver(
            name="NonStrictPredicate",
            predicate=lambda x: isinstance(x, NonStrictPredicateClass),
            save_fn=lambda **kwargs: [],
            restore_fn=lambda **kwargs: None,
            strict_predicate_restore=False)

        root = NonStrictPredicateClass()
        ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
        util.Checkpoint(root).write(ckpt_path)

        root2 = tracking.AutoTrackable()
        # This should run without throwing an error.
        util.Checkpoint(root2).read(ckpt_path)
Example #5
0
    def test_captures_unreachable_variable(self):
        root = tracking.AutoTrackable()
        unreachable_variable = variables.Variable([5.0, 2.0])
        root.reachable_variable = variables.Variable([1.0, 3.0])

        @def_function.function
        def increase_variable(x):
            return 2 * unreachable_variable * x + root.reachable_variable

        root.f = increase_variable

        self.assertAllEqual([101.0, 83.0],
                            root.f(constant_op.constant([10.0, 20.0])).numpy())

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")

        with self.assertRaisesRegexp(KeyError, "not reachable from root"):
            save.save(root, save_dir)
    def test_strict_predicate(self):
        class StrictPredicateClass(tracking.AutoTrackable):
            pass

        registration.register_checkpoint_saver(
            name="StrictPredicate",
            predicate=lambda x: isinstance(x, StrictPredicateClass),
            save_fn=lambda **kwargs: [],
            restore_fn=lambda **kwargs: None,
            strict_predicate_restore=True)

        root = StrictPredicateClass()
        ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
        util.Checkpoint(root).write(ckpt_path)

        root2 = tracking.AutoTrackable()
        with self.assertRaisesRegex(ValueError, "saver cannot be used"):
            util.Checkpoint(root2).read(ckpt_path)
Example #7
0
  def test_partial_with_passed_fn_as_default(self, cycles):
    # TODO(b/124441704): Figure out the story for FunctionSpec vs partial.
    self.skipTest("Partial does not work for serialization.")

    def f(x, y):
      return x(3) + y

    def my_func(a):
      return 2 * a

    func = def_function.function(functools.partial(f, my_func))

    root = tracking.AutoTrackable()
    root.f = func
    self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)

    root = self.cycle(root, cycles)
    self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
Example #8
0
    def test_method_save_list_func(self):
        root = tracking.AutoTrackable()

        @def_function.function
        def case_fn(x):
            branch_index = constant_op.constant(1)
            branches = [lambda: x, lambda: x + 1]
            case_out = control_flow_ops.switch_case(branch_index, branches)
            return case_out

        root.f = def_function.function(
            lambda x: 2. * case_fn(x),
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.f(constant_op.constant(1.))
        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save.save(root, save_dir, root.f)
        self.assertEqual({"output_0": 4.},
                         _import_and_infer(save_dir, {"x": 1.}))
Example #9
0
    def testConstModel(self):
        """Test a basic model with functions to make sure functions are inlined."""
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.f = def_function.function(lambda x: 2. * x)
        to_save = root.f.get_concrete_function(input_data)

        save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
        save(root, save_dir, to_save)

        # Convert model and ensure model is not None.
        converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
        tflite_model = converter.convert()

        # Check values from converted model.
        expected_value = root.f(input_data)
        actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
        self.assertEqual(expected_value.numpy(), actual_value)
Example #10
0
  def test_initialize_with_root_object_and_kwargs(self):
    model = self._create_trackable()
    model.v.assign(3.)
    separate_variable = variables_lib.Variable(5.)

    with self.assertRaisesRegex(ValueError, "root.v already exists"):
      trackable_utils.Checkpoint(model, v=separate_variable)

    checkpoint = trackable_utils.Checkpoint(
        model, separate_variable=separate_variable)
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = checkpoint.save(checkpoint_prefix)

    # Case 1: Loading checkpoint with same configuration.
    new_model = self._create_trackable()
    separate_variable = variables_lib.Variable(1.)
    load_checkpoint = trackable_utils.Checkpoint(
        new_model, separate_variable=separate_variable)
    load_checkpoint.restore(save_path).assert_consumed()
    self.assertEqual(self.evaluate(new_model.v), 3)
    self.assertEqual(self.evaluate(separate_variable), 5)
    self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)

    # Case 2: Loading checkpoint where v and separate_variable are swapped:
    # v is not attached to the root, while separate variable is attached to root
    new_model = tracking.AutoTrackable()
    new_model.separate_variable = variables_lib.Variable(200.)
    v = variables_lib.Variable(100.)
    load_checkpoint = trackable_utils.Checkpoint(new_model, v=v)
    load_checkpoint.restore(save_path).assert_consumed()
    self.assertEqual(self.evaluate(v), 3)
    self.assertEqual(self.evaluate(new_model.separate_variable), 5)
    self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)

    # Case 3: Loading checkpoint where no root object is specified
    separate_variable = variables_lib.Variable(200.)
    v = variables_lib.Variable(100.)
    load_checkpoint = trackable_utils.Checkpoint(
        v=v, separate_variable=separate_variable)
    load_checkpoint.restore(save_path).assert_consumed()
    self.assertEqual(self.evaluate(v), 3)
    self.assertEqual(self.evaluate(new_model.separate_variable), 5)
    self.assertEqual(self.evaluate(load_checkpoint.save_counter), 1)
Example #11
0
  def test_control_outputs(self, cycles):
    exported = tracking.AutoTrackable()
    exported.v = variables.Variable(1.)
    exported.f = def_function.function(
        lambda: exported.v.assign(2., name="should_be_control_output"))
    exported_graph = exported.f.get_concrete_function().graph
    self.assertIn(
        exported_graph.get_operation_by_name("should_be_control_output"),
        exported_graph.control_outputs)

    imported = self.cycle(exported, cycles)
    # Calling get_concrete_function wraps in a second call operation; we want to
    # inspect the original function body for the control output; digging into
    # graph.as_graph_def() and its FunctionDefLibrary is another option.
    imported_concrete, = imported.f._concrete_functions
    imported_graph = imported_concrete.graph
    self.assertIn(
        imported_graph.get_operation_by_name("should_be_control_output"),
        imported_graph.control_outputs)
Example #12
0
    def testTxtSignatureDefs(self):
        with tempfile.TemporaryDirectory() as tmp_dir:

            @tf.function(input_signature=[
                tf.TensorSpec(shape=None, dtype=tf.float32),
                tf.TensorSpec(shape=None, dtype=tf.float32)
            ])
            def add(a, b):
                return {'add_result': tf.add(a, b)}

            @tf.function(input_signature=[
                tf.TensorSpec(shape=None, dtype=tf.float32),
                tf.TensorSpec(shape=None, dtype=tf.float32)
            ])
            def sub(x, y):
                return {'sub_result': tf.subtract(x, y)}

            root = tracking.AutoTrackable()
            root.f1 = add.get_concrete_function()
            root.f2 = sub.get_concrete_function()

            tf.saved_model.save(root,
                                tmp_dir,
                                signatures={
                                    'add': root.f1,
                                    'sub': root.f2
                                })

            converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
            fb_model = converter.convert()
            mock_stdout = io.StringIO()
            with test.mock.patch.object(sys, 'stdout', mock_stdout):
                analyzer.ModelAnalyzer.analyze(model_content=fb_model)
            txt = mock_stdout.getvalue()
            self.assertIn('Your TFLite model has ‘2’ signature_def(s).', txt)
            self.assertIn("Signature#0 key: 'add'", txt)
            self.assertIn("  'a' : T#1", txt)
            self.assertIn("  'b' : T#0", txt)
            self.assertIn("  'add_result' : T#2", txt)
            self.assertIn("Signature#1 key: 'sub'", txt)
            self.assertIn("  'x' : T#1_1", txt)
            self.assertIn("  'y' : T#1_0", txt)
            self.assertIn("  'sub_result' : T#1_2", txt)
Example #13
0
    def test_revived_concrete_function_tensorspec_kwargs(self, cycles):
        @def_function.function
        def func(*args):
            x, y = args
            return x * (y + 1.)

        root = tracking.AutoTrackable()
        root.f = func.get_concrete_function(
            tensor_spec.TensorSpec([], dtypes.float32, name="x"),
            tensor_spec.TensorSpec([], dtypes.float32, name="y"))
        self.assertEqual(
            8.,
            root.f(y=constant_op.constant(3.),
                   x=constant_op.constant(2.)).numpy())
        imported = self.cycle(root, cycles, signatures={})
        self.assertEqual(
            8.,
            imported.f(y=constant_op.constant(3.),
                       x=constant_op.constant(2.)).numpy())
Example #14
0
    def test_partial(self, cycles):
        # TODO(vbardiovsky): Figure out the story for FunctionSpec vs partial vs
        # input_signature.
        self.skipTest("Partial does not work for serialization.")

        def f(x, y):
            return x + y

        func = def_function.function(
            functools.partial(f,
                              x=array_ops.zeros([1]),
                              y=array_ops.zeros([1])))

        root = tracking.AutoTrackable()
        root.f = func
        self.assertAllEqual(root.f(), [0.0])

        root = self.cycle(root, cycles)
        self.assertAllEqual(root.f(), [0.0])
Example #15
0
def _create_saved_model_v2_complex64(save_dir):
  """Test a TF V2 model with complex dtype.

  Args:
    save_dir: directory name of where the saved model will be stored.
  """
  input_data = constant_op.constant(1., shape=[1])
  root = tracking.AutoTrackable()
  root.v1 = variables.Variable(3 + 1j, dtype=tf.complex64)
  root.f = def_function.function(lambda x: tf.complex(x, x) + root.v1)
  to_save = root.f.get_concrete_function(input_data)

  save(root, save_dir, to_save)
  return {
      "async": False,
      "inputs": {
          "x": {"value": [1], "shape": [1], "dtype": 'float32'}},
      "outputs": {
          "Identity:0": {"value": [4, 2], "shape": [1], "dtype": "complex64"}}}
 def testCheckpointStateChangingVarList(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     obj = tracking.AutoTrackable()
     obj.var = variable_scope.get_variable(name="v", initializer=0.)
     self.evaluate(trackable_utils.gather_initializers(obj))
     checkpoint = trackable_utils.Checkpoint(obj=obj)
     looped_variables = []
     for iteration in range(10):
         new_variable = resource_variable_ops.ResourceVariable(iteration)
         self.evaluate(new_variable.initializer)
         setattr(checkpoint, "var_%d" % iteration, new_variable)
         checkpoint.save(checkpoint_prefix)
         looped_variables.append(new_variable)
     expected_filenames = ["checkpoint"]
     # We've copied the saver each time, but checkpoint management should still
     # be consistent. Nothing gets deleted.
     for checkpoint_number in range(1, 11):
         expected_filenames.append("ckpt-%d.index" % (checkpoint_number, ))
     self.assertEmpty(
         set(expected_filenames) - set(os.listdir(checkpoint_directory)))
     self.assertEqual(
         checkpoint_prefix + "-10",
         checkpoint_management.latest_checkpoint(checkpoint_directory))
     # The checkpoint list only contains the most recent checkpoint, but they're
     # all on disk. This means we won't eventually run into proto size limits.
     self.assertEqual([checkpoint_prefix + "-10"],
                      (checkpoint_management.get_checkpoint_state(
                          checkpoint_directory).all_model_checkpoint_paths))
     for v in looped_variables:
         self.evaluate(v.assign(314))
     checkpoint.restore(checkpoint_prefix + "-6").run_restore_ops()
     self.assertEqual(314, self.evaluate(checkpoint.var_9))
     self.assertEqual(314, self.evaluate(checkpoint.var_8))
     self.assertEqual(314, self.evaluate(checkpoint.var_6))
     self.assertEqual(5, self.evaluate(checkpoint.var_5))
     self.assertEqual(1, self.evaluate(checkpoint.var_1))
     self.assertEqual(0, self.evaluate(checkpoint.var_0))
     checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
     self.assertEqual(9, self.evaluate(checkpoint.var_9))
     self.assertEqual(8, self.evaluate(checkpoint.var_8))
     self.assertEqual(1, self.evaluate(checkpoint.var_1))
     self.assertEqual(0, self.evaluate(checkpoint.var_0))
    def testConstSavedModel(self):
        """Test a basic model with functions to make sure functions are inlined."""
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = tracking.AutoTrackable()
        root.f = def_function.function(lambda x: 2. * x)
        to_save = root.f.get_concrete_function(input_data["x"])

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        input_func = saved_model.signatures["serving_default"]

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(variable_graph_def))
        self.assertTrue(variable_graph_def.library.function)

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        self._testConvertedFunction(root, root.f, output_func, input_data)
Example #18
0
  def test_weakref_root(self):
    root = tracking.AutoTrackable()
    root.v = variables_lib.Variable(1)
    ref = root.v.ref()

    ckpt = trackable_utils.Checkpoint(root=weakref.ref(root))
    save_path = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
    root.v.assign(2)
    ckpt.restore(save_path)
    self.assertEqual(root.v.numpy(), 1)

    del root

    # Verifying if the variable is only referenced from `ref`.
    # We expect the reference counter to be 1, but `sys.getrefcount` reports
    # one higher reference counter because a temporary is created when we call
    # sys.getrefcount().  Hence check if the number returned is 2.
    # https://docs.python.org/3/library/sys.html#sys.getrefcount
    self.assertEqual(sys.getrefcount(ref.deref()), 2)
Example #19
0
    def test_shapes_available(self, cycles):
        @def_function.function(input_signature=[
            tensor_spec.TensorSpec([None, 3], dtypes.int32),
            tensor_spec.TensorSpec([None, 2], dtypes.int32)
        ])
        def func(x, y):
            return array_ops.concat([x, y], axis=1)

        root = tracking.AutoTrackable()
        root.f = func

        root = self.cycle(root, cycles)

        imported_graph = root.f.get_concrete_function().graph
        input_x, input_y = imported_graph.inputs
        self.assertEqual([None, 3], input_x.shape.as_list())
        self.assertEqual([None, 2], input_y.shape.as_list())
        output, = imported_graph.outputs
        self.assertEqual([None, 5], output.shape.as_list())
Example #20
0
    def test_revived_concrete_function_kwargs(self, cycles):
        @def_function.function
        def func(x, y):
            return x * (y + 1.)

        root = tracking.AutoTrackable()
        root.f = func.get_concrete_function(
            tensor_spec.TensorSpec([], dtypes.float32),
            tensor_spec.TensorSpec([], dtypes.float32))
        self.assertEqual(
            8.,
            root.f(y=constant_op.constant(3.),
                   x=constant_op.constant(2.)).numpy())
        # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
        imported = self.cycle(root, cycles, signatures={})
        self.assertEqual(
            8.,
            imported.f(y=constant_op.constant(3.),
                       x=constant_op.constant(2.)).numpy())
Example #21
0
    def test_save_variable_devices(self, save_devices, meta_graph_only):
        context._reset_context()
        cpus = context.context().list_physical_devices("CPU")
        if len(cpus) == 1:
            context.context().set_logical_device_configuration(
                cpus[0], [
                    context.LogicalDeviceConfiguration(),
                    context.LogicalDeviceConfiguration()
                ])
        context.ensure_initialized()

        root = tracking.AutoTrackable()
        with ops.device("CPU:0"):
            root.v0 = variables.Variable(1., name="v0")
        with ops.device("CPU:1"):
            root.v1 = variables.Variable(1., name="v1")

        options = save_options.SaveOptions(
            experimental_variable_policy=save_devices)
        file_name = os.path.join(self.get_temp_dir(), "saved_model")
        if meta_graph_only:
            save.export_meta_graph(obj=root,
                                   filename=file_name,
                                   options=options)
        else:
            save.save(obj=root, export_dir=file_name, options=options)

        graph_def = None
        if meta_graph_only:
            graph_def = meta_graph.read_meta_graph_file(file_name).graph_def
        else:
            graph_def = loader_impl.parse_saved_model(
                file_name).meta_graphs[0].graph_def
        v0 = next((n for n in graph_def.node if n.name == "v0"), None)
        v1 = next((n for n in graph_def.node if n.name == "v1"), None)
        self.assertIsNotNone(v0)
        self.assertIsNotNone(v1)
        if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
            self.assertIn("CPU:0", v0.device)
            self.assertIn("CPU:1", v1.device)
        else:
            self.assertEmpty(v0.device)
            self.assertEmpty(v1.device)
Example #22
0
    def _create_saved_model_with_debug_ops(self):
        root = tracking.AutoTrackable()
        root.w = variables.Variable(tf.random.uniform([2, 2]))

        @def_function.function
        def exported_function(x):
            root.x = constant_op.constant([[37.0, -23.0], [1.0, 4.0]])
            root.y = tf.matmul(root.x, root.w)
            tf.compat.v1.Print(root.x, [root.x])
            tf.compat.v1.Assert(tf.greater(tf.reduce_max(root.x), 0), [root.x])
            tf.compat.v1.check_numerics(root.x, 'NaN found')
            return root.y * x

        root.f = exported_function
        to_save = root.f.get_concrete_function(
            tensor_spec.TensorSpec([], dtypes.float32))

        save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
        save(root, save_dir, to_save)
Example #23
0
  def test_save_composite_tensor_signature(self):
    @def_function.function(
        input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)])
    def f(x):
      return {"output_key": x}
    root = tracking.AutoTrackable()
    path = os.path.join(self.get_temp_dir(), "saved_model")
    inp = ragged_factory_ops.constant([[[1.0, 2.0], [3.0]], [[5.]]])
    flat_inp = {
        "x": constant_op.constant([1., 2., 3., 5]),
        "x_1": constant_op.constant([0, 2, 3], dtype=dtypes.int64),
        "x_2": constant_op.constant([0, 2, 3, 4], dtype=dtypes.int64)
    }
    save.save(root, path, signatures={"key": f.get_concrete_function()})

    # Test that the ragged signature can be loaded back into Python with V2 APIs
    imported = load.load(path)
    self.assertAllEqual(inp,
                        imported.signatures["key"](**flat_inp)["output_key"])
    graph = ops.Graph()

    # Try running the signature with V1 APIs.
    with graph.as_default(), session_lib.Session() as session:
      meta_graph_def = loader.load(session, [tag_constants.SERVING], path)
      signature = meta_graph_def.signature_def["key"]

      feed_dict = {}
      for arg_name in flat_inp:
        input_tensor = session.graph.get_tensor_by_name(
            signature.inputs[arg_name].name)
        feed_dict[input_tensor] = flat_inp[arg_name].numpy()

      # Get composite tensor components
      output_components = (
          signature.outputs["output_key"].composite_tensor.components)
      fetches = {}
      components_keys = ["x", "x_1", "x_2"]
      for k, output_tensor_info in zip(components_keys, output_components):
        fetches[k] = session.graph.get_tensor_by_name(output_tensor_info.name)

      outputs = session.run(fetches, feed_dict)

    self.assertAllClose(flat_inp, outputs)
Example #24
0
  def test_concrete_function_backprop(self, cycles):
    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.float32)])
    def func(x):
      return x ** 2.
    root = tracking.AutoTrackable()
    root.f = func.get_concrete_function()

    def _compute_gradient(function):
      with backprop.GradientTape() as tape:
        inp = constant_op.constant(1.)
        tape.watch(inp)
        output = function(inp)
      return tape.gradient(output, inp)

    self.assertEqual(2., _compute_gradient(root.f).numpy())
    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = self.cycle(root, cycles, signatures={})
    self.assertEqual(2., _compute_gradient(imported.f).numpy())
Example #25
0
 def test_captured_constant(self, cycles):
   const = array_ops.zeros([100])
   root = tracking.AutoTrackable()
   root.f = def_function.function(lambda: const + 1.)
   root.g = def_function.function(lambda: const + 2.)
   self.assertAllClose(array_ops.ones([100]), root.f())
   self.assertAllClose(2. * array_ops.ones([100]), root.g())
   imported = self.cycle(root, cycles)
   self.assertAllClose(array_ops.ones([100]), imported.f())
   self.assertAllClose(2. * array_ops.ones([100]), imported.g())
   # TODO(b/123408994): Use the public get_concrete_function.
   f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0]
   g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0]
   self.assertLen(f_concrete.captured_inputs, 1)
   self.assertLen(g_concrete.captured_inputs, 1)
   # We should be using the same captured EagerTensor in both functions, not
   # duplicating the constant.
   self.assertIs(f_concrete.captured_inputs[0],
                 g_concrete.captured_inputs[0])
Example #26
0
  def test_named_tuple(self, cycles):

    class NamedTupleType(collections.namedtuple("NamedTupleType", ["a", "b"])):
      pass

    @def_function.function
    def f(x):
      return x.a + x.b

    f.get_concrete_function(
        NamedTupleType(
            a=tensor_spec.TensorSpec(None, dtypes.float32, name="a"),
            b=tensor_spec.TensorSpec(None, dtypes.float32, name="b")))
    obj = tracking.AutoTrackable()
    obj.__call__ = f
    imported = self.cycle(obj, cycles)
    self.assertAllClose(3.,
                        imported(NamedTupleType(a=constant_op.constant(1.),
                                                b=constant_op.constant(2.))))
Example #27
0
  def test_functions_list(self, cycles):
    root = tracking.AutoTrackable()
    v1 = variables.Variable(1.)
    root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))]
    root.variables = [v1]

    @def_function.function
    def _v2_loss():
      if len(root.variables) == 1:
        v2 = variables.Variable(2.)
        root.variables.append(v2)
      return math_ops.reduce_sum(root.variables[1] ** 2)

    root.losses.append(_v2_loss)
    self.assertAllClose([1., 4.], [loss() for loss in root.losses])
    imported = self.cycle(root, cycles)
    self.assertAllClose([1., 4.], [loss() for loss in imported.losses])
    imported.variables[0].assign(3.)
    imported.variables[1].assign(4.)
    self.assertAllClose([9., 16.], [loss() for loss in imported.losses])
Example #28
0
  def test_concrete_function(self, cycles):

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
    def func(x):
      return 2 * x

    root = tracking.AutoTrackable()
    root.f = func.get_concrete_function()

    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())

    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
    imported = self.cycle(root, cycles, signatures={})

    self.assertAllEqual([2, 4, 6, 8],
                        imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
    self.assertAllEqual([2, 4, 6],
                        imported.f(constant_op.constant([1, 2, 3])).numpy())
Example #29
0
  def test_load_in_graph_mode(self, cycles):
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(1.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(
        lambda x: root.v2 * x,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

    if cycles > 1:
      root = self.cycle(root, cycles - 1)
    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
    save.save(root, path)

    with ops.Graph().as_default():
      imported = load.load(path)
      var_v1 = imported.v1
      output = imported.f(constant_op.constant(2.))
      with monitored_session.MonitoredSession() as sess:
        self.assertEqual(1.0, sess.run(var_v1))
        self.assertEqual(4.0, sess.run(output))
Example #30
0
    def testScalarModel(self):
        """Test a basic model with Variables."""
        input_data = {"x": constant_op.constant(1., shape=[])}
        root = tracking.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        input_func = root.f.get_concrete_function(input_data["x"])

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(2, self._getNumVariables(variable_graph_def))

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        self._testConvertedFunction(root, root.f, output_func, input_data)