Пример #1
0
 def f():
     v = variables.Variable([1.0])
     self.assertTrue(ds_values.is_distributed_variable(v))
     # Slot variables are created in the first call to apply_gradients.
     optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)])
     self.assertTrue(optimizer.get_slot_names())
     for name in optimizer.get_slot_names():
         slot = optimizer.get_slot(v, name)
         self.assertIsNotNone(slot)
         self.assertTrue(ds_values.is_distributed_variable(slot))
Пример #2
0
 def _get_tensor_from_node(self, node_id):
   """Resolves a node id into a tensor to be captured for a function."""
   with ops.init_scope():
     obj = self._nodes[node_id]
     if ds_values.is_distributed_variable(obj):
       return obj
     elif resource_variable_ops.is_resource_variable(obj):
       return obj.handle
     elif isinstance(obj, tracking.Asset):
       return obj.asset_path
     elif tensor_util.is_tensor(obj):
       return obj
     elif isinstance(obj, tracking.CapturableResource):
       # Note: this executes restored functions in the CapturableResource.
       return obj.resource_handle
     raise ValueError("Can't convert node %s to tensor" % (type(obj)))
Пример #3
0
 def _setup_functions_captures(self):
     """Setup captures and variables in restored functions."""
     concrete_functions = sorted(self._proto.concrete_functions.items())
     for name, proto in concrete_functions:
         concrete_function = self._concrete_functions[name]
         bound_inputs = [
             self._get_tensor_from_node(node_id)
             for node_id in proto.bound_inputs
         ]
         bound_variables = [
             self._nodes[node_id] for node_id in proto.bound_inputs
             if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
         ]
         # TODO(andresp): This is only injecting the captured inputs into the
         # concrete function, note that we did not modify the FuncGraph
         # itself.
         concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
         concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
         if bound_inputs:
             for bound_input, internal_capture in zip(
                     bound_inputs,
                     concrete_function.inputs[-len(bound_inputs):]):
                 if ds_values.is_distributed_variable(bound_input):
                     concrete_function.graph.capture_distributed_variable(
                         bound_input, internal_capture)
                 else:
                     concrete_function.graph._captures[ops.tensor_id(
                         bound_input)] = (  # pylint: disable=protected-access
                             bound_input, internal_capture)
                     if internal_capture.dtype == dtypes.resource:
                         if resource_variable_ops.is_resource_variable(
                                 bound_input):
                             try:
                                 handle = bound_input.handle
                             except ValueError:
                                 # For mirrored variables we'll copy handle data for components
                                 # as they get captured.
                                 pass
                             else:
                                 custom_gradient.copy_handle_data(
                                     handle, internal_capture)
                         else:
                             custom_gradient.copy_handle_data(
                                 bound_input, internal_capture)
                     # Setting "captures" first means "capture" won't create a new
                     # placeholder for this input.
                     concrete_function.graph.capture(bound_input)
Пример #4
0
 def get_cross_replica_handle(x):
   return _unused_handle() if ds_values.is_distributed_variable(x) else x
Пример #5
0
 def get_in_replica_handle(x):
   return x.handle if ds_values.is_distributed_variable(x) else x
Пример #6
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})

    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.CapturableResource):
        # pylint: disable=protected-access
        with ops.device(obj._resource_device):
          new_resource = obj._create_resource()
        # pylint: enable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif (ds_values.is_distributed_variable(obj) or
            resource_variable_ops.is_resource_variable(obj)):
        obj_to_copy = obj.primary if ds_values.is_distributed_variable(
            obj) else obj
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(
            obj_to_copy)
        if ds_values.is_distributed_variable(obj):
          self.captured_tensor_node_ids[obj] = node_id
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.Asset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      if not concrete_function.graph.saveable:
        raise ValueError(
            ("Unable to save function {name} for the following reason(s):\n" +
             "\n".join(concrete_function.graph.saving_errors))
            .format(name=concrete_function.name))
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          capture_constant_value = tensor_util.constant_value(capture)
          if capture_constant_value is None:
            raise ValueError(
                ("Attempted to save a function {} which references a symbolic "
                 "Tensor {} that is not a simple constant. This is not "
                 "supported.").format(concrete_function.name, capture))
          copied_tensor = constant_op.constant(capture_constant_value)
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
Пример #7
0
    def map_resources(self):
        """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
        # Only makes sense when adding to the export Graph
        assert not context.executing_eagerly()
        # TODO(allenl): Handle MirroredVariables and other types of variables which
        # may need special casing.
        object_map = object_identity.ObjectIdentityDictionary()
        resource_map = {}
        asset_info = _AssetInfo(asset_defs=[],
                                asset_initializers_by_resource={},
                                asset_filename_map={},
                                asset_index={})

        for node_id, obj in enumerate(self.nodes):
            if isinstance(obj, tracking.CapturableResource):
                new_obj = object_map[obj] = copy.copy(obj)
                # pylint: disable=protected-access
                with ops.device(obj._resource_device):
                    new_resource = new_obj._create_resource()
                new_obj._resource_handle = new_resource
                # pylint: enable=protected-access
                resource_map[obj.resource_handle] = new_resource
                self.captured_tensor_node_ids[obj.resource_handle] = node_id
            elif (ds_values.is_distributed_variable(obj)
                  or resource_variable_ops.is_resource_variable(obj)):
                obj_to_copy = obj.primary if ds_values.is_distributed_variable(
                    obj) else obj
                new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                    obj_to_copy)
                if ds_values.is_distributed_variable(obj):
                    self.captured_tensor_node_ids[obj] = node_id
                    for v in obj.values:
                        object_map[v] = new_variable
                        resource_map[v.handle] = new_variable.handle
                        self.captured_tensor_node_ids[v.handle] = node_id
                object_map[obj] = new_variable
                resource_map[obj.handle] = new_variable.handle
                self.captured_tensor_node_ids[obj.handle] = node_id
            elif isinstance(obj, tracking.Asset):
                _process_asset(obj, asset_info, resource_map)
                self.captured_tensor_node_ids[obj.asset_path] = node_id

        # Note: some concrete functions can have been realized when tracing other
        # functions, and might closure-capture tensors from their parent functions.
        # This is normal, but it means those concrete functions can't be serialized
        # as their own independent endpoints, so we filter them out here.
        bad_functions = []
        for concrete_function in self.concrete_functions:
            if not concrete_function.graph.saveable:
                raise ValueError((
                    "Unable to save function {name} for the following reason(s):\n"
                    + "\n".join(concrete_function.graph.saving_errors)).format(
                        name=concrete_function.name))
            for capture in concrete_function.captured_inputs:
                if (tensor_util.is_tensor(capture)
                        and capture.dtype not in _UNCOPIABLE_DTYPES
                        and capture not in self.captured_tensor_node_ids):
                    if hasattr(capture, "_cached_variable"):
                        if concrete_function not in self.wrapped_functions:
                            wrapped = self.wrapped_functions[
                                concrete_function] = (
                                    function_serialization.
                                    wrap_cached_variables(concrete_function))
                            self.function_name_map[compat.as_text(
                                concrete_function.name)] = (compat.as_text(
                                    wrapped.name))
                        continue
                    capture_constant_value = tensor_util.constant_value(
                        capture)
                    if capture_constant_value is None:
                        bad_functions.append(concrete_function)
                        continue
                    copied_tensor = constant_op.constant(
                        capture_constant_value)
                    node_id = len(self.nodes)
                    node = _CapturedConstant(eager_tensor=capture,
                                             graph_tensor=copied_tensor)
                    self.nodes.append(node)
                    self.node_ids[capture] = node_id
                    self.node_ids[node] = node_id
                    self.captured_tensor_node_ids[capture] = node_id
                    resource_map[capture] = copied_tensor

        self.concrete_functions = [
            self.wrapped_functions.get(x, x) for x in self.concrete_functions
            if x not in bad_functions
        ]
        return object_map, resource_map, asset_info