Exemple #1
0
    def _capture_tensor_as_extra_input(self, tensor, name=None):
        # Substitute with a placeholder.
        self.extra_inputs.append(tensor)
        # Hoist the new input placeholder out of any control flow context
        # we're currently in.
        with ops.control_dependencies(None):
            ph = array_ops.placeholder(tensor.dtype,
                                       shape=tensor.get_shape(),
                                       name=name)
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
            if isinstance(tensor, ops.EagerTensor):
                handle_data = tensor._handle_data
                if handle_data:
                    handle_data = handle_data.SerializeToString()
            else:
                handle_data = c_api.GetResourceHandleShapeAndType(
                    tensor.graph._c_graph, tensor._as_tf_output())

            if handle_data:
                c_api.SetResourceHandleShapeAndType(
                    ph.graph._c_graph, ph._as_tf_output(),
                    compat.as_bytes(handle_data))
        else:
            ph._handle_data = tensor._handle_data
        # pylint: enable=protected-access
        self.inputs.append(ph)
        self._captured[tensor] = ph
        self.extra_args.append(ph)
        if _is_guaranteed_const(tensor):
            with ops.control_dependencies(None):
                return array_ops.guarantee_const(ph)
        else:
            return ph
Exemple #2
0
def get_resource_handle_data(graph_op):
  assert ops._USE_C_SHAPES  # pylint: disable=protected-access
  assert type(graph_op) == ops.Tensor  # pylint: disable=unidiomatic-typecheck

  handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access

  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
      compat.as_bytes(handle_data))