示例#1
0
def capture_value(tensor_map, value, dtype, name):
    """Capture a value from outside the function, to pass in as an extra arg."""
    captured_value = tensor_map.get(ops.tensor_id(value), None)
    if captured_value is None:
        captured_value = graph_placeholder(dtype=dtype or value.dtype,
                                           shape=value.shape,
                                           name=name)
        if captured_value.dtype == dtypes_module.resource:
            handle_data = value._handle_data  # pylint: disable=protected-access
            captured_value._handle_data = handle_data  # pylint: disable=protected-access
            if handle_data is not None and handle_data.is_set:
                # Ensure that shapes and dtypes are propagated.
                shapes, types = zip(*[(pair.shape, pair.dtype)
                                      for pair in handle_data.shape_and_type])
                ranks = [
                    len(s.dim) if not s.unknown_rank else -1 for s in shapes
                ]
                shapes = [[d.size
                           for d in s.dim] if not s.unknown_rank else None
                          for s in shapes]
                with errors.raise_exception_on_not_ok_status() as status:
                    pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                        captured_value._op._graph._c_graph,  # pylint: disable=protected-access
                        captured_value._as_tf_output(),  # pylint: disable=protected-access
                        shapes,
                        ranks,
                        types,
                        status)

        tensor_map[ops.tensor_id(value)] = (value, captured_value)
    else:
        captured_value = captured_value[1]
    tape.record_operation("captured_value", [captured_value], [value],
                          lambda x: [x])
    return captured_value
示例#2
0
def capture_value(tensor_map, value, dtype, name):
    """Capture a value from outside the function, to pass in as an extra arg."""
    captured_value = tensor_map.get(ops.tensor_id(value), None)
    if captured_value is None:
        captured_value = graph_placeholder(dtype=dtype or value.dtype,
                                           shape=value.shape,
                                           name=name)
        if captured_value.dtype == dtypes_module.resource:
            if ops._USE_C_SHAPES:  # pylint: disable=protected-access
                if isinstance(value, ops.EagerTensor):
                    handle_data = value._handle_data  # pylint: disable=protected-access
                else:
                    handle_data = resource_variable_ops.get_resource_handle_data(
                        value)
            else:
                handle_data = value._handle_data  # pylint: disable=protected-access
            if handle_data is not None and handle_data.is_set:
                # pylint: disable=protected-access
                if ops._USE_C_SHAPES:
                    pywrap_tensorflow.TFE_SetResourceHandleShapeAndType(
                        captured_value.graph._c_graph,
                        captured_value._as_tf_output(),
                        handle_data.SerializeToString())
                else:
                    captured_value._handle_data = handle_data
                # pylint: enable=protected-access
                # Ensure that shapes and dtypes are propagated.
                shapes, types = zip(*[(pair.shape, pair.dtype)
                                      for pair in handle_data.shape_and_type])
                ranks = [
                    len(s.dim) if not s.unknown_rank else -1 for s in shapes
                ]
                shapes = [[d.size
                           for d in s.dim] if not s.unknown_rank else None
                          for s in shapes]
                pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                    captured_value._op._graph._c_graph,  # pylint: disable=protected-access
                    captured_value._as_tf_output(),  # pylint: disable=protected-access
                    shapes,
                    ranks,
                    types)

        tensor_map[ops.tensor_id(value)] = (value, captured_value)
    else:
        captured_value = captured_value[1]
    tape.record_operation("captured_value", [captured_value], [value],
                          lambda x: [x])
    return captured_value
示例#3
0
def copy_handle_data(source_t, target_t):
    """Copies HandleData for variant and resource type tensors if available.

  The CppShapeInferenceResult::HandleData proto contains information about the
  shapes and types of the element tensors of resource/variant type tensors.
  We need to copy this across function boundaries, i.e., when capturing a
  placeholder or when returning a function tensor as output. If we don't do this
  the element tensors will have unknown shapes, e.g., if a TensorList variant
  tensor is captured as a placeholder, elements popped from that list would have
  unknown shape.

  Args:
    source_t: The tensor to copy HandleData from.
    target_t: The tensor to copy HandleData to.
  """
    if (target_t.dtype == dtypes.resource or target_t.dtype == dtypes.variant):
        if isinstance(source_t, ops.EagerTensor):
            handle_data = source_t._handle_data  # pylint: disable=protected-access
        else:
            handle_data = resource_variable_ops.get_resource_handle_data(
                source_t)
        if (handle_data is not None and handle_data.is_set
                and handle_data.shape_and_type):
            # pylint: disable=protected-access
            pywrap_tensorflow.SetHandleShapeAndType(
                target_t.graph._c_graph, target_t._as_tf_output(),
                handle_data.SerializeToString())
            # pylint: enable=protected-access
            # Ensure that shapes and dtypes are propagated.
            shapes, types = zip(*[(pair.shape, pair.dtype)
                                  for pair in handle_data.shape_and_type])
            ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
            shapes = [
                [d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
                if not s.unknown_rank else None for s in shapes
            ]
            pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                target_t._op._graph._c_graph,  # pylint: disable=protected-access
                target_t._as_tf_output(),  # pylint: disable=protected-access
                shapes,
                ranks,
                types)