Пример #1
0
  def test_lookup_table_compatibility(self):
    table_module = generate_checkpoint.TableModule()
    ckpt = checkpoint.Checkpoint(table_module)
    checkpoint_directory = self.get_temp_dir()
    checkpoint_path = os.path.join(checkpoint_directory, "ckpt")
    ckpt.write(checkpoint_path)

    # Ensure that the checkpoint metadata and keys are the same.
    legacy_metadata = checkpoint.object_metadata(_LEGACY_TABLE_CHECKPOINT_PATH)
    metadata = checkpoint.object_metadata(checkpoint_path)

    def _get_table_node(object_metadata):
      for child in object_metadata.nodes[0].children:
        if child.local_name == "lookup_table":
          return object_metadata.nodes[child.node_id]

    table_proto = _get_table_node(metadata)
    legacy_table_proto = _get_table_node(legacy_metadata)
    self.assertAllEqual(
        [table_proto.attributes[0].name,
         table_proto.attributes[0].checkpoint_key],
        [legacy_table_proto.attributes[0].name,
         legacy_table_proto.attributes[0].checkpoint_key])

    legacy_reader = checkpoint_utils.load_checkpoint(
        _LEGACY_TABLE_CHECKPOINT_PATH)
    reader = checkpoint_utils.load_checkpoint(checkpoint_path)
    self.assertEqual(
        legacy_reader.get_variable_to_shape_map().keys(),
        reader.get_variable_to_shape_map().keys())

    # Ensure that previous checkpoint can be loaded into current table.
    ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed()
    def _assert_checkpoint(self,
                           expected_global_step,
                           expected_weights=None,
                           expected_bias=None):
        """Assert the values and shapes of the variables saved in the checkpoint."""
        shapes = {
            name: shape
            for (name,
                 shape) in checkpoint_utils.list_variables(self._model_dir)
        }

        reader = checkpoint_utils.load_checkpoint(self._model_dir)

        self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
        self.assertEqual(expected_global_step,
                         reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))

        self.assertEqual([3, 2], shapes[WEIGHTS_NAME])
        if expected_weights is not None:
            self.assertAllClose(expected_weights,
                                reader.get_tensor(WEIGHTS_NAME))

        self.assertEqual([2], shapes[BIAS_NAME])
        if expected_bias is not None:
            self.assertAllClose(expected_bias, reader.get_tensor(BIAS_NAME))
Пример #3
0
    def _assert_checkpoint(self,
                           model_dir,
                           global_step,
                           finalized_trees,
                           attempted_layers,
                           bucket_boundaries=None):
        reader = checkpoint_utils.load_checkpoint(model_dir)
        self.assertEqual(global_step,
                         reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
        serialized = reader.get_tensor('boosted_trees:0_serialized')
        ensemble_proto = boosted_trees_pb2.TreeEnsemble()
        ensemble_proto.ParseFromString(serialized)

        self.assertEqual(
            finalized_trees,
            sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
        self.assertEqual(attempted_layers,
                         ensemble_proto.growing_metadata.num_layers_attempted)

        if bucket_boundaries:
            for i, bucket_boundary in enumerate(bucket_boundaries):
                self.assertAllClose(
                    bucket_boundary,
                    reader.get_tensor(
                        'boosted_trees/QuantileAccumulator/_bucket_boundaries_'
                        + str(i)))
    def _assert_checkpoint_and_return_model(self, model_dir, global_step):
        reader = checkpoint_utils.load_checkpoint(model_dir)
        self.assertEqual(global_step,
                         reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
        serialized = reader.get_tensor("ensemble_model:0_config")
        ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig()
        ensemble_proto.ParseFromString(serialized)

        return ensemble_proto
Пример #5
0
    def restore(self, sess, save_path, var_filter=lambda v: True):
        """Restores only variables that are contained in `save_path` and match in shape and dtype and return `True` when passed to `var_filter`."""
        if self._is_empty:
            return
        if save_path is None:
            raise ValueError("Can't load save_path when it is None.")
        tf_logging.info("Restoring parameters from %s", save_path)

        reader = load_checkpoint(save_path)
        shape_map = reader.get_variable_to_shape_map()
        dtype_map = reader.get_variable_to_dtype_map()

        restore_op_name = self.saver_def.restore_op_name
        restore_op_grouped = sess.graph.get_operation_by_name(restore_op_name)

        def get_restore_ops(r_op):
            return sum((get_restore_ops(i) for i in r_op.control_inputs),
                       [r_op] if r_op.type == 'Assign' else [])

        all_restore_ops = get_restore_ops(restore_op_grouped)
        filtered_restore_ops = []

        for r_op in all_restore_ops:
            v = r_op.inputs[0]
            tensor_name = v.op.name
            part_match = re.search(r'/part_\d+$', tensor_name)
            if part_match:
                tf_logging.info('variable %s is sharded', tensor_name)
                tensor_name = tensor_name[:part_match.span()[0]]
            tensor_shape = v.get_shape().as_list()
            tensor_dtype = v.dtype.base_dtype
            if tensor_name not in shape_map or tensor_name not in dtype_map:
                tf_logging.warn('variable %s not in checkpoint', tensor_name)
            elif shape_map[tensor_name] != tensor_shape and not part_match:
                tf_logging.warn(
                    'variable %s in checkpoint, but checkpoint shape %r does not match graph shape %r',
                    tensor_name, shape_map[tensor_name], tensor_shape)
            elif dtype_map[tensor_name] != tensor_dtype:
                tf_logging.warn(
                    'variable %s in checkpoint, but checkpoint dtype %r does not match graph dtype %r',
                    tensor_name, dtype_map[tensor_name], tensor_dtype)
            elif not var_filter(v):
                tf_logging.info('variable %s rejected by var_filter',
                                tensor_name, dtype_map[tensor_name],
                                tensor_dtype)
            else:
                filtered_restore_ops.append(r_op)
                tf_logging.info('adding variable %s to be restored',
                                tensor_name)

        if context.in_eager_mode():
            raise NotImplementedError(
                'eager selective restoring not supported yet')

        for r_op in filtered_restore_ops:
            sess.run(r_op, {self.saver_def.filename_tensor_name: save_path})
 def _assert_checkpoint(self, model_dir, global_step, finalized_trees,
                        attempted_layers):
   reader = checkpoint_utils.load_checkpoint(model_dir)
   self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
   serialized = reader.get_tensor('boosted_trees:0_serialized')
   ensemble_proto = boosted_trees_pb2.TreeEnsemble()
   ensemble_proto.ParseFromString(serialized)
   self.assertEqual(
       finalized_trees,
       sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
   self.assertEqual(attempted_layers,
                    ensemble_proto.growing_metadata.num_layers_attempted)
Пример #7
0
 def _assert_checkpoint(self, model_dir, global_step, finalized_trees,
                        attempted_layers):
   reader = checkpoint_utils.load_checkpoint(model_dir)
   self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
   serialized = reader.get_tensor('boosted_trees:0_serialized')
   ensemble_proto = boosted_trees_pb2.TreeEnsemble()
   ensemble_proto.ParseFromString(serialized)
   self.assertEqual(
       finalized_trees,
       sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
   self.assertEqual(attempted_layers,
                    ensemble_proto.growing_metadata.num_layers_attempted)
Пример #8
0
def restore_variables_on_create(save_path):
    """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode.
  """
    if context.in_graph_mode():
        raise ValueError(
            "Currently, restore_variables_on_create can only be used with "
            "eager execution enabled.")
    if save_path:
        ckpt_var_cache = dict()
        reader = checkpoint_utils.load_checkpoint(save_path)
        for k, _ in checkpoint_utils.list_variables(save_path):
            ckpt_var_cache[k] = reader.get_tensor(k)

        old_init = getattr(resource_variable_ops.ResourceVariable,
                           "_init_from_args", None)
        assert old_init, "ResourceVariable misses _init_from_args method."
        setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                _init_from_checkpoint)
        setattr(resource_variable_ops.ResourceVariable, "old_init", old_init)
        setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
                ckpt_var_cache)
    try:
        yield
    except Exception as e:
        raise e
    finally:
        if save_path:
            setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                    old_init)
            setattr(resource_variable_ops.ResourceVariable, "old_init", None)
            setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
                    None)
Пример #9
0
def restore_variables_on_create(save_path):
  """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode.
  """
  if context.in_graph_mode():
    raise ValueError(
        "Currently, restore_variables_on_create can only be used with "
        "eager execution enabled.")
  if save_path:
    ckpt_var_cache = dict()
    reader = checkpoint_utils.load_checkpoint(save_path)
    for k, _ in checkpoint_utils.list_variables(save_path):
      ckpt_var_cache[k] = reader.get_tensor(k)

    old_init = getattr(
        resource_variable_ops.ResourceVariable, "_init_from_args", None)
    assert old_init, "ResourceVariable misses _init_from_args method."
    setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
            _init_from_checkpoint)
    setattr(resource_variable_ops.ResourceVariable, "old_init", old_init)
    setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
            ckpt_var_cache)
  try:
    yield
  except Exception as e:
    raise e
  finally:
    if save_path:
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              old_init)
      setattr(resource_variable_ops.ResourceVariable, "old_init", None)
      setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None)
Пример #10
0
    def testFSPath(self):
        checkpoint_dir = pathlib.Path(self.get_temp_dir())
        with self.cached_session() as session:
            v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)  # pylint: disable=unused-variable

        reader = checkpoint_utils.load_checkpoint(checkpoint_dir)
        self.assertAllEqual(reader.get_tensor("var1"), v1)

        self.assertAllEqual(
            checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)

        self.assertEqual(checkpoint_utils.list_variables(checkpoint_dir),
                         [("useful_scope/var4", [9, 9]), ("var1", [1, 10]),
                          ("var2", [10, 10]), ("var3", [100, 100])])
Пример #11
0
def list_checkpoint_attributes(ckpt_dir_or_file):
  """Lists all the attributes in a checkpoint.

  Checkpoint keys are paths in a checkpoint graph, and attribute is the first
  element in the path. e.g. with a checkpoint key
  "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE", optimizer is the attribute. The
  attribute is also used to save/restore a variable in a checkpoint,
  e.g. tf.train.Checkpoint(optimizer=optimizer, model=model).

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.

  Returns:
    Set of attributes in a checkpoint.
  """
  reader = checkpoint_utils.load_checkpoint(ckpt_dir_or_file)
  variable_map = reader.get_variable_to_shape_map()
  return {name.split('/')[0] for name in variable_map.keys()}
Пример #12
0
    def maybe_restore_on_create(self, save_path):
        """ContextManager that restores variables on creation.

      When save_path is None (e.g. No checkpoint), does nothing.
      Otherwise, it preloads all values from checkpoint. When the
      corresponding variable is first created, it assigns the checkpoint
      value to the variable.

    Args:
      save_path: Same as save_path of retore. If None, do not restore.

    Yields:
      Nothing.

    Raises:
      NotFoundError: If the variable is not found in checkpoint.
    """
        if save_path:
            ckpt_var_cache = dict()
            reader = checkpoint_utils.load_checkpoint(save_path)
            for k, _ in checkpoint_utils.list_variables(save_path):
                ckpt_var_cache[k] = reader.get_tensor(k)

            old_init = getattr(resource_variable_ops.ResourceVariable,
                               "_init_from_args", None)
            assert old_init, "ResourceVariable misses _init_from_args method."
            setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                    _init_from_checkpoint)
            setattr(resource_variable_ops.ResourceVariable, "old_init",
                    old_init)
            setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
                    ckpt_var_cache)
        try:
            yield
        except Exception as e:
            raise e
        finally:
            if save_path:
                setattr(resource_variable_ops.ResourceVariable,
                        "_init_from_args", old_init)
                setattr(resource_variable_ops.ResourceVariable, "old_init",
                        None)
                setattr(resource_variable_ops.ResourceVariable,
                        "ckpt_var_cache", None)
Пример #13
0
  def maybe_restore_on_create(self, save_path):
    """ContextManager that restores variables on creation.

      When save_path is None (e.g. No checkpoint), does nothing.
      Otherwise, it preloads all values from checkpoint. When the
      corresponding variable is first created, it assigns the checkpoint
      value to the variable.

    Args:
      save_path: Same as save_path of retore. If None, do not restore.

    Yields:
      Nothing.

    Raises:
      NotFoundError: If the variable is not found in checkpoint.
    """
    if save_path:
      ckpt_var_cache = dict()
      reader = checkpoint_utils.load_checkpoint(save_path)
      for k, _ in checkpoint_utils.list_variables(save_path):
        ckpt_var_cache[k] = reader.get_tensor(k)

      old_init = getattr(
          resource_variable_ops.ResourceVariable, "_init_from_args", None)
      assert old_init, "ResourceVariable misses _init_from_args method."
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              _init_from_checkpoint)
      setattr(resource_variable_ops.ResourceVariable, "old_init", old_init)
      setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
              ckpt_var_cache)
    try:
      yield
    except Exception as e:
      raise e
    finally:
      if save_path:
        setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                old_init)
        setattr(resource_variable_ops.ResourceVariable, "old_init", None)
        setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None)
Пример #14
0
 def _set_restore_on_create(self, save_path, map_func, user_map_func,
                            existing_variables_by_checkpoint_name):
     """If necessary, request deferred restorations of variables."""
     checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
     checkpointed_variables_to_restore = {}
     for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
         if checkpoint_name in existing_variables_by_checkpoint_name:
             # This variable was already created and restored.
             continue
         # Save the variable for later restoration in a custom getter.
         checkpointed_variables_to_restore[checkpoint_name] = (
             checkpoint_reader.get_tensor(checkpoint_name))
     # Only set a deferred restoration if there are checkpoint variables which
     # have not been assigned to existing variables. Note that this loses out on
     # some opportunity for error checking, but avoids creating
     # _DeferredRestoration objects once a Network has been built (so that
     # restoring in a loop does not take increasing amounts of memory).
     if checkpointed_variables_to_restore:
         if context.in_eager_mode():
             sess = None
         else:
             sess = ops.get_default_session()
         # We need a name for error messages. If we haven't been added to another
         # Network yet, we're top-level.
         self._finalize_name(False)
         self._set_scope()
         # Save a record of this restoration for use in the custom getter.
         deferred_restoration = _DeferredRestoration(
             map_func=map_func,
             map_func_is_user=(user_map_func is not None),
             checkpointed_variables_to_restore=
             checkpointed_variables_to_restore,
             restored_variables={},
             session=sess,
             network_name=self.name,
             network_scope_name=self.scope_name)
         self._deferred_restorations.append(deferred_restoration)
         # Add the deferred registration to non-Network children, and request that
         # Networks propagate the request to their children.
         self._add_deferred_restoration(deferred_restoration)
Пример #15
0
 def _set_restore_on_create(self, save_path, map_func, user_map_func,
                            existing_variables_by_checkpoint_name):
   """If necessary, request deferred restorations of variables."""
   checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
   checkpointed_variables_to_restore = {}
   for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
     if checkpoint_name in existing_variables_by_checkpoint_name:
       # This variable was already created and restored.
       continue
     # Save the variable for later restoration in a custom getter.
     checkpointed_variables_to_restore[checkpoint_name] = (
         checkpoint_reader.get_tensor(checkpoint_name))
   # Only set a deferred restoration if there are checkpoint variables which
   # have not been assigned to existing variables. Note that this loses out on
   # some opportunity for error checking, but avoids creating
   # _DeferredRestoration objects once a Network has been built (so that
   # restoring in a loop does not take increasing amounts of memory).
   if checkpointed_variables_to_restore:
     if context.in_eager_mode():
       sess = None
     else:
       sess = ops.get_default_session()
     # We need a name for error messages. If we haven't been added to another
     # Network yet, we're top-level.
     self._finalize_name(False)
     self._set_scope()
     # Save a record of this restoration for use in the custom getter.
     deferred_restoration = _DeferredRestoration(
         map_func=map_func,
         map_func_is_user=(user_map_func is not None),
         checkpointed_variables_to_restore=checkpointed_variables_to_restore,
         restored_variables={},
         session=sess,
         network_name=self.name,
         network_scope_name=self.scope_name)
     self._deferred_restorations.append(deferred_restoration)
     # Add the deferred registration to non-Network children, and request that
     # Networks propagate the request to their children.
     self._add_deferred_restoration(deferred_restoration)
Пример #16
0
def restore_variables_on_create(save_path, map_func=None):
  """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.
    map_func: A function that given the variable name as argument
        and returns a variable name in checkpoint for restore. If
        None, use the variable with the same name in checkpoint to restore.
        It's an error that the mapped variable name doesn't exist in
        checkpoint.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode or map_func is not callable.
  """
  if not context.executing_eagerly():
    raise ValueError(
        "Currently, restore_variables_on_create can only be used with "
        "eager execution enabled.")
  if save_path:
    if map_func is None:
      map_func_wrapper = lambda self, x: x
    else:
      if not callable(map_func):
        raise ValueError("map_func must be callable.")
      map_func_wrapper = lambda self, x: map_func(x)

    ckpt_var_cache = dict()
    reader = checkpoint_utils.load_checkpoint(save_path)
    for k, _ in checkpoint_utils.list_variables(save_path):
      ckpt_var_cache[k] = reader.get_tensor(k)

    old_init = getattr(resource_variable_ops.ResourceVariable,
                       "_init_from_args", None)
    assert old_init, "ResourceVariable misses _init_from_args method."
    setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
            _init_from_checkpoint)
    setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init)
    setattr(resource_variable_ops.ResourceVariable, "_map_func",
            map_func_wrapper)
    setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache",
            ckpt_var_cache)
  try:
    yield
  except Exception as e:
    raise e
  finally:
    if save_path:
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              old_init)
      setattr(resource_variable_ops.ResourceVariable, "_old_init", None)
      setattr(resource_variable_ops.ResourceVariable, "_map_func", None)
      setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None)
Пример #17
0
 def _assert_checkpoint(self, model_dir, global_step):
     reader = checkpoint_utils.load_checkpoint(model_dir)
     self.assertEqual(global_step,
                      reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
Пример #18
0
 def _assert_checkpoint(self, model_dir, global_step):
   reader = checkpoint_utils.load_checkpoint(model_dir)
   self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))