def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
  """Creates a variable handle with information to do shape inference."""
  container = ops.get_default_graph()._container  # pylint: disable=protected-access
  if container is None:
    container = ""
  handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                               shared_name=shared_name,
                                               name=name,
                                               container=container)
  if graph_mode:
    return handle

  with context.graph_mode(), ops.Graph().as_default() as graph:
    h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                            shared_name=shared_name,
                                            name=name,
                                            container=container)

    # Tensor._handle_data contains information for the shape-inference code to
    # know the shape and dtype of the variable pointed to by a handle. Since
    # shape inference doesn't run in eager mode we copy this data here for when
    # the handle is captured by an eager mode function.
    # pylint: disable=protected-access
    handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
    # pylint: enable=protected-access
  # Clean up op->graph->op reference cycles.
  ops.dismantle_graph(graph)
  return handle
Beispiel #2
0
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
  """Creates a variable handle with information to do shape inference."""
  container = ops.get_default_graph()._container  # pylint: disable=protected-access
  if container is None:
    container = ""
  handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                               shared_name=shared_name,
                                               name=name,
                                               container=container)
  if graph_mode:
    return handle

  with context.graph_mode(), ops.Graph().as_default() as graph:
    h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                            shared_name=shared_name,
                                            name=name,
                                            container=container)

    # Tensor._handle_data contains information for the shape-inference code to
    # know the shape and dtype of the variable pointed to by a handle. Since
    # shape inference doesn't run in eager mode we copy this data here for when
    # the handle is captured by an eager mode function.
    # pylint: disable=protected-access
    handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
    # pylint: enable=protected-access
  # Clean up op->graph->op reference cycles.
  ops.dismantle_graph(graph)
  return handle
def copy_handle_data(source_t, target_t):
    """Copies HandleData for variant and resource type tensors if available.

  The CppShapeInferenceResult::HandleData proto contains information about the
  shapes and types of the element tensors of resource/variant type tensors.
  We need to copy this across function boundaries, i.e., when capturing a
  placeholder or when returning a function tensor as output. If we don't do this
  the element tensors will have unknown shapes, e.g., if a TensorList variant
  tensor is captured as a placeholder, elements popped from that list would have
  unknown shape.

  Args:
    source_t: The tensor to copy HandleData from.
    target_t: The tensor to copy HandleData to.
  """
    if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant):
        if isinstance(source_t, ops.EagerTensor):
            handle_data = source_t._handle_data  # pylint: disable=protected-access
        else:
            handle_data = resource_variable_ops.get_resource_handle_data(
                source_t)
        if (handle_data is not None and handle_data.is_set
                and handle_data.shape_and_type):
            # pylint: disable=protected-access
            if isinstance(target_t, ops.EagerTensor):
                target_t._handle_data = handle_data
                return
            pywrap_tf_session.SetHandleShapeAndType(
                target_t.graph._c_graph, target_t._as_tf_output(),
                handle_data.SerializeToString())
Beispiel #4
0
def _copy_source(s, graph, op_map, handle_captures, inverse_captures,
                 base_graph):
  """Create a source in a graph based on a Tensor from a different graph.

  This function creates a placeholder analog of `s` in a graph with the
  following behavior:

  1) If s is a captured Tensor or Variable and handle_captures is set to True,
     simply capture it in the new graph as well.

  2) If s is a PlaceholderWithDefault whose default is a constant, preserve
     said default in the new graph.

  3) When applicable, copy resource variable metadata from `s` to the newly
     created placeholder.

  Args:
    s: The source of interest.
    graph: The destination graph.
    op_map: A dict mapping ops and tensors in the old graph to the new one.
    handle_captures: A boolean indicating whether to re-capture s in the new
      graph or simply create a vanilla placeholder.
    inverse_captures: A dict mapping s back to the Tensor or Variable that it
      captures.
    base_graph: The graph being copied from.
  """
  if handle_captures and s in inverse_captures:
    copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name)
  elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s):
    # Copy the default value to the graph.
    default_value = s.op.inputs[0]
    unavailable_inputs, unavailable_control_inputs = _copy_non_source(
        op=default_value.op, graph=graph, op_map=op_map,
        base_graph=base_graph)
    if unavailable_inputs or unavailable_control_inputs:
      raise AssertionError(
          "Could not copy source node {} because it has inputs."
          .format(default_value))

    with ops.device(s.op.device):
      copied_placeholder = array_ops.placeholder_with_default(
          input=op_map[default_value], shape=s.shape, name=s.op.name)
  else:
    with ops.device(s.op.device):
      copied_placeholder = array_ops.placeholder(
          dtype=s.dtype, shape=s.shape, name=s.op.name)

  base_handle = resource_variable_ops.get_resource_handle_data(s)
  if base_handle.shape_and_type:
    resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
        copied_placeholder,
        base_handle,
        graph_mode=True)

  op_map[s] = copied_placeholder
  # Add an entry for the op of the source tensor so that if there are any nodes
  # depending on that op via control dependencies it can work correctly.
  op_map[s.op] = copied_placeholder.op
Beispiel #5
0
def _copy_source(s, graph, op_map, handle_captures, inverse_captures):
  """Create a source in a graph based on a Tensor from a different graph.

  This function creates a placeholder analog of `s` in a graph with the
  following behavior:

  1) If s is a captured Tensor or Variable and handle_captures is set to True,
     simply capture it in the new graph as well.

  2) If s is a PlaceholderWithDefault whose default is a constant, preserve
     said default in the new graph.

  3) When applicable, copy resource variable metadata from `s` to the newly
     created placeholder.

  Args:
    s: The source of interest.
    graph: The destination graph.
    op_map: A dict mapping ops and tensors in the old graph to the new one.
    handle_captures: A boolean indicating whether to re-capture s in the new
      graph or simply create a vanilla placeholder.
    inverse_captures: A dict mapping s back to the Tensor or Variable that it
      captures.
  """
  if handle_captures and s in inverse_captures:
    copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name)
  elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s):
    # Copy the default value to the graph.
    default_value = s.op.inputs[0]
    unavailable_inputs, unavailable_control_inputs = _copy_non_source(
        op=default_value.op, graph=graph, op_map=op_map)
    if unavailable_inputs or unavailable_control_inputs:
      raise AssertionError(
          "Could not copy source node {} because it has inputs."
          .format(default_value))

    with ops.device(s.op.device):
      copied_placeholder = array_ops.placeholder_with_default(
          input=op_map[default_value], shape=s.shape, name=s.op.name)
  else:
    with ops.device(s.op.device):
      copied_placeholder = array_ops.placeholder(
          dtype=s.dtype, shape=s.shape, name=s.op.name)

  base_handle = resource_variable_ops.get_resource_handle_data(s)
  if base_handle.shape_and_type:
    resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
        copied_placeholder,
        base_handle,
        graph_mode=True)

  op_map[s] = copied_placeholder
  # Add an entry for the op of the source tensor so that if there are any nodes
  # depending on that op via control dependencies it can work correctly.
  op_map[s.op] = copied_placeholder.op
Beispiel #6
0
def capture_value(tensor_map, value, dtype, name):
    """Capture a value from outside the function, to pass in as an extra arg."""
    captured_value = tensor_map.get(ops.tensor_id(value), None)
    if captured_value is None:
        captured_value = graph_placeholder(dtype=dtype or value.dtype,
                                           shape=value.shape,
                                           name=name)
        if captured_value.dtype == dtypes_module.resource:
            if ops._USE_C_SHAPES:  # pylint: disable=protected-access
                if isinstance(value, ops.EagerTensor):
                    handle_data = value._handle_data  # pylint: disable=protected-access
                else:
                    handle_data = resource_variable_ops.get_resource_handle_data(
                        value)
            else:
                handle_data = value._handle_data  # pylint: disable=protected-access
            if handle_data is not None and handle_data.is_set:
                # pylint: disable=protected-access
                if ops._USE_C_SHAPES:
                    pywrap_tensorflow.TFE_SetResourceHandleShapeAndType(
                        captured_value.graph._c_graph,
                        captured_value._as_tf_output(),
                        handle_data.SerializeToString())
                else:
                    captured_value._handle_data = handle_data
                # pylint: enable=protected-access
                # Ensure that shapes and dtypes are propagated.
                shapes, types = zip(*[(pair.shape, pair.dtype)
                                      for pair in handle_data.shape_and_type])
                ranks = [
                    len(s.dim) if not s.unknown_rank else -1 for s in shapes
                ]
                shapes = [[d.size
                           for d in s.dim] if not s.unknown_rank else None
                          for s in shapes]
                pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                    captured_value._op._graph._c_graph,  # pylint: disable=protected-access
                    captured_value._as_tf_output(),  # pylint: disable=protected-access
                    shapes,
                    ranks,
                    types)

        tensor_map[ops.tensor_id(value)] = (value, captured_value)
    else:
        captured_value = captured_value[1]
    tape.record_operation("captured_value", [captured_value], [value],
                          lambda x: [x])
    return captured_value
Beispiel #7
0
def copy_handle_data(source_t, target_t):
    """Copies HandleData for variant and resource type tensors if available.

  The CppShapeInferenceResult::HandleData proto contains information about the
  shapes and types of the element tensors of resource/variant type tensors.
  We need to copy this across function boundaries, i.e., when capturing a
  placeholder or when returning a function tensor as output. If we don't do this
  the element tensors will have unknown shapes, e.g., if a TensorList variant
  tensor is captured as a placeholder, elements popped from that list would have
  unknown shape.

  Args:
    source_t: The tensor to copy HandleData from.
    target_t: The tensor to copy HandleData to.
  """
    if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant):
        if isinstance(source_t, ops.EagerTensor):
            handle_data = source_t._handle_data  # pylint: disable=protected-access
        else:
            handle_data = resource_variable_ops.get_resource_handle_data(
                source_t)
        if (handle_data is not None and handle_data.is_set
                and handle_data.shape_and_type):
            # pylint: disable=protected-access
            pywrap_tf_session.SetHandleShapeAndType(
                target_t.graph._c_graph, target_t._as_tf_output(),
                handle_data.SerializeToString())
            # pylint: enable=protected-access
            # Ensure that shapes and dtypes are propagated.
            shapes, types = zip(*[(pair.shape, pair.dtype)
                                  for pair in handle_data.shape_and_type])
            ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
            shapes = [
                [d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
                if not s.unknown_rank else None for s in shapes
            ]
            pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                target_t._op._graph._c_graph,  # pylint: disable=protected-access
                target_t._as_tf_output(),  # pylint: disable=protected-access
                shapes,
                ranks,
                types)
Beispiel #8
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      if ops._USE_C_SHAPES:  # pylint: disable=protected-access
        if isinstance(value, ops.EagerTensor):
          handle_data = value._handle_data  # pylint: disable=protected-access
        else:
          handle_data = resource_variable_ops.get_resource_handle_data(value)
      else:
        handle_data = value._handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
          pywrap_tensorflow.SetResourceHandleShapeAndType(
              captured_value.graph._c_graph, captured_value._as_tf_output(),
              handle_data.SerializeToString())
        else:
          captured_value._handle_data = handle_data
        # pylint: enable=protected-access
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
            captured_value._op._graph._c_graph,  # pylint: disable=protected-access
            captured_value._as_tf_output(),  # pylint: disable=protected-access
            shapes, ranks, types)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
def copy_handle_data(source_t, target_t):
  """Copies HandleData for variant and resource type tensors if available.

  The CppShapeInferenceResult::HandleData proto contains information about the
  shapes and types of the element tensors of resource/variant type tensors.
  We need to copy this across function boundaries, i.e., when capturing a
  placeholder or when returning a function tensor as output. If we don't do this
  the element tensors will have unknown shapes, e.g., if a TensorList variant
  tensor is captured as a placeholder, elements popped from that list would have
  unknown shape.

  Args:
    source_t: The tensor to copy HandleData from.
    target_t: The tensor to copy HandleData to.
  """
  if (target_t.dtype == dtypes.resource or
      target_t.dtype == dtypes.variant):
    if isinstance(source_t, ops.EagerTensor):
      handle_data = source_t._handle_data  # pylint: disable=protected-access
    else:
      handle_data = resource_variable_ops.get_resource_handle_data(source_t)
    if (handle_data is not None
        and handle_data.is_set
        and handle_data.shape_and_type):
      # pylint: disable=protected-access
      pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
                                              target_t._as_tf_output(),
                                              handle_data.SerializeToString())
      # pylint: enable=protected-access
      # Ensure that shapes and dtypes are propagated.
      shapes, types = zip(*[(pair.shape, pair.dtype)
                            for pair in handle_data.shape_and_type])
      ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
      shapes = [[d.size for d in s.dim]
                if not s.unknown_rank else None for s in shapes]
      pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
          target_t._op._graph._c_graph,  # pylint: disable=protected-access
          target_t._as_tf_output(),  # pylint: disable=protected-access
          shapes, ranks, types)