Example #1
0
  def testNames(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    x1 = resource_variable_ops.ResourceVariable(2.)
    x2 = resource_variable_ops.ResourceVariable(3.)
    x3 = resource_variable_ops.ResourceVariable(4.)
    y = resource_variable_ops.ResourceVariable(5.)
    slots = containers.UniqueNameTracker()
    slots.track(x1, "x")
    slots.track(x2, "x")
    slots.track(x3, "x_1")
    slots.track(y, "y")
    self.evaluate((x1.initializer, x2.initializer, x3.initializer,
                   y.initializer))
    save_root = checkpointable_utils.Checkpoint(slots=slots)
    save_path = save_root.save(checkpoint_prefix)

    restore_slots = checkpointable.Checkpointable()
    restore_root = checkpointable_utils.Checkpoint(
        slots=restore_slots)
    status = restore_root.restore(save_path)
    restore_slots.x = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.y = resource_variable_ops.ResourceVariable(0.)
    status.assert_consumed().run_restore_ops()
    self.assertEqual(2., self.evaluate(restore_slots.x))
    self.assertEqual(3., self.evaluate(restore_slots.x_1))
    self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
    self.assertEqual(5., self.evaluate(restore_slots.y))
Example #2
0
    def testSaveRestoreSplitDep(self):
        save_checkpoint = checkpointable_utils.Checkpoint(
            dep=SaveTensorSlicesAsDeps())
        self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_checkpoint.save(checkpoint_prefix)

        regular_deps = HasRegularDeps()
        regular_restore_checkpoint = checkpointable_utils.Checkpoint(
            dep=regular_deps)
        regular_restore_checkpoint.restore(
            save_path).assert_consumed().run_restore_ops()
        self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half))
        self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))

        one_dep = OnlyOneDep()
        one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(
            dep=one_dep)
        status = one_dep_restore_checkpoint.restore(save_path)
        with self.assertRaises(AssertionError):
            # Missing the second dependency.
            status.assert_consumed()
        status.run_restore_ops()
        self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))

        restore_checkpoint = checkpointable_utils.Checkpoint()
        status = restore_checkpoint.restore(save_path)
        restore_checkpoint.dep = SaveTensorSlicesAsDeps()
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([1., 2., 3., 4.],
                            self.evaluate(restore_checkpoint.dep.combined))
Example #3
0
    def testAnonymousVarsInInit(self):
        class Model(training.Model):
            def __init__(self):
                super(Model, self).__init__()
                self.w = resource_variable_ops.ResourceVariable(0.0)
                self.b = resource_variable_ops.ResourceVariable(0.0)
                self.vars = [self.w, self.b]

            def call(self, x):
                return x * self.w + self.b

        with context.eager_mode():
            model = Model()
            optimizer = adam.AdamOptimizer(learning_rate=0.05)
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            checkpoint = checkpointable_utils.Checkpoint(model=model,
                                                         optimizer=optimizer)
            for _ in range(2):
                checkpoint.save(checkpoint_prefix)
                with backprop.GradientTape() as tape:
                    loss = (constant_op.constant(1.) -
                            model(constant_op.constant(1.)))**2
                grad = tape.gradient(loss, model.vars)
                optimizer.apply_gradients([(g, v)
                                           for g, v in zip(grad, model.vars)])
Example #4
0
  def testSaveRestoreMultipleIterator(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    dataset = dataset.map(math_ops.square).batch(2)
    iterator_1 = datasets.Iterator(dataset)
    iterator_2 = datasets.Iterator(dataset)
    dataset_2 = Dataset.range(10)
    iterator_3 = datasets.Iterator(dataset_2)

    checkpoint = checkpointable_utils.Checkpoint(
        iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
    self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
    self.assertEqual(0, iterator_3.get_next().numpy())
    self.assertEqual(1, iterator_3.get_next().numpy())
    self.assertEqual(2, iterator_3.get_next().numpy())

    save_path = checkpoint.save(checkpoint_prefix)
    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
    self.assertEqual(3, iterator_3.get_next().numpy())
    checkpoint.restore(save_path)
    self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    self.assertEqual(3, iterator_3.get_next().numpy())
Example #5
0
  def testExample(self):
    class SlotManager(checkpointable.Checkpointable):

      def __init__(self):
        self.slotdeps = containers.UniqueNameTracker()
        slotdeps = self.slotdeps
        slots = []
        slots.append(slotdeps.track(
            resource_variable_ops.ResourceVariable(3.), "x"))
        slots.append(slotdeps.track(
            resource_variable_ops.ResourceVariable(4.), "y"))
        slots.append(slotdeps.track(
            resource_variable_ops.ResourceVariable(5.), "x"))
        self.slots = slots

    manager = SlotManager()
    self.evaluate([v.initializer for v in manager.slots])
    checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager)
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = checkpoint.save(checkpoint_prefix)
    metadata = object_metadata(save_path)
    dependency_names = []
    for node in metadata.nodes:
      for child in node.children:
        dependency_names.append(child.local_name)
    six.assertCountEqual(
        self,
        dependency_names,
        ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
Example #6
0
 def testAgnosticUsage(self):
     """Graph/eager agnostic usage."""
     # Does create garbage when executing eagerly due to ops.Graph() creation.
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
         with ops.Graph().as_default(), self.test_session(
                 graph=ops.get_default_graph()), test_util.device(
                     use_gpu=True):
             model = MyModel()
             optimizer = adam.AdamOptimizer(0.001)
             root = checkpointable_utils.Checkpoint(
                 optimizer=optimizer,
                 model=model,
                 global_step=training_util.get_or_create_global_step())
             checkpoint_path = core_saver.latest_checkpoint(
                 checkpoint_directory)
             status = root.restore(save_path=checkpoint_path)
             input_value = constant_op.constant([[3.]])
             train_fn = functools.partial(optimizer.minimize,
                                          functools.partial(
                                              model, input_value),
                                          global_step=root.global_step)
             if not context.executing_eagerly():
                 train_fn = functools.partial(self.evaluate, train_fn())
             status.initialize_or_restore()
             for _ in range(num_training_steps):
                 train_fn()
             root.save(file_prefix=checkpoint_prefix)
             self.assertEqual(
                 (training_continuation + 1) * num_training_steps,
                 self.evaluate(root.global_step))
             self.assertEqual(training_continuation + 1,
                              self.evaluate(root.save_counter))
    def test_checkpointable_save_restore(self):
        def _templated():
            v = variable_scope.get_variable(
                "v",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            v2 = variable_scope.get_variable(
                "v2",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            return v, v + 1., v2

        save_template = template.make_template("s1", _templated)
        v1_save, _, v2_save = save_template()
        optimizer = adam.AdamOptimizer(0.0)
        save_root = checkpointable_utils.Checkpoint(my_template=save_template,
                                                    optimizer=optimizer)
        optimizer.minimize(v1_save.read_value)
        self.evaluate([v.initializer for v in optimizer.variables()])
        self.evaluate(v1_save.assign([12.]))
        self.evaluate(v2_save.assign([14.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _templated)
        load_optimizer = adam.AdamOptimizer(0.0)
        load_root = checkpointable_utils.Checkpoint(my_template=load_template,
                                                    optimizer=load_optimizer)
        status = load_root.restore(save_path)
        var, var_plus_one, var2 = load_template()
        load_optimizer.minimize(var.read_value)
        self.assertEqual(2, len(load_template._checkpoint_dependencies))
        self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
        self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([12.], self.evaluate(var))
        self.assertAllEqual([13.], self.evaluate(var_plus_one))
        self.assertAllEqual([14.], self.evaluate(var2))
Example #8
0
  def testRestoreExhaustedIterator(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    dataset = Dataset.range(3)
    iterator = datasets.Iterator(dataset)

    checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
    self.assertEqual(0, iterator.get_next().numpy())
    self.assertEqual(1, iterator.get_next().numpy())
    save_path = checkpoint.save(checkpoint_prefix)
    self.assertEqual(2, iterator.get_next().numpy())
    checkpoint.restore(save_path)
    self.assertEqual(2, iterator.get_next().numpy())
Example #9
0
 def testSaveRestore(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
   dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
   dataset = dataset.map(math_ops.square).batch(2)
   iterator = datasets.Iterator(dataset)
   checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
   self.assertAllEqual([1, 4], iterator.get_next().numpy())
   save_path = checkpoint.save(checkpoint_prefix)
   self.assertAllEqual([9, 16], iterator.get_next().numpy())
   self.assertAllEqual([25, 36], iterator.get_next().numpy())
   checkpoint.restore(save_path)
   self.assertAllEqual([9, 16], iterator.get_next().numpy())
   self.assertAllEqual([25, 36], iterator.get_next().numpy())
Example #10
0
  def testSaveRestore(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    mean = metrics.Mean()
    checkpoint = checkpointable_utils.Checkpoint(mean=mean)
    mean.build()
    mean._built = True
    self.evaluate(mean.init_variables())
    self.evaluate(mean(100.))
    self.evaluate(mean(200.))
    save_path = checkpoint.save(checkpoint_prefix)
    self.evaluate(mean(1000.))
    checkpoint.restore(save_path).assert_consumed().run_restore_ops()
    self.evaluate(mean(300.))
    self.assertAllEqual(200., self.evaluate(mean.value()))

    restore_mean = metrics.Mean()
    restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
    status = restore_checkpoint.restore(save_path)
    restore_update = restore_mean(300.)
    status.assert_consumed().run_restore_ops()
    self.evaluate(restore_update)
    self.assertAllEqual(200., self.evaluate(restore_mean.value()))
    self.assertEqual(3, self.evaluate(restore_mean.denom))
Example #11
0
    def testWithDefun(self):
        num_training_steps = 2
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        for training_continuation in range(3):
            with ops.Graph().as_default(), self.test_session(
                    graph=ops.get_default_graph()), test_util.device(
                        use_gpu=True):
                model = MyModel()
                # Don't actually train so we can test variable values
                optimizer = adam.AdamOptimizer(0.)
                root = checkpointable_utils.Checkpoint(
                    optimizer=optimizer,
                    model=model,
                    global_step=training_util.get_or_create_global_step())
                checkpoint_path = core_saver.latest_checkpoint(
                    checkpoint_directory)
                status = root.restore(save_path=checkpoint_path)

                def train_fn():
                    @function.defun
                    def _call_model(x):
                        return model(x)

                    with backprop.GradientTape() as tape:
                        loss = _call_model(constant_op.constant([[3.]]))
                    gradients = tape.gradient(loss, model.variables)
                    return optimizer.apply_gradients(
                        zip(gradients, model.variables),
                        global_step=root.global_step)

                if not context.executing_eagerly():
                    train_fn = functools.partial(self.evaluate, train_fn())
                status.initialize_or_restore()
                for _ in range(num_training_steps):
                    train_fn()
                if training_continuation > 0:
                    status.assert_consumed()
                    self.assertAllClose([[42.]],
                                        self.evaluate(model.variables[0]))
                else:
                    self.evaluate(model.variables[0].assign([[42.]]))
                root.save(file_prefix=checkpoint_prefix)
                self.assertEqual(
                    (training_continuation + 1) * num_training_steps,
                    self.evaluate(root.global_step))
                self.assertEqual(training_continuation + 1,
                                 self.evaluate(root.save_counter))
Example #12
0
    def testMakeDotGraph(self):
        with context.eager_mode():
            input_value = constant_op.constant([[3.]])
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            optimizer_step = resource_variable_ops.ResourceVariable(12)
            save_checkpoint = checkpointable_utils.Checkpoint(
                optimizer=optimizer,
                model=model,
                optimizer_step=optimizer_step)
            optimizer.minimize(functools.partial(model, input_value))
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
            save_path = save_checkpoint.save(checkpoint_prefix)
            prefix = save_checkpoint.save(save_path)

        dot_graph_string = visualize.dot_graph_from_checkpoint(prefix)

        # The remainder of this test is more-or-less optional since it's so
        # dependent on pydot/platform/Python versions.
        if pydot is None:
            self.skipTest('pydot is required for the remainder of this test.')
        try:
            parsed, = pydot.graph_from_dot_data(dot_graph_string)
        except NameError as e:
            if "name 'dot_parser' is not defined" in str(e):
                self.skipTest("pydot isn't working")
            else:
                raise
        # Check that the graph isn't completely trivial
        self.assertEqual(
            '"model"',
            parsed.obj_dict['edges'][('N_0', 'N_1')][0]['attributes']['label'])
        image_path = os.path.join(self.get_temp_dir(), 'saved.svg')
        try:
            parsed.write_svg(image_path)
        except Exception as e:  # pylint: disable=broad-except
            # For some reason PyDot's "dot not available" error is an Exception, not
            # something more specific.
            if '"dot" not found in path' in str(e):
                self.skipTest("pydot won't save SVGs (dot not available)")
            else:
                raise
Example #13
0
 def _initialized_model(self):
     input_value = constant_op.constant([[3.]])
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     optimizer_step = training_util.get_or_create_global_step()
     root_checkpointable = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model, optimizer_step=optimizer_step)
     train_op = optimizer.minimize(functools.partial(model, input_value),
                                   global_step=optimizer_step)
     self.evaluate(
         checkpointable_utils.gather_initializers(root_checkpointable))
     self.evaluate(train_op)
     # A regular variable, a slot variable, and a non-slot Optimizer variable
     # with known values to check when loading.
     self.evaluate(model._named_dense.bias.assign([1.]))
     self.evaluate(
         optimizer.get_slot(var=model._named_dense.bias,
                            name="m").assign([2.]))
     beta1_power, _ = optimizer._get_beta_accumulators()
     self.evaluate(beta1_power.assign(3.))
     return root_checkpointable
Example #14
0
 def testDeferredRestorationUsageEager(self):
     """An idiomatic eager execution example."""
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = checkpointable_utils.Checkpoint(
             optimizer=optimizer,
             model=model,
             optimizer_step=training_util.get_or_create_global_step())
         root.restore(core_saver.latest_checkpoint(checkpoint_directory))
         for _ in range(num_training_steps):
             # TODO(allenl): Use a Dataset and serialize/checkpoint it.
             input_value = constant_op.constant([[3.]])
             optimizer.minimize(
                 lambda: model(input_value),  # pylint: disable=cell-var-from-loop
                 global_step=root.optimizer_step)
         root.save(file_prefix=checkpoint_prefix)
         self.assertEqual((training_continuation + 1) * num_training_steps,
                          root.optimizer_step.numpy())
Example #15
0
 def testUsageGraph(self):
     """Expected usage when graph building."""
     with context.graph_mode():
         num_training_steps = 10
         checkpoint_directory = self.get_temp_dir()
         checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
         for training_continuation in range(3):
             with ops.Graph().as_default():
                 model = MyModel()
                 optimizer = adam.AdamOptimizer(0.001)
                 root = checkpointable_utils.Checkpoint(
                     optimizer=optimizer,
                     model=model,
                     global_step=training_util.get_or_create_global_step())
                 input_value = constant_op.constant([[3.]])
                 train_op = optimizer.minimize(model(input_value),
                                               global_step=root.global_step)
                 checkpoint_path = core_saver.latest_checkpoint(
                     checkpoint_directory)
                 with self.test_session(
                         graph=ops.get_default_graph()) as session:
                     status = root.restore(save_path=checkpoint_path)
                     status.initialize_or_restore(session=session)
                     if checkpoint_path is None:
                         self.assertEqual(0, training_continuation)
                         with self.assertRaises(AssertionError):
                             status.assert_consumed()
                     else:
                         status.assert_consumed()
                     for _ in range(num_training_steps):
                         session.run(train_op)
                     root.save(file_prefix=checkpoint_prefix,
                               session=session)
                     self.assertEqual(
                         (training_continuation + 1) * num_training_steps,
                         session.run(root.global_step))
                     self.assertEqual(training_continuation + 1,
                                      session.run(root.save_counter))
Example #16
0
 def testSaveRestore(self):
     model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     root_checkpointable = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model)
     input_value = constant_op.constant([[3.]])
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value))
     else:
         train_op = optimizer.minimize(model(input_value))
         # TODO(allenl): Make initialization more pleasant when graph building.
         root_checkpointable.save_counter  # pylint: disable=pointless-statement
         self.evaluate(
             checkpointable_utils.gather_initializers(root_checkpointable))
         self.evaluate(train_op)
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
     m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
     self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
     save_path = root_checkpointable.save(file_prefix=prefix)
     self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
     self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
     optimizer_variables = self.evaluate(optimizer.variables())
     self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
     # Immediate restoration
     status = root_checkpointable.restore(
         save_path=save_path).assert_consumed()
     status.run_restore_ops()
     self.assertAllEqual([42.],
                         self.evaluate(model._named_dense.variables[1]))
     self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
     self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
     if not context.executing_eagerly():
         return  # Restore-on-create is only supported when executing eagerly
     on_create_model = MyModel()
     on_create_optimizer = adam.AdamOptimizer(
         0.001,
         # Preserve beta1_power and beta2_power when appying gradients so we can
         # test that they've been restored correctly.
         beta1=1.0,
         beta2=1.0)
     on_create_root = checkpointable_utils.Checkpoint(
         optimizer=on_create_optimizer, model=on_create_model)
     # Deferred restoration
     status = on_create_root.restore(save_path=save_path)
     on_create_model(constant_op.constant([[3.]]))  # create variables
     self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
     self.assertAllEqual([42.],
                         self.evaluate(
                             on_create_model._named_dense.variables[1]))
     on_create_m_bias_slot = on_create_optimizer.get_slot(
         on_create_model._named_dense.variables[1], "m")
     # Optimizer slot variables are created when the original variable is
     # restored.
     self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
     self.assertAllEqual(optimizer_variables[2:],
                         self.evaluate(on_create_optimizer.variables()))
     dummy_var = resource_variable_ops.ResourceVariable([1.])
     on_create_optimizer.minimize(loss=dummy_var.read_value)
     status.assert_consumed()
     beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
     self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
     self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
Example #17
0
    def testMultipleGraphsNonSlotVariables(self):
        with context.graph_mode():
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            optimizer = adam.AdamOptimizer(0.001)
            # Construct a model in one graph
            first_graph = ops.Graph()
            first_session = session_lib.Session(graph=first_graph)
            with first_graph.as_default(), first_session.as_default():
                first_variable = resource_variable_ops.ResourceVariable([1.])
                first_root_checkpointable = checkpointable_utils.Checkpoint(
                    optimizer=optimizer, variable=first_variable)
                train_op = optimizer.minimize(first_variable.read_value)
                self.evaluate(
                    checkpointable_utils.gather_initializers(
                        first_root_checkpointable))
                self.evaluate(train_op)
                self.evaluate(first_variable.assign([1.]))
                self.evaluate(
                    optimizer.get_slot(var=first_variable,
                                       name="m").assign([2.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.evaluate(beta1_power.assign(3.))

            # Save and load in a second graph
            second_graph = ops.Graph()
            with second_graph.as_default(), session_lib.Session(
                    graph=second_graph):
                second_variable = resource_variable_ops.ResourceVariable([1.])
                second_root_checkpointable = checkpointable_utils.Checkpoint(
                    optimizer=optimizer, variable=second_variable)
                train_op = optimizer.minimize(second_variable.read_value)
                second_root_checkpointable.restore(
                    None).initialize_or_restore()
                self.evaluate(train_op)
                self.evaluate(second_variable.assign([4.]))
                self.evaluate(
                    optimizer.get_slot(var=second_variable,
                                       name="m").assign([5.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.evaluate(beta1_power.assign(6.))
                save_path = second_root_checkpointable.save(checkpoint_prefix)
                self.evaluate(second_variable.assign([7.]))
                self.evaluate(
                    optimizer.get_slot(var=second_variable,
                                       name="m").assign([8.]))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(6., self.evaluate(beta1_power))
                status = second_root_checkpointable.restore(save_path)
                status.assert_consumed().run_restore_ops()
                self.assertAllEqual([4.], self.evaluate(second_variable))
                self.assertAllEqual([5.],
                                    self.evaluate(
                                        optimizer.get_slot(var=second_variable,
                                                           name="m")))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(6., self.evaluate(beta1_power))

            # Check that the first graph is unmolested
            with first_graph.as_default(), first_session.as_default():
                self.assertAllEqual([1.], self.evaluate(first_variable))
                self.assertAllEqual([2.],
                                    self.evaluate(
                                        optimizer.get_slot(var=first_variable,
                                                           name="m")))
                beta1_power, _ = optimizer._get_beta_accumulators()
                self.assertAllEqual(3., self.evaluate(beta1_power))
Example #18
0
    def testDeferredSlotRestoration(self):
        checkpoint_directory = self.get_temp_dir()

        root = checkpointable.Checkpointable()
        root.var = checkpointable_utils.add_variable(root,
                                                     name="var",
                                                     initializer=0.)
        optimizer = adam.AdamOptimizer(0.1)
        if context.executing_eagerly():
            optimizer.minimize(root.var.read_value)
        else:
            train_op = optimizer.minimize(root.var)
            # Note that `optimizer` has not been added as a dependency of
            # `root`. Create a one-off grouping so that slot variables for `root.var`
            # get initialized too.
            self.evaluate(
                checkpointable_utils.gather_initializers(
                    checkpointable_utils.Checkpoint(root=root,
                                                    optimizer=optimizer)))
            self.evaluate(train_op)
        self.evaluate(state_ops.assign(root.var, 12.))
        no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
            os.path.join(checkpoint_directory, "no_slots"))
        root.optimizer = optimizer
        self.evaluate(state_ops.assign(root.var, 13.))
        self.evaluate(
            state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.))
        slots_path = checkpointable_utils.CheckpointableSaver(root).save(
            os.path.join(checkpoint_directory, "with_slots"))
        new_root = checkpointable.Checkpointable()
        # Load the slot-containing checkpoint (deferred), then immediately overwrite
        # the non-slot variable (also deferred).
        slot_status = checkpointable_utils.CheckpointableSaver(
            new_root).restore(slots_path)
        no_slot_status = checkpointable_utils.CheckpointableSaver(
            new_root).restore(no_slots_path)
        with self.assertRaises(AssertionError):
            no_slot_status.assert_consumed()
        new_root.var = checkpointable_utils.add_variable(new_root,
                                                         name="var",
                                                         shape=[])
        no_slot_status.assert_consumed()
        no_slot_status.run_restore_ops()
        self.assertEqual(12., self.evaluate(new_root.var))
        new_root.optimizer = adam.AdamOptimizer(0.1)
        with self.assertRaisesRegexp(AssertionError, "beta1_power"):
            slot_status.assert_consumed()
        self.assertEqual(12., self.evaluate(new_root.var))
        if context.executing_eagerly():
            # Slot variables are only created with restoring initializers when
            # executing eagerly.
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
        else:
            self.assertIs(
                new_root.optimizer.get_slot(name="m", var=new_root.var), None)
        if context.executing_eagerly():
            new_root.optimizer.minimize(new_root.var.read_value)
        else:
            train_op = new_root.optimizer.minimize(new_root.var)
            # The slot variable now exists; restore() didn't create it, but we should
            # now have a restore op for it.
            slot_status.run_restore_ops()
            self.assertEqual(
                14.,
                self.evaluate(
                    new_root.optimizer.get_slot(name="m", var=new_root.var)))
            self.evaluate(train_op)
        slot_status.assert_consumed()
Example #19
0
 def testNamingWithOptimizer(self):
     input_value = constant_op.constant([[3.]])
     model = MyModel()
     # A nuisance Model using the same optimizer. Its slot variables should not
     # go in the checkpoint, since it is never depended on.
     other_model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     optimizer_step = training_util.get_or_create_global_step()
     root_checkpointable = checkpointable_utils.Checkpoint(
         optimizer=optimizer, model=model, optimizer_step=optimizer_step)
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value),
                            global_step=optimizer_step)
         optimizer.minimize(lambda: other_model(input_value),
                            global_step=optimizer_step)
     else:
         train_op = optimizer.minimize(model(input_value),
                                       global_step=optimizer_step)
         optimizer.minimize(other_model(input_value),
                            global_step=optimizer_step)
         self.evaluate(
             checkpointable_utils.gather_initializers(root_checkpointable))
         self.evaluate(train_op)
     named_variables, serialized_graph = (
         checkpointable_utils._serialize_object_graph(root_checkpointable))
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
         "model/_second/kernel",
         "model/_named_dense/kernel",
         "model/_named_dense/bias",
         # non-Layer dependency of the model
         "model/_non_layer/a_variable",
         # The optimizer creates two non-slot variables
         "optimizer/beta1_power",
         "optimizer/beta2_power",
         # Slot variables
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
     )
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names
     ]
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual("global_step:0",
                      named_variables["optimizer_step" + suffix].name)
     self.assertEqual("my_model/dense_1/kernel:0",
                      named_variables["model/_second/kernel" + suffix].name)
     self.assertEqual(
         "my_model/dense/kernel:0",
         named_variables["model/_named_dense/kernel" + suffix].name)
     self.assertEqual(
         "beta1_power:0",
         named_variables["optimizer/beta1_power" + suffix].name)
     self.assertEqual(
         "beta2_power:0",
         named_variables["optimizer/beta2_power" + suffix].name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
     optimizer_node = serialized_graph.nodes[
         serialized_graph.nodes[0].children[1].node_id]
     self.assertEqual("beta1_power", optimizer_node.children[0].local_name)
     self.assertEqual(
         "beta1_power", serialized_graph.nodes[
             optimizer_node.children[0].node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].full_name)
     # We strip off the :0 suffix, as variable.name-based saving does.
     self.assertEqual(
         "my_model/dense/kernel/Adam",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(
             var=named_variables["model/_named_dense/kernel" + suffix],
             name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].checkpoint_key)
     self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
     self.assertEqual(
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].checkpoint_key)
Example #20
0
 def test_checkpointing_not_implemented(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork())
   with self.assertRaises(NotImplementedError):
     checkpoint.save(checkpoint_directory)