Ejemplo n.º 1
0
    def _capture_tensor_as_extra_input(self, tensor, name=None):
        # Substitute with a placeholder.
        self.extra_inputs.append(tensor)
        # Hoist the new input placeholder out of any control flow context
        # we're currently in.
        with ops.control_dependencies(None):
            ph = array_ops.placeholder(tensor.dtype,
                                       shape=tensor.get_shape(),
                                       name=name)
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
            if isinstance(tensor, ops.EagerTensor):
                handle_data = tensor._handle_data
                if handle_data:
                    handle_data = handle_data.SerializeToString()
            else:
                handle_data = c_api.GetResourceHandleShapeAndType(
                    tensor.graph._c_graph, tensor._as_tf_output())

            if handle_data:
                c_api.SetResourceHandleShapeAndType(
                    ph.graph._c_graph, ph._as_tf_output(),
                    compat.as_bytes(handle_data))
        else:
            ph._handle_data = tensor._handle_data
        # pylint: enable=protected-access
        self.inputs.append(ph)
        self._captured[tensor] = ph
        self.extra_args.append(ph)
        if _is_guaranteed_const(tensor):
            with ops.control_dependencies(None):
                return array_ops.guarantee_const(ph)
        else:
            return ph
Ejemplo n.º 2
0
def capture_value(tensor_map, value, dtype, name):
    """Capture a value from outside the function, to pass in as an extra arg."""
    captured_value = tensor_map.get(ops.tensor_id(value), None)
    if captured_value is None:
        captured_value = graph_placeholder(dtype=dtype or value.dtype,
                                           shape=value.shape,
                                           name=name)
        if captured_value.dtype == dtypes_module.resource:
            if ops._USE_C_SHAPES:  # pylint: disable=protected-access
                if isinstance(value, ops.EagerTensor):
                    handle_data = value._handle_data  # pylint: disable=protected-access
                else:
                    handle_data = resource_variable_ops.get_resource_handle_data(
                        value)
            else:
                handle_data = value._handle_data  # pylint: disable=protected-access
            if handle_data is not None and handle_data.is_set:
                # pylint: disable=protected-access
                if ops._USE_C_SHAPES:
                    pywrap_tensorflow.SetResourceHandleShapeAndType(
                        captured_value.graph._c_graph,
                        captured_value._as_tf_output(),
                        handle_data.SerializeToString())
                else:
                    captured_value._handle_data = handle_data
                # pylint: enable=protected-access
                # Ensure that shapes and dtypes are propagated.
                shapes, types = zip(*[(pair.shape, pair.dtype)
                                      for pair in handle_data.shape_and_type])
                ranks = [
                    len(s.dim) if not s.unknown_rank else -1 for s in shapes
                ]
                shapes = [[d.size
                           for d in s.dim] if not s.unknown_rank else None
                          for s in shapes]
                pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
                    captured_value._op._graph._c_graph,  # pylint: disable=protected-access
                    captured_value._as_tf_output(),  # pylint: disable=protected-access
                    shapes,
                    ranks,
                    types)

        tensor_map[ops.tensor_id(value)] = (value, captured_value)
    else:
        captured_value = captured_value[1]
    tape.record_operation("captured_value", [captured_value], [value],
                          lambda x: [x])
    return captured_value