Ejemplo n.º 1
0
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = trackable_utils.Checkpoint()
    root.var = trackable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.executing_eagerly():
      optimizer.minimize(root.var.read_value)
    else:
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(trackable_utils.gather_initializers(
          trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
    new_root = trackable_utils.Checkpoint()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = new_root.restore(slots_path)
    no_slot_status = new_root.restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = trackable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    slot_status.assert_existing_objects_matched()
    with self.assertRaisesRegex(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.executing_eagerly():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.executing_eagerly():
      new_root.optimizer.minimize(new_root.var.read_value)
    else:
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    slot_status.assert_consumed()
    def testInitNotCalled(self):
        class NoInit(tracking.AutoTrackable):
            def __init__(self):
                pass

        # __init__ for Trackable will be called implicitly.
        trackable_utils.add_variable(NoInit(), "var", shape=[])
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = trackable_utils.Checkpoint()
    root.var = trackable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.executing_eagerly():
      optimizer.minimize(root.var.read_value)
    else:
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(trackable_utils.gather_initializers(
          trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
    new_root = trackable_utils.Checkpoint()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = new_root.restore(slots_path)
    no_slot_status = new_root.restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = trackable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    slot_status.assert_existing_objects_matched()
    with self.assertRaisesRegexp(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.executing_eagerly():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.executing_eagerly():
      new_root.optimizer.minimize(new_root.var.read_value)
    else:
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    slot_status.assert_consumed()
Ejemplo n.º 4
0
 def testNumberedPath(self):
   root = tracking.AutoTrackable()
   leaf = tracking.AutoTrackable()
   root.leaf = leaf
   trackable_utils.add_variable(leaf, name="v", shape=[])
   (named_variable,), _, _ = graph_view.ObjectGraphView(
       root).serialize_object_graph()
   self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name)
Ejemplo n.º 5
0
 def _get_checkpoint_name(self, name):
   root = tracking.AutoTrackable()
   trackable_utils.add_variable(
       root, name=name, shape=[1, 2], dtype=dtypes.float64)
   (named_variable,), _, _ = graph_view.ObjectGraphView(
       root).serialize_object_graph()
   with ops.name_scope("root/" + named_variable.name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return named_variable.name
 def _get_checkpoint_name(self, name):
   root = tracking.AutoTrackable()
   trackable_utils.add_variable(
       root, name=name, shape=[1, 2], dtype=dtypes.float64)
   (named_variable,), _, _ = trackable_utils._serialize_object_graph(
       root, saveables_cache=None)
   with ops.name_scope("root/" + named_variable.name):
     pass  # Make sure we can use this as an op name if we prefix it.
   return named_variable.name
 def testLocalNameValidation(self):
     root = tracking.AutoTrackable()
     leaf = tracking.AutoTrackable()
     # Dots are escaped, which avoids conflicts with reserved names.
     root._track_trackable(leaf, name=".ATTRIBUTES")
     trackable_utils.add_variable(trackable=leaf, name="a", shape=[])
     (named_variable,
      ), _, _ = graph_view.ObjectGraphView(root).serialize_object_graph()
     self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE",
                      named_variable.name)
Ejemplo n.º 8
0
 def _get_checkpoint_name(self, name):
     root = module.Module()
     trackable_utils.add_variable(root,
                                  name=name,
                                  shape=[1, 2],
                                  dtype=dtypes.float64)
     (named_variable, ), _, _ = trackable_utils._serialize_object_graph(
         root, saveables_cache=None)
     with ops.name_scope_v2("root/" + named_variable.name):
         pass  # Make sure we can use this as an op name if we prefix it.
     return named_variable.name
Ejemplo n.º 9
0
 def testShapeDtype(self):
   root = tracking.AutoTrackable()
   v1 = trackable_utils.add_variable(
       root, name="v1", initializer=3., dtype=dtypes.float64)
   self.assertEqual(dtypes.float64, v1.dtype)
   v2 = trackable_utils.add_variable(
       root,
       name="v2",
       shape=[3],
       initializer=init_ops.ones_initializer,
       dtype=dtypes.float64)
   self.assertEqual(dtypes.float64, v2.dtype)
   self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
    def testOverlappingRestores(self):
        checkpoint_directory = self.get_temp_dir()
        save_root = trackable_utils.Checkpoint()
        save_root.dep = tracking.AutoTrackable()
        save_root.dep.var = trackable_utils.add_variable(save_root.dep,
                                                         name="var",
                                                         initializer=0.)
        self.evaluate(state_ops.assign(save_root.dep.var, 12.))
        first_path = save_root.save(os.path.join(checkpoint_directory,
                                                 "first"))
        self.evaluate(state_ops.assign(save_root.dep.var, 13.))
        second_path = save_root.save(
            os.path.join(checkpoint_directory, "second"))

        first_root = trackable_utils.Checkpoint()
        second_root = trackable_utils.Checkpoint()
        first_status = first_root.restore(first_path)
        second_status = second_root.restore(second_path)
        load_dep = tracking.AutoTrackable()
        load_dep.var = trackable_utils.add_variable(load_dep,
                                                    name="var",
                                                    shape=[])
        first_root.dep = load_dep
        first_status.assert_consumed()
        first_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(load_dep.var))
        second_root.dep = load_dep
        second_status.assert_consumed()
        second_status.run_restore_ops()
        self.assertEqual(13., self.evaluate(load_dep.var))

        # Try again with the order of the restore() reversed. The last restore
        # determines the final value.
        first_root = trackable_utils.Checkpoint()
        second_root = trackable_utils.Checkpoint()
        second_status = second_root.restore(second_path)
        first_status = first_root.restore(first_path)
        load_dep = tracking.AutoTrackable()
        load_dep.var = trackable_utils.add_variable(load_dep,
                                                    name="var",
                                                    shape=[])
        first_root.dep = load_dep
        first_status.assert_consumed()
        first_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(load_dep.var))
        second_root.dep = load_dep
        second_status.assert_consumed()
        second_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(load_dep.var))
Ejemplo n.º 11
0
  def testAddVariable(self):
    obj = NonLayerTrackable()
    with self.assertRaisesRegex(ValueError, "do not specify shape"):
      trackable_utils.add_variable(
          obj, name="shape_specified_twice", shape=[], initializer=1)
    constant_initializer = trackable_utils.add_variable(
        obj, name="constant_initializer", initializer=1)
    with variable_scope.variable_scope("some_variable_scope"):
      ones_initializer = trackable_utils.add_variable(
          obj,
          name="ones_initializer",
          shape=[2],
          initializer=init_ops.ones_initializer(dtype=dtypes.float32))
    bare_initializer = trackable_utils.add_variable(
        obj,
        name="bare_initializer",
        shape=[2, 2],
        dtype=dtypes.float64,
        initializer=init_ops.zeros_initializer)

    # Even in graph mode, there are no naming conflicts between objects, only
    # naming conflicts within an object.
    other_duplicate = resource_variable_ops.ResourceVariable(
        name="duplicate", initial_value=1.)
    duplicate = trackable_utils.add_variable(
        obj, name="duplicate", shape=[])
    with self.assertRaisesRegex(ValueError, "'duplicate'.*already declared"):
      trackable_utils.add_variable(obj, name="duplicate", shape=[])

    self.evaluate(trackable_utils.gather_initializers(obj))
    self.assertEqual("constant_initializer:0", constant_initializer.name)
    self.assertEqual(1, self.evaluate(constant_initializer))
    self.assertEqual("some_variable_scope/ones_initializer:0",
                     ones_initializer.name)
    self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
    self.assertAllEqual([[0., 0.],
                         [0., 0.]], self.evaluate(bare_initializer))
    self.assertEqual("a_variable:0", obj.a_variable.name)
    self.assertEqual("duplicate:0", other_duplicate.name)
    if context.executing_eagerly():
      # When executing eagerly, there's no uniquification of variable names. The
      # checkpoint name will be the same.
      self.assertEqual("duplicate:0", duplicate.name)
    else:
      # The .name attribute may be globally influenced, but the checkpoint name
      # won't be (tested below).
      self.assertEqual("duplicate_1:0", duplicate.name)
    named_variables, _, _ = (
        graph_view.ObjectGraphView(obj).serialize_object_graph())
    expected_checkpoint_names = (
        "a_variable/.ATTRIBUTES/VARIABLE_VALUE",
        "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE",
        "duplicate/.ATTRIBUTES/VARIABLE_VALUE",
        "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
    )
    six.assertCountEqual(
        self, expected_checkpoint_names, [v.name for v in named_variables])
Ejemplo n.º 12
0
    def __init__(self,
                 variables: VariableSource,
                 objective: ObjectiveFunction,
                 config: Config = None):
        self.objective = objective
        self.config = config or Trainer.Config()
        self.optimizer = Optimizer(config=self.config.optimizer)
        self.get_variables = self._make_variables_extractor(variables)

        tf.Variable(0, name='num_completed_steps', dtype=tf.int64)
        self.loss = None

        # Setup trackable internal state that'll be saved/restored
        self._track_trackable(
            trackable=self.optimizer,
            name='optimizer'
        )
        with tf.device('/CPU:0'):
            self._num_completed_steps = add_variable(
                trackable=self,
                name='num_completed_steps',
                dtype=tf.int64,
                initializer=0,
                trainable=False
            )
    def testDependencyLoop(self):
        # Note: this test creates garbage during eager execution because it
        # purposefully creates a reference cycle.
        first = trackable_utils.Checkpoint()
        second = trackable_utils.Checkpoint()
        first.second = second
        second.first = first
        first.v = trackable_utils.add_variable(first,
                                               "v1",
                                               initializer=[3., 1., 4.])
        second.v = trackable_utils.add_variable(second,
                                                "v2",
                                                initializer=[1., 1., 2., 3.])
        self.evaluate(trackable_utils.gather_initializers(first))
        checkpoint_directory = self.get_temp_dir()
        save_path = first.save(os.path.join(checkpoint_directory, "ckpt"))

        # Test deferred loading
        first_load = trackable_utils.Checkpoint()
        status = first_load.restore(save_path)
        second_load = tracking.AutoTrackable()
        first_load.second = second_load
        second_load.first = first_load
        with self.assertRaises(AssertionError):
            status.assert_consumed()
        first_load.v = trackable_utils.add_variable(first_load,
                                                    "v1",
                                                    shape=[3])
        second_load.v = trackable_utils.add_variable(second_load,
                                                     "v2",
                                                     shape=[4])
        status.assert_consumed()
        status.run_restore_ops()
        self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
        self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))

        # Test loading when variables have already been created
        self.evaluate(first_load.v.assign([2., 7., 1.]))
        self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v))
        self.evaluate(second_load.v.assign([2., 7., 1., 8.]))
        self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v))
        status = first_load.restore(save_path).assert_consumed()
        status.run_restore_ops()
        self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
        self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
 def testObjectsCombined(self):
     # Currently fine to load two checkpoint objects into one Python object
     checkpoint_directory = self.get_temp_dir()
     save_root = trackable_utils.Checkpoint()
     save_root.dep_one = tracking.AutoTrackable()
     save_root.dep_two = tracking.AutoTrackable()
     trackable_utils.add_variable(save_root.dep_one,
                                  name="var1",
                                  initializer=32.,
                                  dtype=dtypes.float64)
     trackable_utils.add_variable(save_root.dep_two,
                                  name="var2",
                                  initializer=64.,
                                  dtype=dtypes.float64)
     self.evaluate(trackable_utils.gather_initializers(save_root))
     save_path = save_root.save(os.path.join(checkpoint_directory, "ckpt"))
     load_root = trackable_utils.Checkpoint()
     load_root.dep_one = tracking.AutoTrackable()
     load_root.dep_two = load_root.dep_one
     v1 = trackable_utils.add_variable(load_root.dep_one,
                                       name="var1",
                                       shape=[],
                                       dtype=dtypes.float64)
     v2 = trackable_utils.add_variable(load_root.dep_one,
                                       name="var2",
                                       shape=[],
                                       dtype=dtypes.float64)
     status = load_root.restore(
         save_path).assert_consumed().assert_existing_objects_matched()
     status.run_restore_ops()
     self.assertEqual(32., self.evaluate(v1))
     self.assertEqual(64., self.evaluate(v2))
 def testAmbiguousLoad(self):
     # Not OK to split one checkpoint object into two
     checkpoint_directory = self.get_temp_dir()
     save_root = trackable_utils.Checkpoint()
     save_root.dep_one = tracking.AutoTrackable()
     save_root.dep_two = tracking.AutoTrackable()
     dep_three = tracking.AutoTrackable()
     save_root.dep_one.dep_three = dep_three
     save_root.dep_two.dep_three = dep_three
     trackable_utils.add_variable(dep_three, name="var", initializer=0.)
     self.evaluate(trackable_utils.gather_initializers(save_root))
     save_path = save_root.save(os.path.join(checkpoint_directory, "ckpt"))
     load_root = trackable_utils.Checkpoint()
     status = load_root.restore(save_path)
     load_root.dep_one = tracking.AutoTrackable()
     load_root.dep_two = tracking.AutoTrackable()
     load_root.dep_one.dep_three = tracking.AutoTrackable()
     load_root.dep_two.dep_three = tracking.AutoTrackable()
     trackable_utils.add_variable(load_root.dep_one.dep_three,
                                  name="var",
                                  initializer=0.)
     trackable_utils.add_variable(load_root.dep_two.dep_three,
                                  name="var",
                                  initializer=0.)
     with self.assertRaises(AssertionError):
         status.assert_consumed()
     with self.assertRaises(AssertionError):
         status.assert_existing_objects_matched()
Ejemplo n.º 16
0
    def testDeferredSlotRestoration(self):
        with self.test_session():
            checkpoint_directory = self.get_temp_dir()

            root = tf.train.Checkpoint()
            root.var = trackable_utils.add_variable(root,
                                                    name="var",
                                                    initializer=0.)
            optimizer = adam.Adam(0.1)
            variables = [root.var]
            gradients = [1.]
            train_op = optimizer.apply_gradients(zip(gradients, variables))
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                trackable_utils.gather_initializers(
                    tf.train.Checkpoint(root=root, optimizer=optimizer)))
            self.evaluate(train_op)
            self.evaluate(tf.compat.v1.assign(root.var, 12.))
            no_slots_path = root.save(
                os.path.join(checkpoint_directory, "no_slots"))
            root.optimizer = optimizer
            self.evaluate(tf.compat.v1.assign(root.var, 13.))
            self.evaluate(
                tf.compat.v1.assign(
                    optimizer.get_slot(slot_name="m", var=root.var), 14.))
            slots_path = root.save(
                os.path.join(checkpoint_directory, "with_slots"))
            new_root = tf.train.Checkpoint()
            # Load the slot-containing checkpoint (deferred), then immediately
            # overwrite the non-slot variable (also deferred).
            slot_status = new_root.restore(slots_path)
            no_slot_status = new_root.restore(no_slots_path)
            with self.assertRaises(AssertionError):
                no_slot_status.assert_consumed()
            new_root.var = trackable_utils.add_variable(new_root,
                                                        name="var",
                                                        shape=[])
            no_slot_status.assert_consumed()
            no_slot_status.run_restore_ops()
            self.assertEqual(12., self.evaluate(new_root.var))
            new_root.optimizer = adam.Adam(0.1)
            slot_status.assert_existing_objects_matched()
            if not tf.executing_eagerly():
                with self.assertRaisesRegex(AssertionError,
                                            "Unresolved object"):
                    slot_status.assert_consumed()
            self.assertEqual(12., self.evaluate(new_root.var))
            if tf.executing_eagerly():
                # Slot variables are only created with restoring initializers when
                # executing eagerly.
                self.assertEqual(
                    14.,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)))
            else:
                # Slot variables are not created eagerly when graph building.
                with self.assertRaises(KeyError):
                    new_root.optimizer.get_slot(slot_name="m",
                                                var=new_root.var)
            variables = [new_root.var]
            gradients = [1.]
            train_op = new_root.optimizer.apply_gradients(
                zip(gradients, variables))
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            if not tf.executing_eagerly():
                # The train op hasn't run when graph building, so the slot variable has
                # its restored value. It has run in eager, so the value will
                # be different.
                self.assertEqual(
                    14.,
                    self.evaluate(
                        new_root.optimizer.get_slot(slot_name="m",
                                                    var=new_root.var)))
            self.evaluate(train_op)
            slot_status.assert_consumed()
Ejemplo n.º 17
0
 def __init__(self):
   super(NonLayerTrackable, self).__init__()
   self.a_variable = util.add_variable(
       self, name="a_variable", shape=[])
Ejemplo n.º 18
0
 def __init__(self):
   super(NonLayerTrackable, self).__init__()
   self.a_variable = trackable_utils.add_variable(
       self, name="a_variable", shape=[])
 def build(self):
     self.var = trackable_utils.add_variable(self,
                                             "var",
                                             initializer=0.)