Esempio n. 1
0
def _map_resources(accessible_objects):
  """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
  """
  # TODO(allenl, rohanj): Map generic resources rather than just variables.
  # TODO(allenl): Handle MirroredVariables and other types of variables which
  # may need special casing.
  object_map = {}
  resource_map = {}
  for obj in accessible_objects:
    if resource_variable_ops.is_resource_variable(obj):
      new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
      object_map[obj] = new_variable
      resource_map[obj.handle] = new_variable.handle
  return object_map, resource_map
Esempio n. 2
0
def _map_resources(accessible_objects):
    """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
  """
    # TODO(allenl, rohanj): Map generic resources rather than just variables.
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = {}
    resource_map = {}
    for obj in accessible_objects:
        if resource_variable_ops.is_resource_variable(obj):
            new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                obj)
            object_map[obj] = new_variable
            resource_map[obj.handle] = new_variable.handle
    return object_map, resource_map
Esempio n. 3
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj._create_resource()  # pylint: disable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(
              tensor_util.constant_value(capture))
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
Esempio n. 4
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj._create_resource()  # pylint: disable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(
              tensor_util.constant_value(capture))
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
 def testCopyToGraphUninitialized(self):
   v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
   copy_to_graph = ops.Graph()
   with copy_to_graph.as_default():  # Intentionally testing v1 behavior
     copied = resource_variable_ops.copy_to_graph_uninitialized(v)
     self.assertEqual(v.name, copied.name)
     with self.session(copy_to_graph) as session:
       with self.assertRaises(errors.InvalidArgumentError):
         session.run(copied.initializer)
 def testCopyToGraphUninitialized(self):
     v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
     copy_to_graph = ops.Graph()
     with copy_to_graph.as_default():  # Intentionally testing v1 behavior
         copied = resource_variable_ops.copy_to_graph_uninitialized(v)
         self.assertEqual(v.name, copied.name)
         with self.session(copy_to_graph) as session:
             with self.assertRaises(errors.InvalidArgumentError):
                 session.run(copied.initializer)
Esempio n. 7
0
def _map_resources(accessible_objects):
    """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map, asset_info):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
      asset_info: An _AssetInfo tuple describing external assets referenced from
        accessible_objects.
  """
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = util.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(asset_defs=[],
                            asset_initializers_by_resource={},
                            asset_filename_map={},
                            asset_index={})
    for obj in accessible_objects:
        if isinstance(obj, tracking.TrackableResource):
            new_resource = obj.create_resource()
            resource_map[obj.resource_handle] = new_resource
        elif resource_variable_ops.is_resource_variable(obj):
            new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                obj)
            object_map[obj] = new_variable
            resource_map[obj.handle] = new_variable.handle
        elif isinstance(obj, tracking.TrackableAsset):
            _process_asset(obj, asset_info, resource_map)
    return object_map, resource_map, asset_info
Esempio n. 8
0
def _map_resources(accessible_objects):
  """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map, asset_info):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
      asset_info: An _AssetInfo tuple describing external assets referenced from
        accessible_objects.
  """
  # TODO(allenl): Handle MirroredVariables and other types of variables which
  # may need special casing.
  object_map = util.ObjectIdentityDictionary()
  resource_map = {}
  asset_info = _AssetInfo(
      asset_defs=[],
      asset_initializers_by_resource={},
      asset_filename_map={},
      asset_index={})
  for obj in accessible_objects:
    if isinstance(obj, tracking.TrackableResource):
      new_resource = obj.create_resource()
      resource_map[obj.resource_handle] = new_resource
    elif resource_variable_ops.is_resource_variable(obj):
      new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
      object_map[obj] = new_variable
      resource_map[obj.handle] = new_variable.handle
    elif isinstance(obj, tracking.TrackableAsset):
      _process_asset(obj, asset_info, resource_map)
  return object_map, resource_map, asset_info
Esempio n. 9
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})

    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.CapturableResource):
        # pylint: disable=protected-access
        with ops.device(obj._resource_device):
          new_resource = obj._create_resource()
        # pylint: enable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif (ds_values.is_distributed_variable(obj) or
            resource_variable_ops.is_resource_variable(obj)):
        obj_to_copy = obj.primary if ds_values.is_distributed_variable(
            obj) else obj
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(
            obj_to_copy)
        if ds_values.is_distributed_variable(obj):
          self.captured_tensor_node_ids[obj] = node_id
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.Asset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      if not concrete_function.graph.saveable:
        raise ValueError(
            ("Unable to save function {name} for the following reason(s):\n" +
             "\n".join(concrete_function.graph.saving_errors))
            .format(name=concrete_function.name))
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          capture_constant_value = tensor_util.constant_value(capture)
          if capture_constant_value is None:
            raise ValueError(
                ("Attempted to save a function {} which references a symbolic "
                 "Tensor {} that is not a simple constant. This is not "
                 "supported.").format(concrete_function.name, capture))
          copied_tensor = constant_op.constant(capture_constant_value)
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
Esempio n. 10
0
    def map_resources(self):
        """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
        # Only makes sense when adding to the export Graph
        assert not context.executing_eagerly()
        # TODO(allenl): Handle MirroredVariables and other types of variables which
        # may need special casing.
        object_map = object_identity.ObjectIdentityDictionary()
        resource_map = {}
        asset_info = _AssetInfo(asset_defs=[],
                                asset_initializers_by_resource={},
                                asset_filename_map={},
                                asset_index={})

        for node_id, obj in enumerate(self.nodes):
            if isinstance(obj, tracking.CapturableResource):
                new_obj = object_map[obj] = copy.copy(obj)
                # pylint: disable=protected-access
                with ops.device(obj._resource_device):
                    new_resource = new_obj._create_resource()
                new_obj._resource_handle = new_resource
                # pylint: enable=protected-access
                resource_map[obj.resource_handle] = new_resource
                self.captured_tensor_node_ids[obj.resource_handle] = node_id
            elif (ds_values.is_distributed_variable(obj)
                  or resource_variable_ops.is_resource_variable(obj)):
                obj_to_copy = obj.primary if ds_values.is_distributed_variable(
                    obj) else obj
                new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                    obj_to_copy)
                if ds_values.is_distributed_variable(obj):
                    self.captured_tensor_node_ids[obj] = node_id
                    for v in obj.values:
                        object_map[v] = new_variable
                        resource_map[v.handle] = new_variable.handle
                        self.captured_tensor_node_ids[v.handle] = node_id
                object_map[obj] = new_variable
                resource_map[obj.handle] = new_variable.handle
                self.captured_tensor_node_ids[obj.handle] = node_id
            elif isinstance(obj, tracking.Asset):
                _process_asset(obj, asset_info, resource_map)
                self.captured_tensor_node_ids[obj.asset_path] = node_id

        # Note: some concrete functions can have been realized when tracing other
        # functions, and might closure-capture tensors from their parent functions.
        # This is normal, but it means those concrete functions can't be serialized
        # as their own independent endpoints, so we filter them out here.
        bad_functions = []
        for concrete_function in self.concrete_functions:
            if not concrete_function.graph.saveable:
                raise ValueError((
                    "Unable to save function {name} for the following reason(s):\n"
                    + "\n".join(concrete_function.graph.saving_errors)).format(
                        name=concrete_function.name))
            for capture in concrete_function.captured_inputs:
                if (tensor_util.is_tensor(capture)
                        and capture.dtype not in _UNCOPIABLE_DTYPES
                        and capture not in self.captured_tensor_node_ids):
                    if hasattr(capture, "_cached_variable"):
                        if concrete_function not in self.wrapped_functions:
                            wrapped = self.wrapped_functions[
                                concrete_function] = (
                                    function_serialization.
                                    wrap_cached_variables(concrete_function))
                            self.function_name_map[compat.as_text(
                                concrete_function.name)] = (compat.as_text(
                                    wrapped.name))
                        continue
                    capture_constant_value = tensor_util.constant_value(
                        capture)
                    if capture_constant_value is None:
                        bad_functions.append(concrete_function)
                        continue
                    copied_tensor = constant_op.constant(
                        capture_constant_value)
                    node_id = len(self.nodes)
                    node = _CapturedConstant(eager_tensor=capture,
                                             graph_tensor=copied_tensor)
                    self.nodes.append(node)
                    self.node_ids[capture] = node_id
                    self.node_ids[node] = node_id
                    self.captured_tensor_node_ids[capture] = node_id
                    resource_map[capture] = copied_tensor

        self.concrete_functions = [
            self.wrapped_functions.get(x, x) for x in self.concrete_functions
            if x not in bad_functions
        ]
        return object_map, resource_map, asset_info