Example #1
0
    def test_cycles_with_path(self):
        mod = module.Module()
        mod.w = variables.Variable(1.)
        mod.encoder = module.Module()
        mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
        mod.decoder = mod.encoder

        # This introduces two cycles: on mod.encoder.mod and mod.decoder.mod.
        mod.decoder.mod = mod

        state_dict = dict(
            mod._flatten(with_path=True, predicate=module._is_variable))

        self.assertEqual(
            state_dict,
            {
                ("w", ): mod.w,
                ("encoder", "mod", "w"): mod.encoder.mod.w,
                ("decoder", "mod", "w"): mod.decoder.mod.w,
                ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
                ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
                ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
                ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]
            },
        )
Example #2
0
  def _testConvertedFunction(self, obj, func, converted_concrete_func,
                             input_data):
    # Ensure the converted graph has no variables and no function calls.
    constant_graph_def = converted_concrete_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check that the converted ConcreteFunction produces the same result as the
    # original Function.
    expected_value = nest.flatten(func(**input_data))
    actual_value = nest.flatten(converted_concrete_func(**input_data))

    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(expected.numpy(), actual.numpy())

    # Ensure the shape is retained.
    for tensor in converted_concrete_func.inputs:
      actual_shape = input_data[tensor.name.split(":")[0]].shape
      self.assertEqual(tensor.shape, actual_shape)

    # Save the converted ConcreteFunction as a signature.
    save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
    root = module.Module()
    root.f = converted_concrete_func
    save(root, save_dir, {"mykey": converted_concrete_func})

    # Load it back and make sure it works.
    loaded_obj = load(save_dir)
    actual_value = nest.flatten(loaded_obj.signatures["mykey"](**input_data))
    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(expected.numpy(), actual.numpy())
Example #3
0
 def test_model_wrapped_in_module_discovers_submodules(self):
   linear = models.Sequential([layers.Dense(units=1, input_shape=[1])])
   linear.compile(optimizer="sgd", loss="mean_squared_error")
   m = module.Module()
   m.l = linear
   self.assertNotEmpty(m.submodules)
   self.assertLen(m.variables, 2)
 def testLoopAssignedModule(self):
     m = module.Module()
     m.s = (m, )
     self.assertLen(m._trackable_children(), 1)
     self.assertIn("s", m._trackable_children())
     self.assertIs(m.s, m._trackable_children()["s"])
     self.assertEqual((), m.trainable_variables)
Example #5
0
 def testLoopAssignedModule(self):
     m = module.Module()
     m.s = (m, )
     self.assertLen(m._checkpoint_dependencies, 1)
     self.assertIs(m.s, m._checkpoint_dependencies[0].ref)
     self.assertIs("s", m._checkpoint_dependencies[0].name)
     self.assertEqual((), m.trainable_variables)
 def testNamedTupleConflictingAttributes(self):
   named = collections.namedtuple("Named", ("x", "weights"))
   v = variables.Variable(2)
   nt = named(x=v, weights=3)
   m = module.Module()
   m.nt = nt
   self.assertEqual(3, m.nt.weights)
Example #7
0
    def test_saved_model(self):
        different_values = self.device.pack(
            [constant_op.constant(-1.),
             constant_op.constant(3.)])
        with self.device:
            m = module.Module()
            m.v = variables.Variable(different_values)
            m.f = def_function.function(lambda: m.v * 2.)
            self.assertAllClose([-2., 6.], self.device.unpack(m.f()))
        saved_model_path = os.path.join(self.get_temp_dir(), "saved_model")
        save.save(m, saved_model_path)

        context._reset_context()
        self.setUp()

        single_device_loaded = load.load(saved_model_path)
        self.assertAllClose(-2., single_device_loaded.f())
        assign_value = self.device.pack(
            [constant_op.constant(.1),
             constant_op.constant(.2)])
        with self.device:
            parallel_loaded = load.load(saved_model_path)
            self.assertAllClose([-2., 6.],
                                self.device.unpack(parallel_loaded.f()))
            self.assertAllClose([-1., 3.],
                                self.device.unpack(parallel_loaded.v))
            parallel_loaded.v.assign(assign_value)
            self.assertAllClose([.2, .4],
                                self.device.unpack(parallel_loaded.f()))
Example #8
0
    def test_composite_variable(self):
        class Spec(type_spec.TypeSpec):

            value_type = property(lambda self: CompositeVariable)

            def _component_specs(self):
                pass

            def _serialize(self):
                pass

            def _to_components(self, value):
                return value._variables

            def _from_components(self, variable_list):
                return CompositeVariable(variable_list)

        class CompositeVariable(composite_tensor.CompositeTensor):
            def __init__(self, variable_list):
                self._variables = variable_list

            @property
            def _type_spec(self):
                return Spec()

        m = module.Module()
        m.a = CompositeVariable(
            [variables.Variable(1.),
             variables.Variable(2.)])
        self.assertAllEqual(m.variables, m.a._variables)
Example #9
0
    def test_memoized_in_tf2(self):
        if not tf2.enabled():
            self.skipTest("Requires TF2")

        mod = module.Module(name="name")
        name_scope_1 = mod.name_scope
        name_scope_2 = mod.name_scope
        self.assertIs(name_scope_1, name_scope_2)
Example #10
0
    def test_raises_error_with_path(self):
        non_orderable = object

        m = module.Module()
        m.layers = {non_orderable(): None, non_orderable(): None}
        with self.assertRaisesRegex(ValueError,
                                    "Error processing property 'layers'"):
            m.variables  # pylint: disable=pointless-statement
Example #11
0
 def test_supports_variable_like_objects(self):
     m = module.Module()
     v = VariableLike()
     self.assertFalse(hasattr(v, "trainable"))
     m.v = v
     self.assertEqual(m.variables, (v, ))
     self.assertEmpty(m.trainable_variables)
     m.v.trainable = True
     self.assertEqual(m.trainable_variables, (v, ))
Example #12
0
 def testDictWrapperBadKeys(self):
     a = module.Module()
     a.d = {}
     a.d[1] = data_structures.List()
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     with self.assertRaisesRegex(ValueError, "non-string key"):
         model.save_weights(save_path)
Example #13
0
    def test_not_memoized_in_tf1(self):
        if tf2.enabled():
            self.skipTest("Requires TF1")

        mod = module.Module(name="name")
        name_scope_1 = mod.name_scope
        name_scope_2 = mod.name_scope
        self.assertIsNot(name_scope_1, name_scope_2)
        self.assertEqual(name_scope_1.name, name_scope_2.name)
Example #14
0
 def testDictWrapperNoDependency(self):
     a = module.Module()
     a.d = data_structures.NoDependency({})
     a.d[1] = [3]
     self.assertEqual([a], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Example #15
0
  def test_untracked_variable_useful_message(self):
    root = module.Module()
    v = variables.Variable(1., name="some_unique_name")

    @def_function.function(input_signature=[])
    def f():
      return v.read_value()

    root.f = f
    with self.assertRaisesRegex(AssertionError, "some_unique_name"):
      save.save(root, os.path.join(self.get_temp_dir(), "saved_model"))
 def testNamedTuple(self):
     named = collections.namedtuple("Named", ("x", "y"))
     v = variables.Variable(2)
     nt = named(x=v, y=2)
     m = module.Module()
     m.nt = nt
     self.assertIs(v, m.nt.x)
     self.assertIs(v, m.nt[0])
     self.assertIs(v,
                   m._trackable_children()["nt"]._trackable_children()["x"])
     self.assertEqual(2, m.nt.y)
 def testNamedTuple(self):
   named = collections.namedtuple("Named", ("x", "y"))
   v = variables.Variable(2)
   nt = named(x=v, y=2)
   m = module.Module()
   m.nt = nt
   self.assertIs(v, m.nt.x)
   self.assertIs(v, m.nt[0])
   self.assertIs(
       v, m._checkpoint_dependencies[0].ref._checkpoint_dependencies[0].ref)
   self.assertEqual(2, m.nt.y)
Example #18
0
 def _get_checkpoint_name(self, name):
     root = module.Module()
     trackable_utils.add_variable(root,
                                  name=name,
                                  shape=[1, 2],
                                  dtype=dtypes.float64)
     (named_variable, ), _, _ = trackable_utils._serialize_object_graph(
         root, saveables_cache=None)
     with ops.name_scope_v2("root/" + named_variable.name):
         pass  # Make sure we can use this as an op name if we prefix it.
     return named_variable.name
Example #19
0
    def test_with_path(self):
        mod = module.Module()
        mod.w = variables.Variable(1.)
        mod.encoder = module.Module()
        mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
        mod.decoder = mod.encoder

        state_dict = dict(
            mod._flatten(with_path=True, predicate=module._IS_VARIABLE))

        self.assertEqual(
            state_dict,
            {
                ("w", ): mod.w,
                ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
                ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
                ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
                ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]
            },
        )
Example #20
0
 def testNoDepList(self):
     a = training.Model()
     a.l1 = data_structures.NoDependency([])
     a.l1.insert(1, 0)
     self.assertIsInstance(a.l1, list)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l2 = []
     a.l2.insert(1, module.Module())
     with self.assertRaisesRegex(ValueError, "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
Example #21
0
 def testNonAppendNotTrackable(self):
     # Non-append mutations (deleting or overwriting values) are OK when the
     # values aren't tracked.
     a = module.Module()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = 3
     a.d[1] = 2
     self.assertEqual(2, a.d[1])
     del a.d[1]
     a.d[2] = data_structures.NoDependency(module.Module())
     second = module.Module()
     a.d[2] = data_structures.NoDependency(second)
     self.assertIs(second, a.d[2])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Example #22
0
    def testNoDependency(self):
        root = module.Module()
        hasdep = module.Module()
        root.hasdep = hasdep
        nodep = module.Module()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertEqual(1, len(root._checkpoint_dependencies))
        self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @base.no_automatic_dependency_tracking
            def __init__(self):
                super(NoDependencyModel, self).__init__()
                self.a = []
                self.b = module.Module()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
Example #23
0
 def testNonStringKeyNotTrackableValue(self):
     a = module.Module()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = data_structures.NoDependency([3])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Example #24
0
    def testLoadSavedModelWithUnregisteredStruct(self):
        MaskedTensor = build_simple_masked_tensor_type()

        def f(x, y):
            x_values = x.values if isinstance(x, MaskedTensor) else x
            y_values = y.values if isinstance(y, MaskedTensor) else y
            x_mask = x.mask if isinstance(x, MaskedTensor) else True
            y_mask = y.mask if isinstance(y, MaskedTensor) else True
            return MaskedTensor(x_values + y_values, x_mask & y_mask)

        t_spec = tensor_spec.TensorSpec(None, dtypes.int32)
        b_spec = tensor_spec.TensorSpec(None, dtypes.bool)
        mt_spec = MaskedTensor.Spec(values=t_spec, mask=b_spec)
        model = module.Module()
        model.f = def_function.function(f)
        model.f.get_concrete_function(t_spec, t_spec)
        model.f.get_concrete_function(t_spec, mt_spec)
        model.f.get_concrete_function(mt_spec, t_spec)
        model.f.get_concrete_function(mt_spec, mt_spec)

        path = tempfile.mkdtemp(prefix=test.get_temp_dir())
        with temporarily_register_type_spec('tf.test.MaskedTensor.Spec',
                                            MaskedTensor.Spec):
            save.save(model, path)
        loaded_model = load.load(path)

        with self.assertRaises(ValueError):
            type_spec.lookup('tf.test.MaskedTensor')

        t = constant_op.constant([10, 20, 30])
        v1 = loaded_model.f(t, t)
        self.assertIsInstance(v1, tensor_struct.AnonymousStruct)
        self.assertAllEqual(v1.values, [20, 40, 60])
        self.assertAllEqual(v1.mask, True)

        v2 = loaded_model.f(v1, v1)
        self.assertIsInstance(v2, tensor_struct.AnonymousStruct)
        self.assertAllEqual(v2.values, [40, 80, 120])
        self.assertAllEqual(v2.mask, True)

        mt = MaskedTensor([1, 2, 3], [True, True, False])
        v3 = loaded_model.f(
            t,
            tensor_struct.reinterpret_struct(mt,
                                             tensor_struct.AnonymousStruct))
        self.assertIsInstance(v3, tensor_struct.AnonymousStruct)
        self.assertAllEqual(v3.values, [11, 22, 33])
        self.assertAllEqual(v3.mask, [True, True, False])

        v4 = tensor_struct.reinterpret_struct(v3, MaskedTensor)
        self.assertIsInstance(v4, MaskedTensor)
        self.assertAllEqual(v4.values, [11, 22, 33])
        self.assertAllEqual(v4.mask, [True, True, False])
  def test_supports_distributed_variables(self):
    mirrored = distributed_values.MirroredVariable(
        None, [variables.Variable(1.)], variables.VariableAggregation.SUM)
    tpu = tpu_values.TPUMirroredVariable(
        strategy=None, values=[variables.Variable(42.)], aggregation=None)
    aggregating = ps_values.AggregatingVariable(
        strategy=None, v=variables.Variable(1.), aggregation=None)

    m = module.Module()
    m.a = mirrored
    m.b = tpu
    m.c = aggregating
    self.assertEqual(m.variables, (mirrored, tpu, aggregating))
Example #26
0
 def testLayerCollectionWithExternalMutation(self):
     d = {}
     root = module.Module()
     root.wrapper = d
     self.assertEqual([], root.wrapper.layers)
     self.assertEqual([], root.wrapper.trainable_weights)
     layer1 = core.Dense(1)
     layer2 = core.Dense(1)
     d["a"] = layer1
     d["b"] = layer2
     self.assertEqual([layer1, layer2], root.wrapper.layers)
     # The layers have still not created variables
     self.assertEqual([], root.wrapper.trainable_weights)
Example #27
0
    def testSameStructure(self):
        t = (variables.Variable(1.), )
        m = module.Module()
        m.t = t
        nest.assert_same_structure(t, m.t)
        nest.assert_same_structure(m.t, t)

        nt_type = collections.namedtuple("nt", ["x", "y"])
        nt = nt_type(x=1, y=2)
        m.nt = nt
        nest.assert_same_structure(m.nt, nt)
        with self.assertRaises(TypeError):  # pylint: disable=g-error-prone-assert-raises
            nest.assert_same_structure(m.nt, m.t)
Example #28
0
    def testNamedtupleSubclassWithCustomNew(self):
        class SubclassWithDifferentArgs(collections.namedtuple("A", ["x"])):
            def __new__(cls):
                return super(SubclassWithDifferentArgs, cls).__new__(cls, [])

        nt = SubclassWithDifferentArgs()
        m = module.Module()
        m.nt = nt
        m.nt.x.append(variables.Variable(1.))
        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        ckpt = util.Checkpoint(m=m)
        with self.assertRaises(ValueError):
            ckpt.save(prefix)
Example #29
0
  def test_raises_error_with_path(self):
    if six.PY2:
      class NonOrderable(object):
        __lt__ = None

      non_orderable = NonOrderable
    else:
      non_orderable = object

    m = module.Module()
    m.layers = {non_orderable(): None, non_orderable(): None}
    with self.assertRaisesRegexp(ValueError,
                                 "Error processing property 'layers'"):
      m.variables  # pylint: disable=pointless-statement
Example #30
0
    def test_attributes_to_ignore(self):
        class DangerousModule(module.Module):
            _TF_MODULE_IGNORED_PROPERTIES = frozenset(
                itertools.chain(("dangerous_submodule", "dangerous_variable"),
                                module.Module._TF_MODULE_IGNORED_PROPERTIES))

        mod = DangerousModule()
        mod.dangerous_submodule = module.Module()
        mod.dangerous_variable = variables.Variable(1.)
        mod.normal_variable = variables.Variable(2.)

        self.assertEmpty(mod.submodules)
        self.assertLen(mod.variables, 1)
        self.assertEqual(mod.variables[0], mod.normal_variable)