예제 #1
0
    def _capture_tensor_as_extra_input(self, tensor, name=None):
        # Substitute with a placeholder.
        self.extra_inputs.append(tensor)
        # Hoist the new input placeholder out of any control flow context
        # we're currently in.
        with ops.control_dependencies(None):
            ph = array_ops.placeholder(tensor.dtype,
                                       shape=tensor.get_shape(),
                                       name=name)
        # pylint: disable=protected-access
        if isinstance(tensor, ops.EagerTensor):
            handle_data = tensor._handle_data
            if handle_data:
                handle_data = handle_data.SerializeToString()
        else:
            handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
                                                      tensor._as_tf_output())

        if handle_data:
            c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
                                        compat.as_bytes(handle_data))
        # pylint: enable=protected-access
        self.inputs.append(ph)
        self._captured[tensor] = ph
        self.extra_args.append(ph)
        if _is_guaranteed_const(tensor):
            with ops.control_dependencies(None):
                return array_ops.guarantee_const(ph)
        else:
            return ph
예제 #2
0
def copy_handle_data(source_t, target_t):
    """Copies HandleData for variant and resource type tensors if available.

  The CppShapeInferenceResult::HandleData proto contains information about the
  shapes and types of the element tensors of resource/variant type tensors.
  We need to copy this across function boundaries, i.e., when capturing a
  placeholder or when returning a function tensor as output. If we don't do this
  the element tensors will have unknown shapes, e.g., if a TensorList variant
  tensor is captured as a placeholder, elements popped from that list would have
  unknown shape.

  Args:
    source_t: The tensor to copy HandleData from.
    target_t: The tensor to copy HandleData to.
  """
    if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant):
        if isinstance(source_t, ops.EagerTensor):
            handle_data = source_t._handle_data  # pylint: disable=protected-access
        else:
            handle_data = resource_variable_ops.get_resource_handle_data(
                source_t)
        if (handle_data is not None and handle_data.is_set
                and handle_data.shape_and_type):
            # pylint: disable=protected-access
            pywrap_tensorflow.SetHandleShapeAndType(
                target_t.graph._c_graph, target_t._as_tf_output(),
                handle_data.SerializeToString())
            # pylint: enable=protected-access
            # Ensure that shapes and dtypes are propagated.
            shapes, types = zip(*[(pair.shape, pair.dtype)
                                  for pair in handle_data.shape_and_type])
            ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
            shapes = [
                [d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
                if not s.unknown_rank else None for s in shapes
            ]
            pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                target_t._op._graph._c_graph,  # pylint: disable=protected-access
                target_t._as_tf_output(),  # pylint: disable=protected-access
                shapes,
                ranks,
                types)