예제 #1
0
 def __init__(self):
     super().__init__()
     self.isdep = keras.layers.Dense(1)
     self.notdep = data_structures.NoDependency(
         keras.layers.Dense(2)
     )
     self.notdep_var = data_structures.NoDependency(
         tf.Variable(1.0, name="notdep_var")
     )
예제 #2
0
  def __init__(
      self,
      handle,
      trainable=False,
      arguments=None,
      _sentinel=None,  # pylint: disable=invalid-name
      tags=None,
      signature=None,
      signature_outputs_as_dict=None,
      output_key=None,
      output_shape=None,
      load_options=None,
      **kwargs):
    # Note: for compatibility with keras-model serialization this layer is
    # json-serializable. If you add or change arguments here, please also update
    # the `get_config` method.
    # The arguments are marked NoDependency to avoid autoconversion to a
    # trackable _DictWrapper, because that upsets json.dumps() when saving
    # the result of get_config().
    self._handle = handle
    self._arguments = data_structures.NoDependency(arguments or {})
    self._signature = signature
    self._signature_outputs_as_dict = signature_outputs_as_dict
    self._output_key = output_key
    if output_shape:
      # Autograph chokes on _convert_nest_to_shapes(), so we call it here
      # and not from within call().
      self._output_shape = data_structures.NoDependency(
          _convert_nest_to_shapes(output_shape))

    self._load_options = load_options
    self._func = load_module(handle, tags, self._load_options)
    self._is_hub_module_v1 = getattr(self._func, "_is_hub_module_v1", False)

    # Update with the defaults when using legacy TF1 Hub format.
    if self._is_hub_module_v1:
      self._signature = self._signature or "default"
      if not self._signature_outputs_as_dict:
        self._output_key = self._output_key or "default"
    # More validity checks.
    if self._signature and (bool(self._output_key is not None)
                            == bool(self._signature_outputs_as_dict)):
      raise ValueError("When using a signature, either output_key or "
                       "signature_outputs_as_dict=True should be set.")
    if not self._signature and self._signature_outputs_as_dict:
      raise ValueError("signature_outputs_as_dict is only valid if specifying "
                       "a signature (or using a legacy TF1 Hub format).")

    self._callable = self._get_callable()
    self._has_training_argument = func_has_training_argument(self._callable)
    self._setup_layer(trainable, **kwargs)
예제 #3
0
    def __init__(self, mesh: layout.Mesh, root=None, **kwargs):
        super(DTensorCheckpoint, self).__init__(root=root, **kwargs)
        self._mesh = mesh

        saver_root = self
        attached_dependencies = None
        self._save_counter = None  # Created lazily for restore-on-create.
        self._save_assign_op = None

        if root:
            util._assert_trackable(root, "root")
            saver_root = root
            attached_dependencies = []

            # All keyword arguments (including root itself) are set as children
            # of root.
            kwargs["root"] = root
            root._maybe_initialize_trackable()

            self._save_counter = data_structures.NoDependency(
                root._lookup_dependency("save_counter"))
            self._root = data_structures.NoDependency(root)

        for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
            setattr(self, k, v)

            # Call getattr instead of directly using v because setattr converts
            # v to a Trackable data structure when v is a list/dict/tuple.
            converted_v = getattr(self, k)
            util._assert_trackable(converted_v, k)

            if root:
                # Make sure that root doesn't already have dependencies with these names
                attached_dependencies = attached_dependencies or []
                child = root._lookup_dependency(k)
                if child is None:
                    attached_dependencies.append(
                        base.TrackableReference(k, converted_v))
                elif child != converted_v:
                    raise ValueError(
                        "Cannot create a Checkpoint with keyword argument {name} if "
                        "root.{name} already exists.".format(name=k))
        # DTensor Change:
        # Override the parents saver with DTrackableSaver with _SingleDeviceSaver.
        self._saver = DTrackableSaver(
            mesh,
            graph_view_lib.ObjectGraphView(
                weakref.ref(saver_root),
                attached_dependencies=attached_dependencies))
예제 #4
0
 def testNestedLists(self):
     a = autotrackable.AutoTrackable()
     a.l = []
     b = autotrackable.AutoTrackable()
     a.l.append([b])
     c = autotrackable.AutoTrackable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = autotrackable.AutoTrackable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = autotrackable.AutoTrackable()
     f = autotrackable.AutoTrackable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegex(ValueError, "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
예제 #5
0
 def testDictWrapperNoDependency(self):
     a = tf.Module()
     a.d = data_structures.NoDependency({})
     a.d[1] = [3]
     self.assertEqual([a], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
예제 #6
0
 def testNoDepList(self):
     a = training.Model()
     a.l1 = data_structures.NoDependency([])
     a.l1.insert(1, 0)
     self.assertIsInstance(a.l1, list)
     checkpoint = tf.train.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l2 = []
     a.l2.insert(1, tf.Module())
     with self.assertRaisesRegex(ValueError, "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
예제 #7
0
 def testNonAppendNotTrackable(self):
     # Non-append mutations (deleting or overwriting values) are OK when the
     # values aren't tracked.
     a = tf.Module()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = 3
     a.d[1] = 2
     self.assertEqual(2, a.d[1])
     del a.d[1]
     a.d[2] = data_structures.NoDependency(tf.Module())
     second = tf.Module()
     a.d[2] = data_structures.NoDependency(second)
     self.assertIs(second, a.d[2])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
예제 #8
0
 def testNonStringKeyNotTrackableValue(self):
     a = tf.Module()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = data_structures.NoDependency([3])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
예제 #9
0
    def testNoDependency(self):
        root = tf.Module()
        hasdep = tf.Module()
        root.hasdep = hasdep
        nodep = tf.Module()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertLen(root._trackable_children(), 1)
        self.assertIs(root._trackable_children()["hasdep"], root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @tf.__internal__.tracking.no_automatic_dependency_tracking
            def __init__(self):
                super().__init__()
                self.a = []
                self.b = tf.Module()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
예제 #10
0
 def _no_dependency(self, value):
   """Override to allow TrackableBase to disable dependency tracking."""
   return data_structures.NoDependency(value)