def test_checkpoint_comparison(self):
        saveable_state = SaveableState(5.)
        trackable_state = TrackableState(10.)

        # First test that SaveableState and TrackableState are equivalent by
        # saving a checkpoint with both objects and swapping values.

        self.assertEqual(5, self.evaluate(saveable_state.read()))
        self.assertEqual(10, self.evaluate(trackable_state.read()))

        ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
        checkpoint.Checkpoint(a=saveable_state,
                              b=trackable_state).write(ckpt_path)

        status = checkpoint.Checkpoint(b=saveable_state,
                                       a=trackable_state).read(ckpt_path)
        status.assert_consumed()

        self.assertEqual(10, self.evaluate(saveable_state.read()))
        self.assertEqual(5, self.evaluate(trackable_state.read()))

        # Test that the converted SaveableState is compatible with the checkpoint
        # saved above.
        to_convert = SaveableState(0.0)

        converted_saveable_state = (
            saveable_object_util.SaveableCompatibilityConverter(to_convert))

        checkpoint.Checkpoint(a=converted_saveable_state).read(
            ckpt_path).assert_existing_objects_matched().expect_partial()
        self.assertEqual(5, self.evaluate(to_convert.read()))

        checkpoint.Checkpoint(b=converted_saveable_state).read(
            ckpt_path).assert_existing_objects_matched().expect_partial()
        self.assertEqual(10, self.evaluate(to_convert.read()))
    def testRestoreOrInitialize(self):
        directory = self.get_temp_dir()

        # Create a checkpoint for initializing.
        init_prefix = os.path.join(directory, "init")
        init_v = variables.Variable(2.0)
        init_ckpt = util.Checkpoint(v=init_v)
        self.evaluate(init_v.initializer)
        init_path = init_ckpt.save(init_prefix)

        # Create the checkpoint manager.
        ckpt_dir = os.path.join(directory, "ckpt")
        v = variables.Variable(1.0)
        checkpoint = util.Checkpoint(v=v)
        manager = checkpoint_management.CheckpointManager(
            checkpoint,
            ckpt_dir,
            max_to_keep=None,
            init_fn=lambda: checkpoint.restore(init_path).run_restore_ops())
        self.evaluate(v.initializer)

        # First call should call `init_fn`.
        self.assertIsNone(manager.restore_or_initialize())
        self.assertEqual(2.0, self.evaluate(v))

        # Save a checkpoint and second call should restore from the checkpoints.
        manager.save()
        self.assertIsNotNone(manager.restore_or_initialize())
    def test_checkpointing(self):
        self.skipTest(
            "b/216201668: revisit parallel device and checkpointing.")

        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        different_values = self.device.pack(
            [constant_op.constant(-1.),
             constant_op.constant(3.)])
        with self.device:
            v = variables.Variable(different_values)
            checkpoint = tracking.Checkpoint(v=v)
        save_path = checkpoint.save(prefix)
        with self.device:
            v.assign(constant_op.constant(0.))
        checkpoint.restore(save_path).assert_consumed()
        with self.device:
            outputs = self.device.unpack(v)
        self.assertAllClose([-1., 3.], outputs)

        with self.device:
            restore_on_create = tracking.Checkpoint()
            restore_on_create.restore(save_path)
            restore_on_create.v = variables.Variable(0.)
            outputs = self.device.unpack(restore_on_create.v)
        self.assertAllClose([-1., 3.], outputs)

        # Changing the number of devices / restoring into a single-device copy is OK
        single_device = tracking.Checkpoint(v=variables.Variable(0.))
        status = single_device.restore(save_path)
        status.assert_existing_objects_matched()
        self.assertAllClose(-1., single_device.v)
        with self.assertRaisesRegex(AssertionError, "parallel_component_1"):
            # There are parts of the variable that aren't restored into a
            # single-device copy.
            status.assert_consumed()
示例#4
0
  def test_delayed_restore(self):
    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
    model = autotrackable.AutoTrackable()
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
        variables_lib.Variable([2]),
        variables_lib.Variable([3])
    ]
    model.s = sharded_variable.ShardedVariable(variables)
    cp = util.Checkpoint(model=model)
    cp.write(fname)

    model2 = autotrackable.AutoTrackable()
    cp2 = util.Checkpoint(model=model2)
    cp2.restore(fname)
    variables2 = [
        variables_lib.Variable([0]),
        variables_lib.Variable([0]),
        variables_lib.Variable([0]),
        variables_lib.Variable([0])
    ]
    model2.s = sharded_variable.ShardedVariable(variables2)
    self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0])
    self.assertAllEqual(self.evaluate(model2.s.variables[1]), [1])
    self.assertAllEqual(self.evaluate(model2.s.variables[2]), [2])
    self.assertAllEqual(self.evaluate(model2.s.variables[3]), [3])
示例#5
0
    def testSaveRestoreNumpyState(self):
        directory = self.get_temp_dir()
        prefix = os.path.join(directory, "ckpt")
        save_state = _NumpyState()
        saver = util.Checkpoint(numpy=save_state)
        save_state.a = numpy.ones([2, 2])
        save_state.b = numpy.ones([2, 2])
        save_state.b = numpy.zeros([2, 2])
        save_state.c = numpy.int64(3)
        self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
        self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
        self.assertEqual(3, save_state.c)
        first_save_path = saver.save(prefix)
        save_state.a[1, 1] = 2.
        save_state.c = numpy.int64(4)
        second_save_path = saver.save(prefix)

        load_state = _NumpyState()
        loader = util.Checkpoint(numpy=load_state)
        loader.restore(first_save_path).initialize_or_restore()
        self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
        self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
        self.assertEqual(3, load_state.c)
        load_state.a[0, 0] = 42.
        self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
        loader.restore(first_save_path).run_restore_ops()
        self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
        loader.restore(second_save_path).run_restore_ops()
        self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
        self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
        self.assertEqual(4, load_state.c)
示例#6
0
 def testAssertConsumedWithUnusedPythonState(self):
     has_config = base.Trackable()
     has_config.get_config = lambda: {}
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Trackable())
     restored.restore(save_path).assert_consumed()
  def testDeferredSlotRestoration(self):
    checkpoint_directory = self.get_temp_dir()

    root = trackable_utils.Checkpoint()
    root.var = trackable_utils.add_variable(
        root, name="var", initializer=0.)
    optimizer = adam.AdamOptimizer(0.1)
    if context.executing_eagerly():
      optimizer.minimize(root.var.read_value)
    else:
      train_op = optimizer.minimize(root.var)
      # Note that `optimizer` has not been added as a dependency of
      # `root`. Create a one-off grouping so that slot variables for `root.var`
      # get initialized too.
      self.evaluate(trackable_utils.gather_initializers(
          trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
      self.evaluate(train_op)
    self.evaluate(state_ops.assign(root.var, 12.))
    no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
    root.optimizer = optimizer
    self.evaluate(state_ops.assign(root.var, 13.))
    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                   14.))
    slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
    new_root = trackable_utils.Checkpoint()
    # Load the slot-containing checkpoint (deferred), then immediately overwrite
    # the non-slot variable (also deferred).
    slot_status = new_root.restore(slots_path)
    no_slot_status = new_root.restore(no_slots_path)
    with self.assertRaises(AssertionError):
      no_slot_status.assert_consumed()
    new_root.var = trackable_utils.add_variable(
        new_root, name="var", shape=[])
    no_slot_status.assert_consumed()
    no_slot_status.run_restore_ops()
    self.assertEqual(12., self.evaluate(new_root.var))
    new_root.optimizer = adam.AdamOptimizer(0.1)
    slot_status.assert_existing_objects_matched()
    with self.assertRaisesRegex(AssertionError, "beta1_power"):
      slot_status.assert_consumed()
    self.assertEqual(12., self.evaluate(new_root.var))
    if context.executing_eagerly():
      # Slot variables are only created with restoring initializers when
      # executing eagerly.
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
    else:
      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
                    None)
    if context.executing_eagerly():
      new_root.optimizer.minimize(new_root.var.read_value)
    else:
      train_op = new_root.optimizer.minimize(new_root.var)
      # The slot variable now exists; restore() didn't create it, but we should
      # now have a restore op for it.
      slot_status.run_restore_ops()
      self.assertEqual(14., self.evaluate(
          new_root.optimizer.get_slot(name="m", var=new_root.var)))
      self.evaluate(train_op)
    slot_status.assert_consumed()
示例#8
0
  def test_save_restore_different_partitions(self):
    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
        variables_lib.Variable([2]),
        variables_lib.Variable([3])
    ]
    s = sharded_variable.ShardedVariable(variables, name='s')

    cp = util.Checkpoint(s=s)
    cp.write(fname)

    variables2 = [variables_lib.Variable([0, 0, 0, 0])]
    s2 = sharded_variable.ShardedVariable(variables2, name='s')

    # Restore from 4 partitions into 1.
    cp2 = util.Checkpoint(s=s2)
    cp2.restore(fname)
    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3])

    self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20]))
    cp2.write(fname)

    # Restore 1 partition into 4.
    cp.restore(fname)
    self.assertEqual(self.evaluate(cp.s.variables[0]), [5])
    self.assertEqual(self.evaluate(cp.s.variables[1]), [10])
    self.assertEqual(self.evaluate(cp.s.variables[2]), [15])
    self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
    def testCheckpoint(self, delayed, restore_shards):

        if test_util.is_xla_enabled() and not delayed and restore_shards == 4:
            self.skipTest(
                "TODO(b/202760274): Would raise an error that is to be "
                "investigated.")

        def make_variable(name, shape, dtype, initializer):
            initial_value = functools.partial(initializer, shape, dtype=dtype)
            return variables.Variable(name=name,
                                      initial_value=initial_value,
                                      shape=shape,
                                      dtype=dtype)

        class Model(autotrackable.AutoTrackable):
            def build(self):
                self.w = self._add_variable_with_custom_getter(
                    "w",
                    shape=(4, ),
                    initializer=init_ops_v2.Ones(),
                    getter=make_variable)

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
        ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint")

        with strategy.scope():
            model1 = Model()
            model1.build()
            self.assertIsInstance(model1.w, sharded_variable.ShardedVariable)
            self.assertLen(model1.w.variables, 2)
            model1.w.assign([1., 2., 3., 4.])

            cp1 = tracking_util.Checkpoint(model=model1)
            cp1.write(ckpt_dir)

        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver,
            sharded_variable.FixedShardsPartitioner(restore_shards))

        with strategy.scope():
            model2 = Model()
            cp2 = tracking_util.Checkpoint(model=model2)
            if delayed:
                cp2.restore(ckpt_dir)
                model2.build()
            else:
                model2.build()
                cp2.restore(ckpt_dir)
            self.assertIsInstance(model2.w, sharded_variable.ShardedVariable)
            self.assertLen(model2.w.variables, restore_shards)
            if restore_shards == 2:
                self.assertAllEqual(model2.w.variables[0], [1., 2.])
                self.assertAllEqual(model2.w.variables[1], [3., 4.])
            elif restore_shards == 4:
                self.assertAllEqual(model2.w.variables[0], [1.])
                self.assertAllEqual(model2.w.variables[1], [2.])
                self.assertAllEqual(model2.w.variables[2], [3.])
                self.assertAllEqual(model2.w.variables[3], [4.])
示例#10
0
    def test_forward_compatibility(self):
        class _MultiSpecSaveable(saveable_object.SaveableObject):
            def __init__(self, obj, name):
                self.obj = obj
                specs = [
                    saveable_object.SaveSpec(obj.a, "", name + "-a"),
                    saveable_object.SaveSpec(obj.b, "", name + "-b")
                ]
                super(_MultiSpecSaveable, self).__init__(None, specs, name)

            def restore(self, restored_tensors, restored_shapes):
                del restored_shapes  # Unused.
                self.obj.a.assign(restored_tensors[0])
                self.obj.b.assign(restored_tensors[1])

        class DeprecatedTrackable(base.Trackable):
            def __init__(self):
                self.a = variables.Variable(1.0)
                self.b = variables.Variable(2.0)

            def _gather_saveables_for_checkpoint(self):
                return {"foo": lambda name: _MultiSpecSaveable(self, name)}

        @saveable_compat.legacy_saveable_name("foo")
        class NewTrackable(base.Trackable):
            def __init__(self):
                self.a = variables.Variable(3.0)
                self.b = variables.Variable(4.0)

            def _serialize_to_tensors(self):
                return {"-a": self.a, "-b": self.b}

            def _restore_from_tensors(self, restored_tensors):
                self.a.assign(restored_tensors["-a"])
                self.b.assign(restored_tensors["-b"])

        new = NewTrackable()

        # Test with the checkpoint conversion flag disabled (normal compatibility).
        saveable_compat.force_checkpoint_conversion(False)
        checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt")
        checkpoint.Checkpoint(new).write(checkpoint_path)

        dep = DeprecatedTrackable()
        checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed()
        self.assertEqual(3, self.evaluate(dep.a))
        self.assertEqual(4, self.evaluate(dep.b))

        # Now test with the checkpoint conversion flag enabled (forward compat).
        # The deprecated object will try to load from the new checkpoint.
        saveable_compat.force_checkpoint_conversion()
        checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt2")
        checkpoint.Checkpoint(new).write(checkpoint_path)

        dep = DeprecatedTrackable()
        checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed()
        self.assertEqual(3, self.evaluate(dep.a))
        self.assertEqual(4, self.evaluate(dep.b))
    def test_checkpoint_restore_before_variable_creation(self):
        self.skip_if_oss()

        class TestModule(module.Module):
            def __init__(self, initializer, rows):
                self._initializer = initializer
                self._rows = rows

                table = tpu_embedding_v2_utils.TableConfig(
                    vocabulary_size=self._rows,
                    dim=4,
                    initializer=self._initializer,
                    combiner='sum',
                    name='table')
                feature_config = (tpu_embedding_v2_utils.FeatureConfig(
                    table=table, name='feature'), )
                optimizer = tpu_embedding_v2_utils.SGD()

                self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
                    feature_config, optimizer)

            def create_embedding(self):
                # We aren't training so batch_size here doesn't matter.
                self.tpu_embedding.build(64)

        strategy = self._get_strategy()
        with strategy.scope():
            module1 = TestModule(init_ops_v2.Ones(),
                                 strategy.num_replicas_in_sync * 2)
            module1.create_embedding()

        checkpoint = util.Checkpoint(test_module=module1)
        checkpoint.save(self._get_tmpdir('restore_before_create', 'save'))

        # Reinitialize the tpu
        strategy = self._get_strategy()

        with strategy.scope():
            module2 = TestModule(init_ops_v2.Zeros(),
                                 strategy.num_replicas_in_sync * 2)

        checkpoint = util.Checkpoint(test_module=module2)
        checkpoint.restore(self._get_tmpdir('restore_before_create', 'save-1'))

        with strategy.scope():
            module2.create_embedding()

        def get_values(mid):
            return mid._variables['table']['parameters'].variables[0].numpy()

        self.assertAllClose(np.ones((strategy.num_replicas_in_sync * 2, 4)),
                            get_values(module2.tpu_embedding))

        # Fetch the values from the TPU to check that they are the same.
        module2.tpu_embedding._retrieve_variables()

        self.assertAllClose(np.ones((strategy.num_replicas_in_sync * 2, 4)),
                            get_values(module2.tpu_embedding))
  def testMultipleGraphsNonSlotVariables(self):
    with context.graph_mode():
      checkpoint_directory = self.get_temp_dir()
      checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
      optimizer = adam.AdamOptimizer(0.001)
      # Construct a model in one graph
      first_graph = ops.Graph()
      first_session = session_lib.Session(graph=first_graph)
      with first_graph.as_default(), first_session.as_default():
        first_variable = resource_variable_ops.ResourceVariable([1.])
        first_root_trackable = trackable_utils.Checkpoint(
            optimizer=optimizer, variable=first_variable)
        train_op = optimizer.minimize(first_variable.read_value)
        self.evaluate(trackable_utils.gather_initializers(
            first_root_trackable))
        self.evaluate(train_op)
        self.evaluate(first_variable.assign([1.]))
        self.evaluate(optimizer.get_slot(
            var=first_variable, name="m").assign([2.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.evaluate(beta1_power.assign(3.))

      # Save and load in a second graph
      second_graph = ops.Graph()
      with second_graph.as_default(), session_lib.Session(graph=second_graph):
        second_variable = resource_variable_ops.ResourceVariable([1.])
        second_root_trackable = trackable_utils.Checkpoint(
            optimizer=optimizer, variable=second_variable)
        train_op = optimizer.minimize(second_variable.read_value)
        second_root_trackable.restore(None).initialize_or_restore()
        self.evaluate(train_op)
        self.evaluate(second_variable.assign([4.]))
        self.evaluate(optimizer.get_slot(
            var=second_variable, name="m").assign([5.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.evaluate(beta1_power.assign(6.))
        save_path = second_root_trackable.save(checkpoint_prefix)
        self.evaluate(second_variable.assign([7.]))
        self.evaluate(optimizer.get_slot(
            var=second_variable, name="m").assign([8.]))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.assertAllEqual(6., self.evaluate(beta1_power))
        status = second_root_trackable.restore(save_path)
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([4.], self.evaluate(second_variable))
        self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
            var=second_variable, name="m")))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.assertAllEqual(6., self.evaluate(beta1_power))

      # Check that the first graph is unmolested
      with first_graph.as_default(), first_session.as_default():
        self.assertAllEqual([1.], self.evaluate(first_variable))
        self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
            var=first_variable, name="m")))
        beta1_power, _ = optimizer._get_beta_accumulators()
        self.assertAllEqual(3., self.evaluate(beta1_power))
示例#13
0
    def testDocstringExample(self):
        arrays = _NumpyState()
        checkpoint = util.Checkpoint(numpy_arrays=arrays)
        arrays.x = numpy.zeros([3, 4])
        save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
        arrays.x[1, 1] = 4.
        checkpoint.restore(save_path)
        self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)

        second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState())
        second_checkpoint.restore(save_path)
        self.assertAllEqual(numpy.zeros([3, 4]),
                            second_checkpoint.numpy_arrays.x)
    def test_trackable_save_restore(self):
        def _templated():
            v = variable_scope.get_variable(
                "v",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            v2 = variable_scope.get_variable(
                "v2",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            manual = _ManualScope()
            return v, v + 1., v2, manual, manual()

        save_template = template.make_template("s1", _templated)
        v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
        six.assertCountEqual(self, [
            id(obj) for obj in
            [v1_save, v2_save, manual_scope, manual_scope_v, save_template]
        ], [id(obj) for obj in trackable_utils.list_objects(save_template)])
        self.assertDictEqual({"in_manual_scope": manual_scope_v},
                             manual_scope._trackable_children())
        optimizer = adam.AdamOptimizer(0.0)
        save_root = trackable_utils.Checkpoint(my_template=save_template,
                                               optimizer=optimizer)
        optimizer.minimize(v1_save.read_value)
        self.evaluate([v.initializer for v in save_template.variables])
        self.evaluate([v.initializer for v in optimizer.variables()])
        self.evaluate(v1_save.assign([12.]))
        self.evaluate(v2_save.assign([14.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _templated)
        load_optimizer = adam.AdamOptimizer(0.0)
        load_root = trackable_utils.Checkpoint(my_template=load_template,
                                               optimizer=load_optimizer)
        status = load_root.restore(save_path)
        var, var_plus_one, var2, _, _ = load_template()
        load_optimizer.minimize(var.read_value)
        self.assertEqual(3, len(load_template._trackable_children()))
        self.assertEqual(set(["v", "v2", "ManualScope"]),
                         load_template._trackable_children().keys())
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([12.], self.evaluate(var))
        self.assertAllEqual([13.], self.evaluate(var_plus_one))
        self.assertAllEqual([14.], self.evaluate(var2))
示例#15
0
 def testAssertConsumedFailsWithUsedPythonState(self):
     has_config = base.Trackable()
     attributes = {
         "foo_attr":
         functools.partial(base.PythonStringStateSaveable,
                           state_callback=lambda: "",
                           restore_callback=lambda x: None)
     }
     has_config._gather_saveables_for_checkpoint = lambda: attributes
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Trackable())
     status = restored.restore(save_path)
     with self.assertRaisesRegex(AssertionError, "foo_attr"):
         status.assert_consumed()
  def test_spmd_model_checkpointing(self):

    class LinearModel(module.Module):

      def __init__(self, w):
        super(LinearModel, self).__init__()
        self.w = variables.Variable(w)

      def __call__(self, x):
        return math_ops.matmul(x, self.w)

      def change_weights_op(self, w_new):
        return self.w.assign(w_new)

    batch_size = 32
    num_feature_in = 16
    num_feature_out = 8
    w1 = random_ops.random_uniform((num_feature_in, num_feature_out),
                                   dtype=dtypes.float32)
    w2 = random_ops.random_uniform((num_feature_in, num_feature_out),
                                   dtype=dtypes.float32)
    x = random_ops.random_uniform((batch_size, num_feature_in),
                                  dtype=dtypes.float32)

    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      model = LinearModel(w1)

    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = util.Checkpoint(model=model)

    @def_function.function
    def step_fn(x):
      x = strategy.experimental_split_to_logical_devices(x, [1, 2])
      return model(x)

    with self.cached_session() as sess:
      self.evaluate(variables.global_variables_initializer())
      checkpoint.save(file_prefix=checkpoint_prefix)

      self.evaluate(model.change_weights_op(w2))
      result = strategy.run(step_fn, args=(x,))
      self.assertAllClose(
          math_ops.matmul(x, w2) * num_replicas,
          self.evaluate(strategy.reduce("SUM", result, axis=None)),
          rtol=5e-3,
          atol=5e-3)

      status = checkpoint.restore(
          checkpoint_management.latest_checkpoint(checkpoint_dir))
      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
      status.assert_consumed()
      status.assert_existing_objects_matched()
      result = strategy.run(step_fn, args=(x,))
      self.assertAllClose(
          math_ops.matmul(x, w1) * num_replicas,
          self.evaluate(strategy.reduce("SUM", result, axis=None)),
          rtol=5e-3,
          atol=5e-3)
示例#17
0
 def test_table(self):
     initializer = lookup_ops.TextFileInitializer(
         self._vocab_path,
         key_dtype=dtypes.string,
         key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
         value_dtype=dtypes.int64,
         value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
     root = checkpoint.Checkpoint(
         table=lookup_ops.HashTable(initializer, default_value=-1))
     root.table_user = def_function.function(
         root.table.lookup,
         input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
     self.assertEqual(
         2, self.evaluate(root.table_user(constant_op.constant("gamma"))))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(root, save_dir)
     file_io.delete_file(self._vocab_path)
     self.assertAllClose({"output_0": [2, 0]},
                         _import_and_infer(save_dir,
                                           {"keys": ["gamma", "alpha"]}))
     second_dir = os.path.join(self.get_temp_dir(), "second_dir")
     # Asset paths should track the location the SavedModel is loaded from.
     file_io.rename(save_dir, second_dir)
     self.assertAllClose({"output_0": [2, 1]},
                         _import_and_infer(second_dir,
                                           {"keys": ["gamma", "beta"]}))
示例#18
0
 def testNestedLists(self):
     a = autotrackable.AutoTrackable()
     a.l = []
     b = autotrackable.AutoTrackable()
     a.l.append([b])
     c = autotrackable.AutoTrackable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = autotrackable.AutoTrackable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = autotrackable.AutoTrackable()
     f = autotrackable.AutoTrackable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegex(ValueError, "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
示例#19
0
  def test_metrics_v2(self):
    api_label = util._CHECKPOINT_V2
    prefix = os.path.join(self.get_temp_dir(), 'ckpt')

    with context.eager_mode():
      ckpt = util.Checkpoint(v=variables_lib.Variable(1.))
      self.assertEqual(self._get_time_saved(api_label), 0.0)
      self.assertEqual(self._get_write_histogram_proto(api_label).num, 0.0)

      for i in range(3):
        time_saved = self._get_time_saved(api_label)
        time.sleep(1)
        ckpt_path = ckpt.write(file_prefix=prefix)
        filesize = util._get_checkpoint_size(ckpt_path)
        self.assertEqual(self._get_checkpoint_size(api_label, filesize), i + 1)
        self.assertGreater(self._get_time_saved(api_label), time_saved)

    self.assertEqual(self._get_write_histogram_proto(api_label).num, 3.0)
    self.assertEqual(self._get_read_histogram_proto(api_label).num, 0.0)

    time_saved = self._get_time_saved(api_label)
    with context.eager_mode():
      ckpt.restore(ckpt_path)
    self.assertEqual(self._get_read_histogram_proto(api_label).num, 1.0)
    # Restoring a checkpoint in the same "job" does not increase training time
    # saved.
    self.assertEqual(self._get_time_saved(api_label), time_saved)
示例#20
0
  def test_lookup_table_compatibility(self):
    table_module = generate_checkpoint.TableModule()
    ckpt = checkpoint.Checkpoint(table_module)
    checkpoint_directory = self.get_temp_dir()
    checkpoint_path = os.path.join(checkpoint_directory, "ckpt")
    ckpt.write(checkpoint_path)

    # Ensure that the checkpoint metadata and keys are the same.
    legacy_metadata = checkpoint.object_metadata(_LEGACY_TABLE_CHECKPOINT_PATH)
    metadata = checkpoint.object_metadata(checkpoint_path)

    def _get_table_node(object_metadata):
      for child in object_metadata.nodes[0].children:
        if child.local_name == "lookup_table":
          return object_metadata.nodes[child.node_id]

    table_proto = _get_table_node(metadata)
    legacy_table_proto = _get_table_node(legacy_metadata)
    self.assertAllEqual(
        [table_proto.attributes[0].name,
         table_proto.attributes[0].checkpoint_key],
        [legacy_table_proto.attributes[0].name,
         legacy_table_proto.attributes[0].checkpoint_key])

    legacy_reader = checkpoint_utils.load_checkpoint(
        _LEGACY_TABLE_CHECKPOINT_PATH)
    reader = checkpoint_utils.load_checkpoint(checkpoint_path)
    self.assertEqual(
        legacy_reader.get_variable_to_shape_map().keys(),
        reader.get_variable_to_shape_map().keys())

    # Ensure that previous checkpoint can be loaded into current table.
    ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed()
    def test_training_loop(self):
        self.skipTest("b/216201668: revisit parallel device and checkpointing")
        for _ in range(5):
            layer = _Dense(5)
            checkpoint = tracking.Checkpoint(layer=layer)
            manager = checkpoint_management.CheckpointManager(
                checkpoint, directory=self.get_temp_dir(), max_to_keep=5)
            manager.restore_or_initialize()

            for _ in range(10):
                x = self.device.pack([
                    constant_op.constant([[-0.5]]),
                    constant_op.constant([[0.5]])
                ])
                with self.device:
                    with backprop.GradientTape() as tape:
                        y = layer(x)
                        loss = (y - math_ops.range(5.))**2.
                    parameters = layer.trainable_variables
                    unreduced_gradients = tape.gradient(loss, parameters)
                    reduced_gradients = _collective_sum(
                        unreduced_gradients,
                        num_replicas=len(self.device.components))
                    for grad, param in zip(reduced_gradients, parameters):
                        param.assign_sub(0.01 * grad)

                manager.save()
示例#22
0
 def testSaveRestoreMultipleIterator(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     dataset = dataset_ops.Dataset.from_tensor_slices(
         [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
     dataset = dataset.map(math_ops.square).batch(2)
     iterator_1 = iter(dataset)
     get_next_1 = iterator_1.get_next
     iterator_2 = iter(dataset)
     get_next_2 = iterator_2.get_next
     dataset_2 = dataset_ops.Dataset.range(10)
     iterator_3 = iter(dataset_2)
     get_next_3 = iterator_3.get_next
     checkpoint = trackable_utils.Checkpoint(iterator_1=iterator_1,
                                             iterator_2=iterator_2,
                                             iterator_3=iterator_3)
     self.assertAllEqual([1, 4], get_next_1())
     self.assertAllEqual(0, get_next_3())
     self.assertAllEqual(1, get_next_3())
     self.assertAllEqual(2, get_next_3())
     save_path = checkpoint.save(checkpoint_prefix)
     self.assertAllEqual([1, 4], get_next_2())
     self.assertAllEqual([9, 16], get_next_2())
     self.assertAllEqual(3, get_next_3())
     checkpoint.restore(save_path).run_restore_ops()
     self.assertAllEqual([9, 16], get_next_1())
     self.assertAllEqual([1, 4], get_next_2())
     self.assertAllEqual(3, get_next_3())
示例#23
0
 def test_signature_attribute_reserved(self):
     root = checkpoint.Checkpoint(signatures=variables.Variable(1.))
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     with self.assertRaisesRegex(ValueError, "del obj.signatures"):
         save.save(root, save_dir)
     del root.signatures
     save.save(root, save_dir)
示例#24
0
def save(self,
         path,
         compression=None,
         shard_func=None,
         checkpoint_args=None):
  """Implements the save function and checkpoint functionality."""
  if context.executing_eagerly() and checkpoint_args:
    save_dataset = _SaveDataset(self, path, shard_func, compression)
    save_iterator = iter(save_dataset)

    if "checkpoint" in checkpoint_args:
      raise ValueError(
          "'Invalid `checkpoint_args`. `checkpoint_args` are not allowed "
          "to include 'checkpoint'."
      )
    checkpoint = checkpoint_lib.Checkpoint(iterator=save_iterator)
    checkpoint_args["checkpoint"] = checkpoint
    manager = checkpoint_management.CheckpointManager(**checkpoint_args)
    checkpoint.restore(manager.latest_checkpoint)

    for _ in enumerate(save_iterator):
      if "step_counter" in checkpoint_args:
        checkpoint_args["step_counter"].assign_add(delta=1)
      manager.save(check_interval=True)
  else:
    dataset, shard_func, use_shard_func, path = set_save_dataset_attributes(
        self, shard_func, path)
    ged_ops.save_dataset(
        dataset._variant_tensor,   # pylint: disable=protected-access
        path=path,
        shard_func_other_args=shard_func.captured_inputs,
        compression=compression,
        shard_func=shard_func,
        use_shard_func=use_shard_func)
    def testCheckpointing(self, distribution, synchronization, aggregation,
                          mode):

        if (isinstance(
                distribution,
                collective_all_reduce_strategy.CollectiveAllReduceStrategy)
                and mode == "graph"):
            self.skipTest(
                "MWMS combinations tests do not work well in graph mode.")

        with distribution.scope():
            v = variables_lib.Variable(constant_op.constant([1., 2., 3., 4]),
                                       synchronization=synchronization,
                                       aggregation=aggregation)

        self.evaluate(v.initializer)
        before_save = self.evaluate(v.read_value())

        # Save random weights into checkpoint.
        checkpoint = trackable_utils.Checkpoint(v=v)
        prefix = os.path.join(self.get_temp_dir(), "ckpt")
        with self.test_session():
            save_path = checkpoint.save(prefix)

        # Assign inverted value.
        self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
        after_assign = self.evaluate(v.read_value())
        self.assertNotAllClose(before_save, after_assign)

        # Restore from the checkpoint.
        with self.test_session():
            checkpoint.restore(save_path).assert_consumed().run_restore_ops()
        after_restore = self.evaluate(v)
        self.assertAllClose(before_save, after_restore)
示例#26
0
    def testStatefulExternalPolicy(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        dataset = dataset_ops.Dataset.range(4)

        def fn(x):
            return x * x

        dataset = dataset.map(
            lambda x: script_ops.eager_py_func(fn, [x], dtypes.int64))

        options = options_lib.Options()
        options.experimental_external_state_policy = (
            options_lib.ExternalStatePolicy.WARN)
        dataset = dataset.with_options(options)

        iterator = iter(dataset)
        get_next = iterator.get_next
        checkpoint = trackable_utils.Checkpoint(iterator=iterator)
        self.assertEqual(0, get_next().numpy())
        self.assertEqual(1, get_next().numpy())
        save_path = checkpoint.save(checkpoint_prefix)
        self.assertEqual(4, get_next().numpy())
        self.assertEqual(9, get_next().numpy())
        checkpoint.restore(save_path).run_restore_ops()
        self.assertEqual(4, get_next().numpy())
        self.assertEqual(9, get_next().numpy())
        with self.assertRaises(errors.OutOfRangeError):
            get_next()
示例#27
0
  def test_registered_saver_is_called_before_save_after_load(self):
    if not context.executing_eagerly():
      self.skipTest("This test must run under eager mode.")

    class RestoreClass(autotrackable.AutoTrackable):
      pass
    def save_fn(trackables, file_prefix):
      del trackables  # Unused.
      # Check that directory is empty
      files = gfile.ListDirectory(os.path.dirname(file_prefix.numpy()))
      self.assertEmpty(files)

    def restore_fn(trackables, merged_prefix):
      del merged_prefix  # Unused.
      root = next(trackables.values())
      self.assertEqual(root.v.numpy(), 123)

    registration.register_checkpoint_saver(
        name="OptionalRestore",
        predicate=lambda x: isinstance(x, RestoreClass),
        save_fn=save_fn,
        restore_fn=restore_fn)

    root = RestoreClass()
    root.v = variables.Variable(123.0)

    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
    util.Checkpoint(root).write(ckpt_path)
示例#28
0
  def test_non_strict_predicate(self):
    class NonStrictPredicateClass(autotrackable.AutoTrackable):
      pass
    registration.register_checkpoint_saver(
        name="NonStrictPredicate",
        predicate=lambda x: isinstance(x, NonStrictPredicateClass),
        save_fn=lambda **kwargs: [],
        restore_fn=lambda **kwargs: None,
        strict_predicate_restore=False)

    root = NonStrictPredicateClass()
    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
    util.Checkpoint(root).write(ckpt_path)

    root2 = autotrackable.AutoTrackable()
    # This should run without throwing an error.
    util.Checkpoint(root2).read(ckpt_path)
示例#29
0
  def test_strict_predicate(self):
    class StrictPredicateClass(autotrackable.AutoTrackable):
      pass
    registration.register_checkpoint_saver(
        name="StrictPredicate",
        predicate=lambda x: isinstance(x, StrictPredicateClass),
        save_fn=lambda **kwargs: [],
        restore_fn=lambda **kwargs: None,
        strict_predicate_restore=True)

    root = StrictPredicateClass()
    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
    util.Checkpoint(root).write(ckpt_path)

    root2 = autotrackable.AutoTrackable()
    with self.assertRaisesRegex(ValueError, "saver cannot be used"):
      util.Checkpoint(root2).read(ckpt_path)
示例#30
0
 def testDistStratRestore(self, strat1, strat2, jit_replica_fn):
   """Tests checkpointing and restoring (to possibly different #replicas)."""
   if strat2 is None:
     strat2 = strat1
   strat1_name = type(strat1).__name__
   strat2_name = type(strat2).__name__
   if "Default" in strat1_name or "Default" in strat2_name:
     self.skipTest(
         "We don't guarantee consistency between strategy and no-strategy.")
   if ("TPU" in strat1_name or "TPU" in strat2_name) and not jit_replica_fn:
     self.skipTest(
         "TPUStrategy requires the replica function (the function passed to "
         "strategy.run) to be decorated with tf.function")
   coord1 = None
   if "ParameterServer" in strat1_name:
     coord1 = coordinator_lib.ClusterCoordinator(strat1)
   coord2 = None
   if "ParameterServer" in strat2_name:
     coord2 = coordinator_lib.ClusterCoordinator(strat2)
   fname = os.path.join(self.get_temp_dir(), "checkpoint")
   def uniform(strat, coord, g):
     def f():
       return g.uniform_full_int([3], dtype=dtypes.int32)
     replica_fn = def_function.function(f) if jit_replica_fn else f
     result = run_on_strategy(replica_fn, strat, coord)
     return strat.experimental_local_results(result)
   with strat1.scope():
     g1 = rng.Generator.from_seed(1)
   with strat2.scope():
     g2 = rng.Generator.from_seed(10)
   cp1 = tracking_util.Checkpoint(g=g1)
   cp2 = tracking_util.Checkpoint(g=g2)
   def write_restore_compare():
     cp1.write(fname)
     r1 = uniform(strat1, coord1, g1)
     cp2.restore(fname)
     r2 = uniform(strat2, coord2, g2)
     # Tests that overlapping replicas are properly restored.
     n1 = get_num_local_replicas(strat1)
     n2 = get_num_local_replicas(strat2)
     n = min(n1, n2)
     self.assertAllEqual(r1[:n], r2[:n])
   # Run multiple times so that cp1.write is called in various RNG states
   for _ in range(2):
     write_restore_compare()