def _capture_tensor_as_extra_input(self, tensor, name=None): # Substitute with a placeholder. self.extra_inputs.append(tensor) # Hoist the new input placeholder out of any control flow context # we're currently in. with ops.control_dependencies(None): ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if ops._USE_C_SHAPES: if isinstance(tensor, ops.EagerTensor): handle_data = tensor._handle_data if handle_data: handle_data = handle_data.SerializeToString() else: handle_data = c_api.GetResourceHandleShapeAndType( tensor.graph._c_graph, tensor._as_tf_output()) if handle_data: c_api.SetResourceHandleShapeAndType( ph.graph._c_graph, ph._as_tf_output(), compat.as_bytes(handle_data)) else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access self.inputs.append(ph) self._captured[tensor] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): with ops.control_dependencies(None): return array_ops.guarantee_const(ph) else: return ph
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder(dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): handle_data = value._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data( value) else: handle_data = value._handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access if ops._USE_C_SHAPES: pywrap_tensorflow.SetResourceHandleShapeAndType( captured_value.graph._c_graph, captured_value._as_tf_output(), handle_data.SerializeToString()) else: captured_value._handle_data = handle_data # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [ len(s.dim) if not s.unknown_rank else -1 for s in shapes ] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value